大語言模型基礎-Transformer模型詳解和訓練

2023-10-25 06:00:12

一、Transformer概述

Transformer是由谷歌在17年提出並應用於神經機器翻譯的seq2seq模型,其結構完全通過自注意力機制完成對源語言序列和目標語言序列的全域性依賴建模

Transformer由編碼器解碼器構成。圖2.1展示了該結構,其左側和右側分別對應著編碼器(Encoder)和解碼器(Decoder)結構,它們均由若干個基本的 Transformer Encoder/Decoder Block(N×表示N次堆疊)。

二、Transformer結構與實現

2.1、嵌入表示層

對於輸入文字序列,首先通過輸入嵌入層(Input Embedding)將每個單詞轉換為其相對應的向量表示。通常直接對每個單詞建立一個向量表示。

注意:在翻譯問題中,有兩個詞彙表,分別對應源語言和目標語言。

由於Transfomer中沒有任何資訊能表示單詞間的相對位置關係,故需在詞嵌入中加入位置編碼(Positional Encoding)

具體來說,序列中每一個單詞所在的位置都對應一個向量。這一向量會與單詞表示對應相加並送入到後續模組中做進一步處理。

在訓練的過程當中,模型會自動地學習到如何利用這部分位置資訊。

2.1.1、詞元嵌入層

初始化詞彙表(對原始詞彙表用BPE(Byte Pair Encoding)進行壓縮分詞,得到最終的詞元list)

self.embedding = nn.Embedding(vocab_size, num_hiddens)

2.1.2、位置編碼

為了使用序列的順序資訊,通過在輸入表示中新增位置編碼(positional encoding)來注入絕對的或相對的位置資訊。

位置編碼可以通過學習得到也可以直接固定得到。接下將介紹基於正弦函數和餘弦函數的固定位置編碼。

假設輸入\(\mathbf{X} \in \mathbb{R}^{n \times d}\)表示包含一個序列中\(n\)個詞元的\(d\)維嵌入表示。 位置編碼使用相同形狀的位置嵌入矩陣\(\mathbf{P} \in \mathbb{R}^{n \times d}\) 輸出 \(\mathbf{X} +\mathbf{P}\), 矩陣第行\(pos\)、第列\(2i\)和列上\(2i+1\)的元素為:

\[\begin{split}\begin{aligned} p_{(pos, 2i)} &= \sin\left(\frac{pos}{10000^{2i/d}}\right),\\p_{(pos, 2i+1)} &= \cos\left(\frac{pos}{10000^{2i/d}}\right).\end{aligned}\end{split} \]

其中,\(pos\)表示單詞所在的位置,\(2i\)\(2i+ 1\)表示位置編碼向量中的對應維度,\(d\) 則對應位置編碼的總維度。

通過上面這種方式計算位置編碼有這樣幾個好處:

  • 首先,正餘弦函數的範圍是在 [-1,+1],匯出的位置編碼與原詞嵌入相加不會使得結果偏離過遠而破壞原有單詞的語意資訊。

  • 其次,依據三角函數的基本性質,可以得知第\(pos + k\)個位置的編碼是第\(pos\)個位置的編碼的線性組合,這就意味著位置編碼中蘊含著單詞之間的距離資訊。

