Langchain-Chatchat專案:5.1-ChatGLM3-6B工具呼叫

2023-11-07 06:00:39

  在語意、數學、推理、程式碼、知識等不同角度的資料集上測評顯示,ChatGLM3-6B-Base 具有在10B以下的基礎模型中最強的效能。ChatGLM3-6B採用了全新設計的Prompt格式,除正常的多輪對話外。同時原生支援工具呼叫(Function Call)、程式碼執行(Code Interpreter)和Agent任務等複雜場景。本文主要通過天氣查詢例子介紹了在tool_registry.py中註冊新的工具來增強模型能力。

  可以直接呼叫LangChain自帶的工具(比如,ArXiv),也可以呼叫自定義的工具。LangChain自帶的部分工具[2],如下所示:

一.自定義天氣查詢工具
1.Weather類
  可以參考Tool/Weather.py以及Tool/Weather.yaml檔案,繼承BaseTool類,過載_run()方法,如下所示:

class Weather(BaseTool):  # 天氣查詢工具
    name = "weather"
    description = "Use for searching weather at a specific location"

    def __init__(self):
        super().__init__()

    def get_weather(self, location):
        api_key = os.environ["SENIVERSE_KEY"]
        url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
        response = requests.get(url)
        if response.status_code == 200:
            data = response.json()
            weather = {
                "temperature": data["results"][0]["now"]["temperature"],
                "description": data["results"][0]["now"]["text"],
            }
            return weather
        else:
            raise Exception(
                f"Failed to retrieve weather: {response.status_code}")

    def _run(self, para: str) -> str:
        return self.get_weather(para)

2.weather.yaml檔案
  weather.yaml檔案內容,如下所示:

name: weather
description: Search the current weather of a city
parameters:
  type: object
  properties:
    city:
      type: string
      description: City name
  required:
    - city

二.自定義天氣查詢工具呼叫
  自定義天氣查詢工具呼叫,在main.py中匯入Weather工具。如下所示:

run_tool([Weather()], llm, [
    "今天北京天氣怎麼樣?",
    "What's the weather like in Shanghai today",
])

  其中,run_tool()函數實現如下所示:

def run_tool(tools, llm, prompt_chain: List[str]):
    loaded_tolls = []  # 用於儲存載入的工具
    for tool in tools:  # 逐個載入工具
        if isinstance(tool, str):
            loaded_tolls.append(load_tools([tool], llm=llm)[0])  # load_tools返回的是一個列表
        else:
            loaded_tolls.append(tool)  # 如果是自定義的工具,直接新增到列表中
    agent = initialize_agent(  # 初始化agent
        loaded_tolls, llm,
        agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,  # agent型別:使用結構化聊天的agent
        verbose=True,
        handle_parsing_errors=True
    )
    for prompt in prompt_chain:  # 逐個輸入prompt
        agent.run(prompt)

1.load_tools()函數
  根據工具名字載入相應的工具,如下所示:

def load_tools(
    tool_names: List[str],
    llm: Optional[BaseLanguageModel] = None,
    callbacks: Callbacks = None,
    **kwargs: Any,
) -> List[BaseTool]:

2.initialize_agent()函數
  根據工具列表和LLM載入一個agent executor,如下所示:

def initialize_agent(
    tools: Sequence[BaseTool],
    llm: BaseLanguageModel,
    agent: Optional[AgentType] = None,
    callback_manager: Optional[BaseCallbackManager] = None,
    agent_path: Optional[str] = None,
    agent_kwargs: Optional[dict] = None,
    *,
    tags: Optional[Sequence[str]] = None,
    **kwargs: Any,
) -> AgentExecutor:

  其中,agent預設為AgentType.ZERO_SHOT_REACT_DESCRIPTION。本文中使用為AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,一種為聊天模型優化的zero-shot react agent,該agent能夠呼叫具有多個輸入的工具。
3.run()函數
  執行鏈的便捷方法,這個方法與Chain.__call__之間的主要區別在於,這個方法期望將輸入直接作為位置引數或關鍵字引數傳遞,而Chain.__call__期望一個包含所有輸入的單一輸入字典。如下所示:

def run(
    self,
    *args: Any,
    callbacks: Callbacks = None,
    tags: Optional[List[str]] = None,
    metadata: Optional[Dict[str, Any]] = None,
    **kwargs: Any,
) -> Any:

4.結果分析
  結果輸出,如下所示:

> Entering new AgentExecutor chain...
======
======

Action: 
``
{"action""weather""action_input""北京"}
``
Observation: {'temperature''20''description''晴'}
Thought:======
======

Action: 
``
{"action""Final Answer""action_input""根據查詢結果,北京今天的天氣是晴,氣溫為20℃。"}
``

