這一章我們聊聊有哪些方案可以不用微調直接讓大模型支援超長文字輸入,注意這裡主要針對無限輸入場景。之前在BERT系列中我們就介紹過稀疏注意力和片段遞迴的一些長文字建模方案長文字建模 BigBird & Longformer & Reformer & Performer,不過以上方案無一例外都需要在訓練階段引入。針對當前大模型微調成本高的問題,更多研究放到如何在模型外部支援長文字輸入。先討論下為啥當前的大模型會在推理時存在輸入長度的限制,主要有以下幾點原因
Attention矩陣計算複雜度:在不引入稀疏注意力的條件下,Attention矩陣的記憶體和計算複雜度是\(O(序列長度^2)\),文字長度的上升會帶來視訊記憶體的指數增長。
訓練耗時:訓練階段的文字長度會顯著影響訓練速度, 因此2048一般是當前預訓練常見的最大長度。
針對以上問題本章介紹4種方案:顯式搜尋的知識庫外掛方案,隱式搜尋的Unlimiformer, 並行輸入的pcw和並行解碼NBCE。
- paper: Unleashing Infinite-Length Input Capacity for Large-scale Language Models with Self-Controlled Memory System
- 看到最無敵的應用,文字和表格解析超厲害https://chatdoc.com/?viaurl=ainavpro.com
- ChatGPT程式碼實現: https://github.com/arc53/DocsGPT
- ChatGLM程式碼實現: https://github.com/imClumsyPanda/langchain-ChatGLM
- 適用於大規模知識問答場景
這塊可能是GPT後比較火的方向,有一陣每天都能看到類似的新應用,從GPT讀論文,再到百科問答,搭配langchain框架,在DocQA,KBQA的場景簡直無往不利, 以上分別給出了基於ChatGPT和ChatGLM的兩個實現方案。
實現的步驟基本可以被下圖概括
搜尋法最大的優點是實現簡單,不過也有許多限制就是隻能支援NLU任務,以及會破壞輸入文字的上下文連續性,和文字順序。但在大規模知識問答這塊算是現在看到最好的方案。
- Unlimiformer: Long-Range Transformers with Unlimited Length Input
- https://github.com/abertsch72/unlimiformer
- 適用於Encoder-Decoder模型,長文字摘要等場景
特意起了個隱式搜尋的標題,是因為和上面的文字搜尋實現有異曲同工之妙,本質的差異只是以上是離散文字塊的搜尋。而Unlimiformer是在解碼階段對超長輸入,token粒度的輸出層embedding進行檢索,選擇最相關的Top Token計算Attention。
首先對於超長輸入,unlimiformr採用以上提到的重疊切分的方法,重疊率50%,這樣可以更好保留上文和文字連貫性,例如第一段文字是1-500字,第二段重疊250字取250-750字。然後使用Encoder對每段文字進行獨立編碼,繞過Attention的平方複雜度問題。最後輸出每段文字的Embedding,注意這裡不是文字整體embedidng, 而是後半部分(250~500字)每個Token最上層的Embedding,並寫入向量索引,這裡用的是Faiss。
在解碼層,每一步解碼,query都會檢索注意力最高的Top-k個輸入Token,作為編碼器部分的資訊用於解碼器的解碼。這裡簡單回憶下Attention計算, Top-K個Token就是讓以下注意力取值最高的key。
考慮Decoder的每一層(N層)中的每一個head(L個頭)都需要和Encoder的輸出層進行互動, 檢索Top Key,如果儲存每一層每個head的Key,需要構建\(O(L*N*seqlen)\)的向量儲存。對此作者進行了優化,改變了以下QK的計算順序,用每一層每個頭Key的對映矩陣對Q進行對映,這樣只需要儲存一份seq_len的編碼向量(\(h_{encoder}\)),在每一層檢索時用對映後的Q進行檢索既可,其實就是時間換空間
unlimiformer提供了程式碼實現,核心程式碼抽出來看下有兩塊
for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices:
chunk = input_ids[:, context_start_ind:context_end_ind]
chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind]
hidden_states = self.model(chunk, attention_mask=chunk_attention_mask, labels=dummy_labels, return_dict=True)
last_hidden = hidden_states.encoder_last_hidden_state # (batch, chunked_source_len, dim)
to_add = last_hidden[:, update_start_ind:update_end_ind].detach()
to_apply_mask = chunk_attention_mask[:, update_start_ind:update_end_ind]
def attention_forward_hook(self, module, input, output):
# output: (batch, time, 3 * heads * attention_dim)
with torch.no_grad():
query = self.process_query(output)[:,-1] # (batch * beam, head, dim)
query = query[:, self.head_nums] # (batch * beam, head, dim)
#這是前面提到的計算優化使用每層每個head的Key對映矩陣對Query進行對映用於搜尋
attention_layer_list = self.attention_layer_to_capture(self.layer_begin, self.layer_end)
k_proj_layer = [layers[0] for layers in attention_layer_list][self.cur_decoder_layer_index]
# modify query by k_projs
k_proj = k_proj_layer.weight
k_proj = k_proj.view(1, self.num_heads, query.shape[-1], k_proj.shape[0]) # (1, num_heads, attn_dim, embed_dim)
datastore_query = query.unsqueeze(-2) # (batch * beam, num_heads, 1, attn_dim)
datastore_query = torch.matmul(datastore_query, k_proj) # (batch * beam, num_heads, 1, embed_dim)
datastore_query = datastore_query.squeeze(-2) # (batch * beam, num_heads, embed_dim)
datastore_query = datastore_query.view((self.datastore.batch_size, -1, datastore_query.shape[2])) # (batch, beam * num_heads, embed_dim)
# 這裡進行Top Key的檢索:得到Key的索引,Embedding和得分
top_search_key_scores, top_search_key_indices = self.datastore.search(datastore_query, k=self.actual_model_window_size)
embeddings = torch.take_along_dim(input=self.embeddings.unsqueeze(1),
indices=top_search_key_indices.unsqueeze(-1).to(self.embeddings.device), dim=-2)
##後面就是常規的對Embedding進行Key和Value的對映然後做Attention了
和前面的文字檢索對比,unlimiformer的儲存成本會更高,因為要儲存token粒度的Embedding資訊,更適用於on-the-fly的長文字推理使用,例如針對單一檔案的QA,只儲存當前檔案,而前面文字塊檢索方案更適合一些大規模知識,批次的檔案的儲存。
但其實unlimiformer直接對Token進行離散召回,這一點我讓我有些困惑,這樣單一token的檢索召回,真的不會破壞上文連續性麼?還是說Encoder編碼方式已經保證了檢索召回大概率會召回成段的Token,又或者說每個Token的Embedding內已經充分編碼了連續上下文的資訊,召回離散Token也不會出現割裂的語意資訊?哈哈考慮unlimiformer只支援Encoder-Decoder的框架,和我們用的Decoder框架不適配,我決定不細糾結了!有在中文嘗試過效果的童鞋可以分享下~
- Parallel Context Windows for Large Language Models
- https://github.com/AI21Labs/Parallel-Context-Windows
- 適用於Decoder模型,以及小規模內容理解場景
同樣是對超長文字進行切塊,然後獨立編碼,PCW使用的是Decoder框架。和unlimiformer只使用Top-Key進行解碼,PCW在解碼過程中對全部輸入上文進行Attention。對比Encoder-Decoder框架,因為輸入和輸出都在Decoder側,PCW需要解決兩個問題:位置編碼和注意力矩陣如何調整, 下圖基本概括了這兩個細節
1. 位置編碼:輸入文字截斷後,每段文字的位置編碼相同。考慮所最長的文字長度為C,則輸入文字最大的位置編碼id是$P_C$,則解碼器第一個字的位置編碼id是$P_{C+1}$,然後順序向後編碼。其實就是丟棄了上文多段文字之間的位置關係,解碼時只知道上文多段文字都是在解碼器之前,但無法區分文字之間的位置。不過因為上文每段文字複用了相同的位置編碼,因此位置編碼的長度大幅降低,也就降低了對位置編碼外推性的需求。
position_ids = attention_mask.long().cumsum(-1) - 1
n_task_tokens = position_ids.shape[1] - sum_windows_size
# 保證解碼器的位置編碼比最長上文要長度+1
position_ids[0, -n_task_tokens:] = torch.arange(max_window_size, max_window_size + n_task_tokens, 1)
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: # i.e., first token is already generated
position_ids = position_ids[:, -1].unsqueeze(-1)
elif windows_key_values: # i.e., we are in the first token generation #其實就是取-n_task_tokens:
position_ids = position_ids[:, sum_windows_size:]
def combine_past_key_values(past_lst: List[Tuple[Tuple[torch.Tensor]]],
contains_bos_token: bool = True) -> Tuple[Tuple[torch.Tensor]]:
# 這裡past_lst是每段文字的past-key-value
# GPT是n_layer * 2(key+value) * tensor(seq_len,batch,n_head,n_hidden)
# 注意不同模型past-key-value的shape不同
# Chatglm是n_layer * 2(key+value) * tensor(seq_len,batch, n_head, n_hidden)
return tuple(
(torch.cat([c[i][0] for c in past_lst], dim=2),
torch.cat([c[i][1] for c in past_lst], dim=2))
for i in range(len(past_lst[0])))
res['past_attention_mask'] = torch.cat([window['attention_mask'] for window in windows], dim=1)
combined_attention_mask = torch.cat((cache['past_attention_mask'], encoded_task_text['attention_mask']), dim=1)
考慮ChatGLM本身是二維的Attention矩陣和位置編碼,特殊的BOS和GMASK,我重寫了PCW,但是在長文字QA問題上表現比較一般,表現在當上文多段文字無明顯關係的時候例如多個完全無關的新聞,在進行問答的時候,正確答案中會混雜很多無關的文字變短,以及這個問題當上文片段變多,或者指令問題變多的時候會變得越來越嚴重,直到開始完全胡說八道。當然不排除我寫bug了哈哈哈,但我自己是真的沒查出來。
不過也有一種可能,是PCW是在輸入層就開始對超長上文進行Attention,因為不同上文的位置編碼相同,一定程度上會讓解碼注意力變得非常分散,導致注意力的熵值變高,解碼的不確定性變大,更容易出現亂碼。
- 蘇劍林. (May. 23, 2023). 《NBCE:使用樸素貝葉斯擴充套件LLM的Context處理長度 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9617
- 蘇劍林. (May. 31, 2023). 《關於NBCE方法的一些補充說明和分析 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9632
- https://github.com/bojone/NBCE
- 適用於Encoder-Decoder模型,長文字內容理解如摘要問答等場景
壓軸的必須是蘇神的NBCE!這裡我把看完部落格後的理解進行簡單的總結,詳細推理請看去蘇神的科學空間!答應我一定要去看!每次看蘇神推導,都會覺得數學之魂在燃燒!
NBCE的原理簡單解釋如下圖,和PCW相同是對每段上文進行獨立編碼,但差異在於PCW是在輸入層進行融合,而NBCE是在輸出層對每一個Step輸出的預測token的概率矩陣進行融合,更大程度上避免了注意力被分散,保證瞭解碼的合理性。
這裡我們簡單說下如何在輸出層進行融合,把找超長文字chunk成多段文字後($s_1,s_2,...s_k$),基於樸素貝葉斯的簡化假設, 基於多段文字進行並行解碼的預測概率可以簡化如下,也就是每段文字條件解碼概率之和減去無條件解碼概率 $$ log(P(T|s_1,..s_k)) = \sum_{i=1}^Klog(p(T|s_i)) -(n-1)log(p(T)) + const $$
既然說了是簡化假設,因此可以對上式進行一些調優,核心是讓模型對上文的解碼更加準確,降低無關上文帶來的解碼噪聲,比較重要的優化包括
以上解碼概率求和,其實是對k段文字生成的\(vocab * K\)的概率矩陣,沿K做AvergePooling,得到最終\(vocab*1\)的解碼概率。但考慮LM訓練其實是擬合one-hot(出現概率最高的詞),也就是除了概率最高的幾個token之外其餘token的預測概率都不靠譜。如果直接取平均的多路打分,很容易投出一個在各段文字上打分都不高不低的token,上文越多這個問題越明顯。但其實在閱讀理解例如抽取,QA問題的解碼策略上我們要的是在某段文字上打分置信度最高的token,因為答案往往只來自一個上文片段。
因此蘇神給出了兩種準確率更高的解碼方案,一個是MaxPooling+GreedySearch,其實就是對\(vocab*k\)的概率矩陣取全域性概率最高的token,另一個是最小熵+RandomSampling,也就是從多段上文中取1個預測置信度最高的上文進行解碼。這裡其實是和PCW最大的差異,也就是在解碼層進行融合,並通過熵值較低的融合策略來保證解碼的準確率。
以及後面蘇神還通過Top-P來進一步過濾尾部的噪聲,以及通過控制每一步解碼的轉移概率,來讓解碼器不會在不同上文片段之間反覆切換,而是保證連續的解碼片段大概率來自相同的上文片段。
基於上文來進行解碼的一個核心是為了降低模型回答胡說八道的概率。例如在金融場景我們直接問chatgpt基金贖回費用是多少 vs 我們基於某個基金的介紹問模型該基金的贖回費用是多少,後者得到的答案一定是更準確的。而其實以上二者的差異在於條件(上文)解碼和無條件解碼, 因此可以通過diff無條件編碼的方式來提高解碼對上文的依賴程度(reliablity)。如下圖
因此蘇神把把n變成超參Beta, 控制條件概率和無條件概率的佔比,Beta越高解碼和上文的關聯度越高,QA等場景的解碼準確率越高,生成自由度越低。
當前NBCE的侷限性在於無法處理上文片段之間的位置關係,以及無法處理解碼需要依賴多個上文片段的場景。後者感覺可以通過預測概率矩陣的相關性修改Pooling方式,而前者
基於蘇神提供的程式碼,在chatglm上做了嘗試,只需要簡單調整下輸入輸出的部分就可以直接使用。我在論文,書籍,和新聞上進行摘要,實體抽取和QA問答後發現,INT8量化的模型效果似乎要略優於FP16, 顯著優於INT4。INT8量化下,10K左右的輸入,視訊記憶體佔用基本可以限制在單卡A100(40g),大家可以自行嘗試下~
@torch.inference_mode()
def generate(max_tokens):
device = torch.device('cuda')
"""Naive Bayes-based Context Extension 演示程式碼
"""
inputs = tokenizer(batch, padding='longest', return_tensors='pt').to(device)
input_ids = inputs.input_ids
n = input_ids.shape[0]
with torch.no_grad():
for i in range(max_tokens):
# 模型輸出
model_input = model.prepare_inputs_for_generation(input_ids)
outputs = model(**model_input,
return_dict=True,
use_cache=True
)
"""
中間程式碼不變
"""
# 把唯一的回答擴充到每一個batch進行下一輪的解碼
next_tokens = next_tokens.unsqueeze(-1).tile(n, 1)
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
# 更新past-key-values, 更新attention_mask, 更新position_ids
model_kwargs = model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
)
Reference