class PositionalEncoding(nn.Module):
    """位置編碼"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 建立一個足夠長的P
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

2.1、多頭自注意力(Multi-Head-self-Attention)

2.2.1、自注意力機制

1) 縮放點積注意力(scaled dot-product attention)
假設有查詢向量(query) $ \mathbf{q} \in \mathbb{R}^{1 \times d} $ 和 鍵向量(key) $ \mathbf{k} \in \mathbb{R}^{1 \times d} $,查詢向量和鍵向量點積的結果即為注意力得分。

\[a(\mathbf q, \mathbf k) = \mathbf{q} \mathbf{k}^\top \]

將縮放點積注意力推廣到批次矩陣形勢,其公式為:

\[\mathbf{Z} = \mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times d} \]

其中,\(\mathbf Q\in\mathbb R^{m\times d}\)\(\mathbf K \in\mathbb R^{n\times d}\)\(\mathbf V\in\mathbb R^{n\times d}\)

考慮到在\(d\)過大時,點積值較大會使得後續Softmax操作溢位導致梯度爆炸,不利於模型優化。故將注意力得分除以\(\sqrt{d}\)進行縮放。

注:當\(m=1\)時,就是傳統的注意力機制(1個\(q\), 多個\(k\),\(v\))。

import math
import torch
from torch import nn

class DotProductAttention(nn.Module):
    """縮放點積注意力"""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # queries的形狀:(batch_size,查詢的個數,d)
    # keys的形狀:(batch_size,「鍵-值」對的個數,d)
    # values的形狀:(batch_size,「鍵-值」對的個數,值的維度)
    # valid_lens的形狀:(batch_size,)或者(batch_size,查詢的個數)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # 設定transpose_b=True為了交換keys的最後兩個維度
        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

為批次處理資料或在自迴歸處理時避免資訊洩露等情況,在Token序列中填充[mask]Token,從而使一些值不納入注意力匯聚計算。這裡可指定一個有效序列長度(即Token個數), 以便在計算softmax時過濾掉超出指定範圍的位置。

注:該縮放點積注意力的實現使用了dropout進行正則化。

masked_softmax函數實現了掩碼\(softmax\)操作(masked softmax operation), 其中任何超出有效長度的位置都被掩蔽並置為\(0\)(將掩碼位置的注意力係數變為無窮小\(-inf\)\(Softmax\)後的值為一個接近\(0\)的值)

def masked_softmax(X, valid_lens):
    """通過在最後一個軸上掩蔽元素來執行softmax操作"""
    # X:3D張量,valid_lens:1D或2D張量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最後一軸上被掩蔽的元素使用一個非常大的負值替換,從而其softmax輸出為0
        X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
                              value=-1e9)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

def sequence_mask(X, valid_len, value=0):
    """在序列中遮蔽不相關的項"""
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32,
                        device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X

2)自注意力

當n=m時,且\(\mathbf{Q}\)\(\mathbf{K}\)\(\mathbf{V}\)均源於輸入\(\mathbf{X} \in\mathbb R^{n\times d}\)經過不同的線性變換時,縮放點積注意力即推廣為自注意力。

這時,每個查詢都會關注所有的鍵值對並生成一個注意力輸出。 由於查詢、鍵和值來自同一組輸,故稱為Self-Attention。

2.2.2、多頭自注意力

class MultiHeadAttention(nn.Module):
    """多頭注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries,keys,values的形狀:
        # (batch_size,查詢或者「鍵-值」對的個數,num_hiddens)
        # valid_lens 的形狀:
        # (batch_size,)或(batch_size,查詢的個數)
        # 經過變換後,輸出的queries,keys,values 的形狀:
        # (batch_size*num_heads,查詢或者「鍵-值」對的個數,
        # num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在軸0,將第一項(標量或者向量)複製num_heads次,
            # 然後如此複製第二項,然後諸如此類。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # output的形狀:(batch_size*num_heads,查詢的個數,
        # num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形狀:(batch_size,查詢的個數,num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

為了能夠使多個頭平行計算, 上面的MultiHeadAttention類將使用下面定義的兩個轉置函數。 具體來說,transpose_output函數反轉了transpose_qkv函數的操作。

