一、唠唠近况
开学我就大三了,时间飞逝呀!两个月的暑假接近尾声,还算是充实。近况,一方面暑假准备的论文已经交由老师修改了,最后应该是由我检查全文语法错误而结束,最终投出去。另一方面,又又又在准备下一篇论文,初步想法是在原网络中融入LSTM。经过两天的网上冲浪,CSDN看了,视频也看了,最后发现还ChristopherOlah大佬的一篇Understanding LSTM Networks – colah’s blog最经典了,因此做部分记录,大部分都是翻译的啦。
二、了解RNN网络
在传统RNN中,模块将具有非常简单的结构,例如单个tanh层。
LSTM是针对传统RNN存在"长期依赖"问题而设计的网络架构,这并不是说RNN你能处理"长期依赖"。事实上通过认为的选定参数,RNN绝对有能力处理这种“长期依赖”,但这是一件吃力不讨好的事情。LSTM通过三个"门"的设计,连接长短期记忆,从而帮助RNN减轻了处理"长期依赖"问题的压力。下面是ChristopherOlah对RNN处理长、短期问题的综述:
有时,我们只需要查看最近的信息即可执行当前任务。例如,考虑一个语言模型,试图根据前一个单词预测下一个单词。如果我们试图预测“云在天空中”中的最后一个词,我们不需要任何进一步的上下文 - 很明显下一个词将是天空。在这种情况下,相关信息与需要的地方之间的距离很小,RNN可以学习使用过去的信息。
但在某些情况下,我们需要更多的背景信息。考虑尝试预测文本中的最后一个词“我在法国长大…我能说一口流利的法语”。最近的信息表明,下一个单词可能是一种语言的名称,但如果我们想缩小哪种语言的范围,我们需要更远的法国语境。相关信息与需要信息的点之间的距离完全有可能变得非常大。
不幸的是,随着差距的扩大,RNN变得无法学习连接信息。
三、了解LSTM网络
相较于传统RNN网络,LSTM也有这种链状结构,但重复模块具有不同的结构。不是只有一个神经网络层,而是有四个交互层,以一种非常特殊的方式相互作用。
在上图中,每条线承载着一个完整的向量,从一个节点的输出到其他节点的输入。粉红色圆圈表示逐点运算,如向量加法,而黄色框是学习神经网络层。行合并表示串联,而行分叉表示其内容被复制并且副本将转到不同的位置。
四、LSTM的核心思想
LSTM的核心思想,简单理解就是在原网络增加一条主线。观察LSTM和RNN的架构,不难发现LSTM上方添加了一条很长的主线,该主线的信息决定了输出门的输出。我们称这条主线为细胞状态,它只有一些轻微的线性作用,信息很容易原封不动地沿着它流动。然后,LSTM能够删除或添加信息到细胞状态,由称为门的结构仔细调节。门是一种选择性地让信息通过的方式。由 sigmoid 神经网络层和逐点乘法运算组成。
五、LSTM的三个门
5.1 遗忘门
LSTM 的第一步是决定要从细胞状态中忘记哪些信息,由sigmoid层做出判断。通过查看[h(t−1),xt]信息来输出一个0-1的向量,该向量里面的0-1值表示细胞状态Ct−1中哪些信息将被遗忘或者保留。下图是遗忘门的图示:
5.2 输入门
下一步是决定将在细胞状态中存储哪些新信息。这分为两部分。首先,输入门的sigmoid层决定我们将更新哪些值。接下来,tanh 层创建一个新候选值的向量,可以添加到状态中。在下一步中,我们将结合这两者来创建状态更新。
将旧状态乘以ft,忘记了之前决定忘记的事情。然后我们添加it∗Ct.这是新的候选值,按决定更新每个状态值的程度进行缩放。
5.3 输出门
LSTM的输出是基于细胞状态的,但是过滤版本。首先,我们运行一个 sigmoid 层,它决定我们要输出细胞状态的哪些部分。然后,我们将单元格状态通过tanh(将值规范到-1到1),并将其乘以 Sigmoid 门的输出,这样就只输出了我们决定的部分。
六、Pytorch实战
在Pytorch中当然有直接调包的nn.LSTM可供使用,但是为了更好的理解LSTM,所以最好是自己实现一个:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
| import torch import torch.nn as nn
class LSTMCell3D(nn.Module): def __init__(self, input_size, hidden_size): super(LSTMCell3D, self).__init__() self.hidden_size = hidden_size self.input_size = input_size self.W_i = nn.Linear(input_size, hidden_size) self.U_i = nn.Linear(hidden_size, hidden_size) self.W_f = nn.Linear(input_size, hidden_size) self.U_f = nn.Linear(hidden_size, hidden_size) self.W_c = nn.Linear(input_size, hidden_size) self.U_c = nn.Linear(hidden_size, hidden_size) self.W_o = nn.Linear(input_size, hidden_size) self.U_o = nn.Linear(hidden_size, hidden_size) def forward(self, x, h, c): i_t = torch.sigmoid(self.W_i(x) + self.U_i(h)) f_t = torch.sigmoid(self.W_f(x) + self.U_f(h)) c_tilda = torch.tanh(self.W_c(x) + self.U_c(h)) c_next = f_t * c + i_t * c_tilda o_t = torch.sigmoid(self.W_o(x) + self.U_o(h)) h_next = o_t * torch.tanh(c_next) return h_next, c_next
class LSTM3D(nn.Module): def __init__(self, input_size, hidden_size, num_layers): super(LSTM3D, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.layers = nn.ModuleList() self.layers.append(LSTMCell3D(input_size, hidden_size)) for _ in range(num_layers - 1): self.layers.append(LSTMCell3D(hidden_size, hidden_size)) def forward(self, x): batch_size, seq_len, input_size = x.size() hidden_states = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device) cell_states = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device) for t in range(seq_len): for layer in range(self.num_layers): if layer == 0: h, c = self.layers[layer](x[:, t, :], hidden_states[layer], cell_states[layer]) else: h, c = self.layers[layer](h, hidden_states[layer], cell_states[layer]) hidden_states[layer] = h cell_states[layer] = c return h, c
input_size = 10 hidden_size = 20 num_layers = 3 batch_size = 5 seq_len = 10
lstm = LSTM3D(input_size, hidden_size, num_layers) x = torch.randn(batch_size, seq_len, input_size) output, _ = lstm(x) print(output.size())
|