我的博客
GPT Assistant 初体验(附示例代码)

GPT Assistant 初体验(附示例代码)

背景

11 月 6 号,OpenAI 举办了首次开发者大会,堪称第一届 AI 春晚,其中 GPT Assistant 的发布备受关注,因为它基本上就是一个官方出品的 Agent 框架。基于 GPT Assistant,你可以给 GPT 接上 Code Interpreter 能力(模型自己可以写代码并且执行)、向量检索能力(可以简单理解为官方版的 ChatPDF)或者其他自定义函数。简单说,以后开发 Agent 可能不需要 Langchain 框架了。

OpenAI 官方文档写得非常简单,有很多地方直接省略。对于不熟悉的新手来说,不太容易跑起来。本文档的作用仅仅是填补几个官方文档中的 gap,希望能帮助读者节省一些时间。

如何运行

为了使用 GPT Assistant,你需要先完成以下 3 步工作:

  1. 安装/更新你的 openai python SDK,因为 GPT Assistant 仅在最新的 openai v1.1 中可用 pip3 install --upgrade openai
  2. 创建.env 文件,填入 OPENAI API KEY: OPENAI_API_KEY=你的API密钥
  3. 在.env 同路径下创建 main.py,并运行:、 如果你没有安装过 dotenv 的话,还需要手动安装下 python-dotenv: pip3 install python-dotenv 这样做可以让你的代码更安全,防止密钥泄露。

你的 main.py 的代码应该如下所示:

from openai import Client
from dotenv import load_dotenv
import time
 
# 将.env加载为环境变量,这样openai SDK可以自动获取环境变量中的API KEY
load_dotenv()
 
# 初始化一个client实例,官方文档没有写这部分,有的人可能会一头雾水
client = Client()
 
# 初始化一个assistant
assistant = client.beta.assistants.create(
    name="Helpful Assistant",
    instructions="You are a helpful assitant.",
    tools=[{"type": "code_interpreter"}, {"type": "retrieval"}],
    model="gpt-4-1106-preview",
)
 
"""
初始化一个线程,这里的线程指的是一次对话,
thread会通过剪切等手段帮你管理上下文,确保context长度不超过模型的限制
"""
thread = client.beta.threads.create()
 
# 向thread里放入一个消息
message = client.beta.threads.messages.create(
    thread_id=thread.id,
    role="user",
    content="请告诉我如何把一头大象放进冰箱里。",
)
 
# 将消息发送出去
run = client.beta.threads.runs.create(
    thread_id=thread.id,
    assistant_id=assistant.id,
    instructions="",
)
 
"""
官方文档中获取生成结果的部分写的非常简略,这里提供一个简单的示例。
用一个轮询来查询回复是否生成完毕,每0.5秒查询一次,若运行完毕则输出结果。
这个方案很不优雅,相信以后openai会提供官方的结果获取方式。
"""
while True:
 
    # run被创建完还要retrieve一下才会获取结果,每0.5秒查看一下,这个retrieve的作用主要是更新run的状态
    run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id)
    if run.status == "completed":
        messages = client.beta.threads.messages.list(thread_id=thread.id)
 
        #输出最新消息
        print(messages.data[0].content[0].text.value)
        break
    else:
        time.sleep(0.5)
        print("waiting for reply...")

注意,以上是基于 openai python SDK 的调用方式,如果你使用 http 方式调用,那么你需要自己添加一个 header: OpenAI-Beta: assistants=v1 这样 openai 才会允许你调用 GPT Assistant。

向量检索

向量检索是近期的大热门方向,围绕着向量检索诞生了许多热门项目,例如 PrivateGPT, ChatDoc, ChatPDF 等等,GPT Assistant 的 retrieval 工具是一个官方版的向量检索工具,以后开发基于向量检索的问答机器人不需要再基于 langchain 了。

要快速体验 GPT Assitant 的向量检索工具,可以先在同路径下准备一份 knowledge.pdf 文件(或者其他任何支持的文件格式,openai 已经支持了十几种格式:https://platform.openai.com/docs/assistants/tools/supported-files) (opens in a new tab)

from openai import Client
import time
from dotenv import load_dotenv
 
def ask_question(client, thread_id, question):
    return client.beta.threads.messages.create(
        thread_id=thread_id,
        role="user",
        content=question,
    )
 
def wait_for_reply(client, thread_id, run):
    max_retries = 100
    for _ in range(max_retries):
        run = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id)
        if run.status == "completed":
            messages = client.beta.threads.messages.list(thread_id=thread_id)
            return messages.data[0].content[0].text.value
        else:
            print(f"waiting for reply. took {_* 0.5:04.2f} seconds")
            time.sleep(0.5)
    raise TimeoutError("No reply received within the expected time.")
 
def main():
    load_dotenv()
    client = Client()
    print("client created!")
    file = client.files.create(file=open("knowledge.pdf", "rb"), purpose="assistants")
    print("file uploaded!")
    assistant = client.beta.assistants.create(
        instructions="You are a customer support chatbot. Use your knowledge base to best respond to customer queries.",
        model="gpt-4-1106-preview",
        tools=[{"type": "retrieval"}],
        file_ids=[file.id],
    )
    print("assistant created!")
    thread = client.beta.threads.create()
    ask_question(client, thread.id, "帮我总结一下这份文件的内容,请用中文回答。")
    print("question asked!")
    run = client.beta.threads.runs.create(
        thread_id=thread.id,
        assistant_id=assistant.id,
        instructions="",
    )
    try:
        reply = wait_for_reply(client, thread.id, run)
        print(reply)
    except TimeoutError as e:
        print(str(e))
 
if __name__ == "__main__":
    main()