一、唠唠近况

开学我就大三了,时间飞逝呀!两个月的暑假接近尾声,还算是充实。近况,一方面暑假准备的论文已经交由老师修改了,最后应该是由我检查全文语法错误而结束,最终投出去。另一方面,又又又在准备下一篇论文,初步想法是在原网络中融入LSTM。经过两天的网上冲浪,CSDN看了,视频也看了,最后发现还ChristopherOlah大佬的一篇Understanding LSTM Networks — colah’s blog最经典了,因此做部分记录,大部分都是翻译的啦。

二、了解RNN网络

在传统RNN中,模块将具有非常简单的结构,例如单个tanh层。

LSTM是针对传统RNN存在”长期依赖”问题而设计的网络架构,这并不是说RNN你能处理”长期依赖”。事实上通过认为的选定参数,RNN绝对有能力处理这种“长期依赖”,但这是一件吃力不讨好的事情。LSTM通过三个”门”的设计,连接长短期记忆,从而帮助RNN减轻了处理”长期依赖”问题的压力。下面是ChristopherOlah对RNN处理长、短期问题的综述:

有时,我们只需要查看最近的信息即可执行当前任务。例如,考虑一个语言模型,试图根据前一个单词预测下一个单词。如果我们试图预测“云在天空中”中的最后一个词,我们不需要任何进一步的上下文 - 很明显下一个词将是天空。在这种情况下,相关信息与需要的地方之间的距离很小,RNN可以学习使用过去的信息。

但在某些情况下,我们需要更多的背景信息。考虑尝试预测文本中的最后一个词“我在法国长大……我能说一口流利的法语”。最近的信息表明,下一个单词可能是一种语言的名称,但如果我们想缩小哪种语言的范围,我们需要更远的法国语境。相关信息与需要信息的点之间的距离完全有可能变得非常大。

不幸的是,随着差距的扩大,RNN变得无法学习连接信息。

Neural networks struggle with long term dependencies.

三、了解LSTM网络

相较于传统RNN网络,LSTM也有这种链状结构,但重复模块具有不同的结构。不是只有一个神经网络层,而是有四个交互层,以一种非常特殊的方式相互作用。

A LSTM neural network.

在上图中,每条线承载着一个完整的向量,从一个节点的输出到其他节点的输入。粉红色圆圈表示逐点运算,如向量加法,而黄色框是学习神经网络层。行合并表示串联,而行分叉表示其内容被复制并且副本将转到不同的位置。

四、LSTM的核心思想

LSTM的核心思想,简单理解就是在原网络增加一条主线。观察LSTM和RNN的架构,不难发现LSTM上方添加了一条很长的主线,该主线的信息决定了输出门的输出。我们称这条主线细胞状态,它只有一些轻微的线性作用,信息很容易原封不动地沿着它流动。然后,LSTM能够删除或添加信息到细胞状态,由称为的结构仔细调节。是一种选择性地让信息通过的方式。由 sigmoid 神经网络层和逐点乘法运算组成。

五、LSTM的三个门

5.1 遗忘门

LSTM 的第一步是决定要从细胞状态中忘记哪些信息,由sigmoid层做出判断。通过查看$[h{(t-1)},x_t]$信息来输出一个0-1的向量,该向量里面的0-1值表示细胞状态$C{t-1}$中哪些信息将被遗忘或者保留。下图是遗忘门的图示:

5.2 输入门

下一步是决定将在细胞状态中存储哪些新信息。这分为两部分。首先,输入门的sigmoid层决定我们将更新哪些值。接下来,tanh 层创建一个新候选值的向量,可以添加到状态中。在下一步中,我们将结合这两者来创建状态更新。

将旧状态乘以$f_t$,忘记了之前决定忘记的事情。然后我们添加$i_t*C{_t}$.这是新的候选值,按决定更新每个状态值的程度进行缩放。

5.3 输出门

LSTM的输出是基于细胞状态的,但是过滤版本。首先,我们运行一个 sigmoid 层,它决定我们要输出细胞状态的哪些部分。然后,我们将单元格状态通过tanh(将值规范到-1到1),并将其乘以 Sigmoid 门的输出,这样就只输出了我们决定的部分。

六、Pytorch实战

在Pytorch中当然有直接调包的nn.LSTM可供使用,但是为了更好的理解LSTM,所以最好是自己实现一个:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn as nn

class LSTMCell3D(nn.Module):
def __init__(self, input_size, hidden_size):
super(LSTMCell3D, self).__init__()
self.hidden_size = hidden_size
self.input_size = input_size

self.W_i = nn.Linear(input_size, hidden_size)
self.U_i = nn.Linear(hidden_size, hidden_size)
self.W_f = nn.Linear(input_size, hidden_size)
self.U_f = nn.Linear(hidden_size, hidden_size)
self.W_c = nn.Linear(input_size, hidden_size)
self.U_c = nn.Linear(hidden_size, hidden_size)
self.W_o = nn.Linear(input_size, hidden_size)
self.U_o = nn.Linear(hidden_size, hidden_size)

def forward(self, x, h, c):
i_t = torch.sigmoid(self.W_i(x) + self.U_i(h))
f_t = torch.sigmoid(self.W_f(x) + self.U_f(h))
c_tilda = torch.tanh(self.W_c(x) + self.U_c(h))
c_next = f_t * c + i_t * c_tilda
o_t = torch.sigmoid(self.W_o(x) + self.U_o(h))
h_next = o_t * torch.tanh(c_next)
return h_next, c_next

class LSTM3D(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(LSTM3D, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers

self.layers = nn.ModuleList()
self.layers.append(LSTMCell3D(input_size, hidden_size))
for _ in range(num_layers - 1):
self.layers.append(LSTMCell3D(hidden_size, hidden_size))

def forward(self, x):
batch_size, seq_len, input_size = x.size()
hidden_states = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)
cell_states = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)

for t in range(seq_len):
for layer in range(self.num_layers):
if layer == 0:
h, c = self.layers[layer](x[:, t, :], hidden_states[layer], cell_states[layer])
else:
h, c = self.layers[layer](h, hidden_states[layer], cell_states[layer])

hidden_states[layer] = h
cell_states[layer] = c

return h, c

# 使用示例
input_size = 10
hidden_size = 20
num_layers = 3
batch_size = 5
seq_len = 10

lstm = LSTM3D(input_size, hidden_size, num_layers)
x = torch.randn(batch_size, seq_len, input_size)
output, _ = lstm(x)
print(output.size())