1.LSTM的原理
LSTM是RNN(迴圈神經網路)的變體,全名為長短期記憶網路(Long Short Term Memory networks)。
它的精髓在於引入了細胞狀態這樣一個概念,不同於RNN只考慮最近的狀態,LSTM的細胞狀態會決定哪些狀態應該被留下來,哪些狀態應該被遺忘。
具體與RNN的區別可參考這篇博文:LSTM與RNN的比較
先放一張LSTM網路的模型圖:
如上圖所示,可以看到這是一個網路,我們單拿出其中一個單元來進行分析,可見每一個單元都包含一系列運算,那麼這些運算的意義是什麼呢?下面我們來一一解釋每個單元的具體內容。
(1)遺忘門
ht-1 :前一個時刻的Cell的輸出
xt : 當前時刻的輸入
注意:中括號的意思是將ht-1與xt拼接起來,後面出現公式同理
遺忘門主要來判斷上一狀態中的輸出對現狀態的影響大小,遺忘門的輸出要通過一個Sigmoid函數,Sigmoid函數的輸出範圍是0~1,相當於得到一個權重,後面與Ct-1相乘,以此得到上一狀態輸出對現狀態的影響。
(2)輸入門
輸入門中會得到一個臨界的細胞狀態(Ct^),表示此狀態下的備選輸出,與it作用後就得到此次狀態需要輸出的內容。
由以上兩個門就可以輸出更新後的細胞狀態Ct,輸出公式如上圖所示,需要注意這裡的「 * 」符號為哈達瑪乘積,就是對應矩陣元素相乘。
(3)輸出門
輸出門具體運算過程如上圖所示。這樣就得到了這個時刻的輸出,把這個輸出再傳入下一狀態即可。
2.程式碼實現
初始化:
import torch
import torch.nn as nn
搭建一個LSTM單元:
class LSTMCell(nn.Module):
def __init__(self,input_size,hidden_size,cell_size,output_size):
super(LSTMCell,self).__init__()
self.hidden_size = hidden_size
self.cell_size = cell_size
#設定門輸入輸出資料的大小尺寸
self.gate = nn.Linear(input_size+hidden_size,cell_size)
self.output = nn.Linear(hidden_size,output_size)
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
#分類器-輸出
self.softmax = nn.LogSoftmax(dim=1)
def forward(self,input,hidden,cell):
#拼接資料,後置的0/1 確定橫向(1)還是豎向(0)拼接
combined = torch.cat((input,hidden),1)
#根據LSTM一個單元的網路圖得出三個門,並進行運算
f_gate = self.sigmoid(self.gate(combined))
i_gate = self.sigmoid(self.gate(combined))
#z_state看作為Cell的中間狀態
z_state = self.tanh(self.gate(combined))
o_gate = self.sigmoid(self.gate(combined))
#注意這下面的乘為哈達瑪乘積,矩陣對應元素相乘
cell = torch.add(torch.mul(f_gate,cell),torch.mul(i_gate,z_state))
hidden = torch.mul(self.tanh(cell),o_gate)
output = self.output(hidden)
output = self.softmax(output)
return output,hidden,cell
def initHidden(self):
return torch.zeros(1,self.hidden_size)
def initCell(self):
return torch.zeros(1,self.cell_size)
範例化LSTMCell,並傳入輸入、隱含狀態等進行驗證:
lstmcell = LSTMCell(input_size=10,hidden_size=20,cell_size=20,output_size=10)
input = torch.randn(32,10)
h_0 = torch.randn(32,20)
c_0 = torch.randn(32,20)
output,hn,cn = lstmcell(input,h_0,c_0)
print(output.size(),hn.size(),cn.size())
輸出結果:
torch.Size([32, 10]) torch.Size([32, 20]) torch.Size([32, 20])
end
(以上圖片來源於網路,若侵權請聯絡刪除)