> Finished chain.


> Entering new AgentExecutor chain...
======
======

Action: 
``
{"action""weather""action_input""Shanghai"}
``
Observation: {'temperature''20''description''晴'}
Thought:======
======

Action: 
``
{"action""Final Answer""action_input""根據最新的天氣資料,今天上海的天氣情況是晴朗的,氣溫為20℃。"}
``

> Finished chain.

  剛開始的時候沒有找到識別實體city的地方,後面偵錯ChatGLM3/langchain_demo/ChatGLM3.py->_call()時發現了一個巨長的prompt,這不就是zero-prompt(AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION)嗎?順便吐槽下LangChain的程式碼真的不好偵錯。


三.註冊工具增強LLM能力
1.註冊工具
  可以通過在tool_registry.py中註冊新的工具來增強模型的能力。只需要使用@register_tool裝飾函數即可完成註冊。對於工具宣告,函數名稱即為工具的名稱,函數docstring即為工具的說明;對於工具的引數,使用Annotated[typ: type, description: str, required: bool]標註引數的型別、描述和是否必須。將get_weather()函數進行註冊,如下所示:

@register_tool
def get_weather(  # 工具函數
        city_name: Annotated[str, 'The name of the city to be queried', True],
) -> str:
    """
    Get the current weather for `city_name`
    "
""

    if not isinstance(city_name, str):  # 引數型別檢查
        raise TypeError("City name must be a string")

    key_selection = {  # 選擇的鍵
        "current_condition": ["temp_C""FeelsLikeC""humidity""weatherDesc""observation_time"],
    }
    import requests
    try:
        resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
        resp.raise_for_status()
        resp = resp.json()
        ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
    except:
        import traceback
        ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()

    return str(ret)

  具體工具註冊實現方式@register_tool裝飾函數,如下所示:

def register_tool(func: callable):  # 註冊工具
    tool_name = func.__name__  # 工具名
    tool_description = inspect.getdoc(func).strip()  # 工具描述
    python_params = inspect.signature(func).parameters  # 工具引數
    tool_params = []  # 工具引數描述
    for name, param in python_params.items():  # 遍歷引數
        annotation = param.annotation  # 引數註解
        if annotation is inspect.Parameter.empty:
            raise TypeError(f"Parameter `{name}` missing type annotation")  # 引數缺少註解
        if get_origin(annotation) != Annotated:  # 引數註解不是Annotated
            raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")  # 引數註解必須是Annotated

        typ, (description, required) = annotation.__origin__, annotation.__metadata__  # 引數型別, 引數描述, 是否必須
        typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__  # 引數型別名
        if not isinstance(description, str):  # 引數描述必須是字串
            raise TypeError(f"Description for `{name}` must be a string")
        if not isinstance(required, bool):  # 是否必須必須是布林值
            raise TypeError(f"Required for `{name}` must be a bool")

        tool_params.append({  # 新增引數描述
            "name": name,
            "description": description,
            "type": typ,
            "required": required
        })
    tool_def = {  # 工具定義
        "name": tool_name,
        "description": tool_description,
        "params": tool_params
    }

    print("[registered tool] " + pformat(tool_def))  # 列印工具定義
    _TOOL_HOOKS[tool_name] = func  # 註冊工具
    _TOOL_DESCRIPTIONS[tool_name] = tool_def  # 新增工具定義

    return func

2.呼叫工具
  參考檔案ChatGLM3/tool_using/openai_api_demo.py,如下所示:

def main():
    messages = [  # 對話資訊
        system_info,
        {
            "role""user",
            "content""幫我查詢北京的天氣怎麼樣",
        }
    ]
    response = openai.ChatCompletion.create(  # 呼叫OpenAI API
        model="chatglm3",
        messages=messages,
        temperature=0,
        return_function_call=True
    )
    function_call = json.loads(response.choices[0].message.content)  # 獲取函數呼叫資訊
    logger.info(f"Function Call Response: {function_call}")  # 列印函數呼叫資訊

    tool_response = dispatch_tool(function_call["name"], function_call["parameters"])  # 呼叫函數
    logger.info(f"Tool Call Response: {tool_response}")  # 列印函數呼叫結果

    messages = response.choices[0].history  # 獲取歷史對話資訊
    messages.append(
        {
            "role""observation",
            "content": tool_response,  # 呼叫函數返回結果
        }
    )

    response = openai.ChatCompletion.create(  # 呼叫OpenAI API
        model="chatglm3",
        messages=messages,
        temperature=0,
    )
    logger.info(response.choices[0].message.content)  # 列印對話結果

參考文獻:
[1]https://github.com/THUDM/ChatGLM3/tree/main
[2]https://python.langchain.com/docs/integrations/tools