```python
def transpose_qkv(X, num_heads):
    """為了多注意力頭的平行計算而變換形狀"""
    # 輸入X的形狀:(batch_size,查詢或者「鍵-值」對的個數,num_hiddens)
    # 輸出X的形狀:(batch_size,查詢或者「鍵-值」對的個數,num_heads,
    # num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 輸出X的形狀:(batch_size,num_heads,查詢或者「鍵-值」對的個數,
    # num_hiddens/num_heads)
    X = X.transpose(0, 2, 1, 3)

    # 最終輸出的形狀:(batch_size*num_heads,查詢或者「鍵-值」對的個數,
    # num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    """逆轉transpose_qkv函數的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.transpose(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5)
print(attention)

batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
attention(X, X, X, valid_lens).shape

2.3、前饋網路

位置感知的前饋網路對序列中的所有位置的表示進行變換時使用的是同一個2層全連線網路,故稱其為positionwise的前饋網路。

\[{FFN}(\mathbf x) = Relu(\mathbf{x} \mathbf{W}_1 + \mathbf{b}_1)\mathbf{W}_2 + \mathbf{b}_2 \]

在下面的實現中,輸入X的形狀(批次大小,時間步數或序列長度,隱單元數或特徵維度)將被一個兩層的感知機轉換成形狀為(批次大小,時間步數,ffn_num_outputs)的輸出張量。

class PositionWiseFFN(nn.Module):
    """基於位置的前饋網路"""
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
                 **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

2.4、殘差連線和層規一化

add&norm元件是由殘差連線和緊隨其後的層規一化組成的,它被用來進一步提升訓練的穩定性。

1)殘差連線
殘差連線引入輸入直接到輸出的通路,便於梯度回傳從而緩解在優化過程中由於網路過深引起的梯度消失問題。

\[\mathbf{x}^{l+1} = f(\mathbf{x}^l) + \mathbf{x}^l \]

2)層歸一化
層歸一化(Layer Normalization)是基於特徵維度進行規範化,將資料進行標準化(乘以縮放係數、加上平移係數,保留其非線效能力。

\[{LN}(\mathbf x) = \alpha (\frac{\mathbf x - \mu }{\sigma}) + \beta \]

層歸一化可以有效地緩解優化過程中潛在的不穩定、收斂速度慢等問題。

以下程式碼對比不同維度的層規範化和批次規範化的效果。

ln = nn.LayerNorm(2)
bn = nn.BatchNorm1d(2)
X = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32)
# 在訓練模式下計算X的均值和方差
print('layer norm:', ln(X), '\nbatch norm:', bn(X))

層歸一化實現

class NormLayer(nn.Module):
    def __init__(self, d_model, eps = 1e-6):
        super().__init__()
        self.size = d_model
        # 層歸一化包含兩個可以學習的引數
        self.alpha = nn.Parameter(torch.ones(self.size))
        self.bias = nn.Parameter(torch.zeros(self.size))
        self.eps = eps

    def forward(self, x):
        norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \
        / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
        return norm

使用殘差連線和層規一化來實現AddNorm類

class AddNorm(nn.Module):
    """殘差連線後進行層規範化"""
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

2.5、編碼器

現在可以基於編碼器的基礎元件實現編碼器的一個層。

下面的EncoderBlock類包含兩個子層:多頭自注意力基於位置的前饋網路,這兩個子層都使用了殘差連線和緊隨的層規一化。

class EncoderBlock(nn.Module):
    """Transformer編碼器塊"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention = MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout,
            use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(
            ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

下面的Transformer編碼器中,堆疊了num_layers個EncoderBlock類的範例。

由於這裡使用的是值範圍在-1和1之間的固定位置編碼,因此通過學習得到的輸入的嵌入表示的值需要先乘以嵌入維度的平方根進行重新縮放,然後再與位置編碼相加。

class TransformerEncoder(Encoder):
    """Transformer編碼器"""
    def __init__(self, vocab_size, key_size, query_size, value_size,
                 num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
                 num_heads, num_layers, dropout, use_bias=False, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
                EncoderBlock(key_size, query_size, value_size, num_hiddens,
                             norm_shape, ffn_num_input, ffn_num_hiddens,
                             num_heads, dropout, use_bias))

    def forward(self, X, valid_lens, *args):
        # 因為位置編碼值在-1和1之間,
        # 因此嵌入值乘以嵌入維度的平方根進行縮放,
        # 然後再與位置編碼相加。
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self.attention_weights = [None] * len(self.blks)
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights[
                i] = blk.attention.attention.attention_weights
        return X

2.6、解碼器

1) 掩碼多頭注意力
解碼器的每個Transformer塊的第一個自注意力子層額外增加了注意力掩碼,對應圖中的掩碼多頭注意力(Masked Multi-Head Attention)部分。

因為在翻譯的過程中,編碼器用於編碼已知的源語言序列的資訊,因而它只需要考慮如何融合上下文語意資訊即可。而解碼端則負責生成目標語言序列,這一自迴歸的過程意味著,在生成每一個單詞時,僅有當前單詞之前的目標語言序列是可觀測的。

增加的Mask是用來避免模型在訓練階段直接看到後續的文字序列(資訊洩露)進而無法得到有效地訓練。

2) 交叉注意力
解碼器端還增加了一個多頭注意力(Multi-Head Attention)模組,使用交叉注意力(Cross-attention)方法,同時接收來自編碼器端的輸出以及當前 Transformer 塊的前一個掩碼注意力層的輸出。

Query是通過解碼器前一層的輸出進行投影的,而Key和Value是使用編碼器的輸出進行投影的。它的作用是在翻譯的過程當中,為了生成合理的目標語言序列需要觀測待翻譯的源語言序列是什麼。

基於上述的編碼器和解碼器結構,待翻譯的源語言文字,首先經過編碼器端的每個Transformer塊對其上下文語意的層層抽象,最終輸出每一個源語言單詞上下文相關的表示。

解碼器端以自迴歸的方式生成目標語言文字,即在每個時間步t,根據編碼器端輸出的源語言文字表示,以及前 t -1 個時刻生成的目標語言文字,生成當前時刻的目標語言單詞

class DecoderBlock(nn.Module):
    """解碼器中第i個塊"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, i, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)
        self.i = i
        self.attention1 = MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.attention2 = MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm2 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,
                                   num_hiddens)
        self.addnorm3 = AddNorm(norm_shape, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        # 訓練階段,輸出序列的所有詞元都在同一時間處理,
        # 因此state[2][self.i]初始化為None。
        # 預測階段,輸出序列是通過詞元一個接著一個解碼的,
        # 因此state[2][self.i]包含著直到當前時間步第i個塊解碼的輸出表示
        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i], X), axis=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            # dec_valid_lens的開頭:(batch_size,num_steps),
            # 其中每一行是[1,2,...,num_steps]
            dec_valid_lens = torch.arange(
                1, num_steps + 1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None

        # 自注意力
        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)
        # 編碼器-解碼器注意力。
        # enc_outputs的開頭:(batch_size,num_steps,num_hiddens)
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state

2.7、Transformer

class EncoderDecoder(nn.Module):
    """編碼器-解碼器架構的基礎類別"""
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)

三、Transformer訓練

損失

class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    """帶遮蔽的softmax交叉熵損失函數"""
    # pred的形狀:(batch_size,num_steps,vocab_size)
    # label的形狀:(batch_size,num_steps)
    # valid_len的形狀:(batch_size,)
    def forward(self, pred, label, valid_len):
        weights = torch.ones_like(label)
        weights = sequence_mask(weights, valid_len)
        self.reduction='none'
        unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(
            pred.permute(0, 2, 1), label)
        weighted_loss = (unweighted_loss * weights).mean(dim=1)
        return weighted_loss

def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):
    """訓練序列到序列模型"""
    def xavier_init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    nn.init.xavier_uniform_(m._parameters[param])

    net.apply(xavier_init_weights)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = MaskedSoftmaxCELoss()
    net.train()
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                     xlim=[10, num_epochs])
    for epoch in range(num_epochs):
        timer = d2l.Timer()
        metric = d2l.Accumulator(2)  # 訓練損失總和,詞元數量
        for batch in data_iter:
            optimizer.zero_grad()
            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
            bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],
                          device=device).reshape(-1, 1)
            dec_input = torch.cat([bos, Y[:, :-1]], 1)  # 強制教學
            Y_hat, _ = net(X, dec_input, X_valid_len)
            l = loss(Y_hat, Y, Y_valid_len)
            l.sum().backward()      # 損失函數的標量進行「反向傳播」
            d2l.grad_clipping(net, 1)
            num_tokens = Y_valid_len.sum()
            optimizer.step()
            with torch.no_grad():
                metric.add(l.sum(), num_tokens)
        if (epoch + 1) % 10 == 0:
            animator.add(epoch + 1, (metric[0] / metric[1],))
    print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} '
        f'tokens/sec on {str(device)}')

訓練語料為句子對

import torch
from torch import nn

num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
lr, num_epochs, device = 0.005, 200, d2l.try_gpu()
ffn_num_hiddens, num_heads = 64, 4

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)

encoder = TransformerEncoder(
    len(src_vocab), num_hiddens, ffn_num_hiddens, num_heads, num_layers,
    dropout)
decoder = TransformerDecoder(
    len(tgt_vocab), num_hiddens, ffn_num_hiddens, num_heads, num_layers,
    dropout)
net = EncoderDecoder(encoder, decoder)
train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

# Test
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation, dec_attention_weight_seq = d2l.predict_seq2seq(
        net, eng, src_vocab, tgt_vocab, num_steps, device, True)
    print(f'{eng} => {translation}, ',
          f'bleu {d2l.bleu(translation, fra, k=2):.3f}')

參考連結

【1】大規模語言模型:從理論到實踐, 第二章
【2】動手學深度學習 https://zh.d2l.ai/chapter_attention-mechanisms/transformer.html
【3】NoteBook: https://colab.research.google.com/github/d2l-ai/d2l-pytorch-colab/blob/master/chapter_attention-mechanisms-and-transformers/transformer.ipynb

後續Todo

1、分詞器
2、GPT系列
3、Llama2系列:原理、微調、預訓練、SFT、RLHF
4、BERT系列
5、LLM訓練推理加速