LLM探索:為ChatGLM2的gRPC後端增加連續對話功能

2023-10-19 12:00:26

前言

之前我做 AIHub 的時候通過 gRPC 的方式接入了 ChatGLM 等開源大模型,對於大模型這塊我搞了個 StarAI 框架,相當於簡化版的 langchain ,可以比較方便的把各種大模型和相關配套組合在一起使用。

主要思路還是用的 OpenAI 介面的那套,降低學習成本,但之前為了快速開發,就只搞了個簡單的 gRPC 介面,還差個多輪對話功能沒有實現,這次就來完善一下這個功能。

簡述

系統分為LLM後端和使用者端兩部分,LLM後端使用 gRPC 提供介面,使用者端就是我用 Blazor 開發的 AIHub

所以這次涉及到這幾個地方的修改

  • proto
  • 使用者端 - C# 程式碼
  • AIHub頁面 - Blazor 的 razor 程式碼
  • gRPC 伺服器端 - Python 程式碼

修改 proto

來改造一下 proto 檔案

\syntax = "proto3";

import "google/protobuf/wrappers.proto";

option csharp_namespace = "AIHub.RPC";

package aihub;

service ChatHub {
  rpc Chat (ChatRequest) returns (ChatReply);
  rpc StreamingChat (ChatRequest) returns (stream ChatReply);
}

message ChatRequest {
  string prompt = 1;
  repeated Message history = 2;
  int32 max_length = 3;
  float top_p = 4;
  float temperature = 5;
}

message Message {
  string role = 1;
  string content = 2;
}

message ChatReply {
  string response = 1;
}

增加了 Message 型別,在 ChatRequest 聊天請求中增加了 history 欄位作為對話歷史。

修改 C# 的 gRPC 使用者端程式碼

上面的 proto 寫完之後編譯專案,會重新生成使用者端的 C# 程式碼,現在來修改一下我們的呼叫程式碼

可以看到 ChatRequest 多了個 RepeatedField<Message> 型別的 history 屬性,這個屬性是唯讀的,所以每次聊天的時候傳入對話歷史只能使用新增的方式。

為了方便使用,我封裝了以下方法來建立 ChatRequest 物件

private ChatRequest GetRequest(string prompt, List<Message>? history = null) {
  var request = new ChatRequest {
    Prompt = prompt,
    MaxLength = 2048,
    TopP = 0.75f,
    Temperature = 0.95f
  };

  if (history != null) {
    request.History.AddRange(history);
  }

  return request;
}

繼續改寫兩個聊天的方法,增加個一個 history 引數

public async Task<string> Chat(string prompt, List<Message>? history = null) {
  var resp = await _client.ChatAsync(GetRequest(prompt, history));
  return RenderText(resp.Response);
}

public async IAsyncEnumerable<string> StreamingChat(string prompt, List<Message>? history = null) {
  using var call = _client.StreamingChat(GetRequest(prompt, history));
  await foreach (var resp in call.ResponseStream.ReadAllAsync()) {
    yield return RenderText(resp.Response);
  }
}

搞定。

修改 gRPC 伺服器端的 Python 程式碼

先來看看 ChatGLM2 是如何傳入對話的

對官方提供的 demo 進行偵錯,發現傳入模型的 history 是列表裡面包著一個個元組,表示一個個對話,奇奇怪怪的格式。

history = [('問題1', '回答1'), ('問題2', '回答2')]

但是 AIHub 的對話是按照 OpenAI 的思路來做的,是這樣的格式:

history = [
  {'role': 'user', 'content': '問題1'},
  {'role': 'assistant', 'content': '回答1'},
  {'role': 'user', 'content': '問題2'},
  {'role': 'assistant', 'content': '回答2'},
]

現在需要把 OpenAI 對話格式轉換為 ChatGLM 的格式

直接上程式碼吧

def messages_to_tuple_history(messages: List[chat_pb2.Message]):
    """把聊天記錄列表轉換成 ChatGLM 需要的 list 巢狀 tuple 形式"""
    history = []
    current_completion = ['', '']
    is_enter_completion = False

    
    for item in messages:
        if not is_enter_completion and item.role == 'user':
            is_enter_completion = True

        if is_enter_completion:
            if item.role == 'user':
                if len(current_completion[0]) > 0:
                    current_completion[0] = f"{current_completion[0]}\n\n{item.content}"
                else:
                    current_completion[0] = item.content
            if item.role == 'assistant':
                if len(current_completion[1]) > 0:
                    current_completion[1] = f"{current_completion[1]}\n\n{item.content}"
                else:
                    current_completion[1] = item.content

                is_enter_completion = False
                history.append((current_completion[0], current_completion[1]))
                current_completion = ['', '']

    return history

目前只處理了 user 和 assistant 兩種角色,其實 OpenAI 還有 system 和 function ,system 比較好處理,可以做成以下形式

[('system prompt1', ''), ('system prompt2', '')]

不過我還沒測試,暫時也用不上這個東西,所以就不寫在程式碼裡了。

接著繼續修改兩個對話的方法

class ChatService(chat_pb2_grpc.ChatHubServicer):
    def Chat(self, request: chat_pb2.ChatRequest, context):
        response, history = model.chat(
            tokenizer,
            request.prompt,
            history=messages_to_tuple_history(request.history),
            max_length=request.max_length,
            top_p=request.top_p,
            temperature=request.temperature)
        torch_gc()
        return chat_pb2.ChatReply(response=response)

    def StreamingChat(self, request: chat_pb2.ChatRequest, context):
        current_length = 0
        for response, history in model.stream_chat(
                tokenizer,
                request.prompt,
                history=messages_to_tuple_history(request.history),
                max_length=request.max_length,
                top_p=request.top_p,
                temperature=request.temperature,
                return_past_key_values=False):

            print(response[current_length:], end="", flush=True)
            yield chat_pb2.ChatReply(response=response)
            current_length = len(response)

        torch_gc()

對了,每次對話完成記得回收視訊記憶體

def torch_gc():
    if torch.cuda.is_available():
        with torch.cuda.device(CUDA_DEVICE):
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

這樣就搞定了。

PS: Python 紀錄檔元件可以用 loguru ,很好用,我最近剛發現的。

小結

gRPC 方式呼叫開發起來還是有點麻煩的,主要是偵錯比較麻煩,我正在考慮是否改成統一 OpenAI 介面方式的呼叫,GitHub 上有人貢獻了 ChatGLM 的 OpenAI 相容介面,後續可以看看。

不過在視覺這塊,還是得繼續搞 gRPC ,傳輸效率比較好。大模型可以使用 HTTP 的 EventSource 是因為資料量比較小,次要原因是對話是單向的,即:使用者向模型提問,模型不會主動向使用者傳送資訊。