普通的RNN对短时记忆比较敏感,如果输入序列很长,在反向传播期间,RNN 会面临梯度消失和梯度爆炸的问题。为解决这个问题,人们对RNN进行了很多改进,其中最有效的改进方式是引入门控机制, 于是就有了LSTM和GRU。

LSTM和GRU都属于RNN(Recurrent Neural Network, 循环神经网络)的一种,是对普通RNN的改进。

Basic RNN

普通RNN的基本结构如下图所示:

其中$t$时刻计算单元的图示为:

image-20190804101221896

综上,对于一个BASIC RNN CELL的计算步骤如下:

1.利用 $a^{\langle t-1 \rangle}$ 和 $x^{\langle t \rangle}$计算新的隐藏层激活状态:

$A = W_{aa} a^{\langle t-1 \rangle} $

$B= W_{ax} x^{\langle t \rangle} $

$ a^{\langle t \rangle} = \tanh(A+B+ b_a) $

$\hat{y}^{\langle t \rangle} = softmax(W_{ya} a^{\langle t \rangle} + b_y)$

2.用新的隐藏层激活状态 $a^{\langle t \rangle}$进行预测:

$\hat{y}^{\langle t \rangle} = softmax(W_{ya} a^{\langle t \rangle} + b_y)$

所以,对于$T_x $长度的输入序列$x$和$y$,RNN结构图如下:

输入序列:$x = (x^{\langle 1 \rangle}, x^{\langle 2 \rangle}, …, x^{\langle T_x \rangle})$

输出序列:$y = (y^{\langle 1 \rangle}, y^{\langle 2 \rangle}, …, y^{\langle T_x \rangle})$

image-20190804104058462

但是,普通RNN有较大的局限性,由于梯度消失或爆炸问题,很难建模长时间间隔 (Long Range )的状态之间的依赖关系。RNN计算参数梯度的方式和前馈神经网络不太相同。在循环神经网络中主要有两种计算梯度的方式:随时间反向传播(BPTT)和实时循环学习(RTRL)算法。RNN一般会比较难训练,这种难训练并不是因为activation function,而是来源于循环神经网络在学习过程中的梯度消失或爆炸问题,很难建模长时间间隔(Long Range)的状态之间的依赖关系。

RNN中梯度的计算和存储一般是通过时间反向传播(backpropagation through time, BPTT)进行的。通过时间反向传播是反向传播在循环神经⽹络中的具体应⽤。但是,当时间步数较⼤或者时间步较小时,循环神经⽹络的梯度较容易出现消失或爆炸。虽然裁剪梯度可以应对梯度爆炸,但⽆法解决梯度消失的问题。所以普通RNN较难捕捉时间序列中时间步距离较⼤的依赖关系。

为解决这个问题,⻔控循环神经⽹络(gated recurrent neural network )应运而生,正是为了更好地捕捉时间序列中时间步距离较⼤的依赖关系。它通过可以学习的⻔来控制信息的流动。其中,⻔控循环单元(gated recurrent unit, GRU )是⼀种常⽤的⻔控循环神经⽹络,另⼀种常⽤的⻔控循环神经⽹络是⻓短期记忆(long short-term memory, LSTM )

LSTM

LSTM 中引⼊了 3个⻔,即输⼊⻔( input gate)、遗忘⻔( forget gate)和输出⻔( output gate),以及与隐藏状态形状相同的记忆细胞(某些⽂献把记忆细胞当成⼀种特殊的隐藏状态),从而记录额外的信息。

LSTM的基本结构:(图中的update gate其实就是input gate)

对应上图的主要公式:


$$\Gamma_f^{\langle t \rangle} = \sigma(W_f[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_f)\tag{1} $$

$$\Gamma_u^{\langle t \rangle} = \sigma(W_u[a^{\langle t-1 \rangle}, x^{{t}}] + b_u)\tag{2} $$

$$ \tilde{c}^{\langle t \rangle} = \tanh(W_c[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_c)\tag{3} $$

$$ c^{\langle t \rangle} = \Gamma_f^{\langle t \rangle}* c^{\langle t-1 \rangle} + \Gamma_u^{\langle t \rangle} *\tilde{c}^{\langle t \rangle} \tag{4} $$

$$ \Gamma_o^{\langle t \rangle}= \sigma(W_o[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_o)\tag{5}$$

$$ a^{\langle t \rangle} = \Gamma_o^{\langle t \rangle}* \tanh(c^{\langle t \rangle})\tag{6} $$


Forget gate

For the sake of this illustration, lets assume we are reading words in a piece of text, and want use an LSTM to keep track of grammatical structures, such as whether the subject is singular or plural. If the subject changes from a singular word to a plural word, we need to find a way to get rid of our previously stored memory value of the singular/plural state. In an LSTM, the forget gate lets us do this:

$$\Gamma_f^{\langle t \rangle} = \sigma(W_f[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_f)\tag{1} $$

Here, $W_f$ are weights that govern the forget gate’s behavior. We concatenate $[a^{\langle t-1 \rangle}, x^{\langle t \rangle}]$ and multiply by $W_f$. The equation above results in a vector $\Gamma_f^{\langle t \rangle}$ with values between 0 and 1. This forget gate vector will be multiplied element-wise by the previous cell state $c^{\langle t-1 \rangle}$. So if one of the values of $\Gamma_f^{\langle t \rangle}$ is 0 (or close to 0) then it means that the LSTM should remove that piece of information (e.g. the singular subject) in the corresponding component of $c^{\langle t-1 \rangle}$. If one of the values is 1, then it will keep the information. (其中$a$也常用$h$来表示,如下图)

Update gate

Once we forget that the subject being discussed is singular, we need to find a way to update it to reflect that the new subject is now plural. Here is the formulat for the update gate:

$$\Gamma_u^{\langle t \rangle} = \sigma(W_u[a^{\langle t-1 \rangle}, x^{{t}}] + b_u)\tag{2} $$

Similar to the forget gate, here $\Gamma_u^{\langle t \rangle}$ is again a vector of values between 0 and 1. This will be multiplied element-wise with $\tilde{c}^{\langle t \rangle}$, in order to compute $c^{\langle t \rangle}$.

Updating the cell (更新细胞状态)

To update the new subject we need to create a new vector of numbers that we can add to our previous cell state. The equation we use is:

$$ \tilde{c}^{\langle t \rangle} = \tanh(W_c[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_c)\tag{3} $$

Finally, the new cell state is:

$$ c^{\langle t \rangle} = \Gamma_f^{\langle t \rangle}* c^{\langle t-1 \rangle} + \Gamma_u^{\langle t \rangle} *\tilde{c}^{\langle t \rangle} \tag{4} $$

Output gate

To decide which outputs we will use, we will use the following two formulas:

$$ \Gamma_o^{\langle t \rangle}= \sigma(W_o[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_o)\tag{5}$$

$$ a^{\langle t \rangle} = \Gamma_o^{\langle t \rangle}* \tanh(c^{\langle t \rangle})\tag{6} $$

Where in equation 5 you decide what to output using a sigmoid function and in equation 6 you multiply that by the $\tanh$ of the previous state.

综上,对于一个LSTM CELL的计算步骤如下:

  1. 首先拼接 $a^{\langle t-1 \rangle}$ 和 $x^{\langle t \rangle}$: $concat = \begin{bmatrix} a^{\langle t-1 \rangle} \ x^{\langle t \rangle} \end{bmatrix}$
  2. 利用上面的公式1-6进行计算.
  3. 利用$a^{\langle t \rangle}$ 计算$y^{\langle t \rangle}$. $\hat{y}^{\langle t \rangle} = softmax(W_{ya} a^{\langle t \rangle} + b_y)$

更多直观的LSTM图示:

其中一个LSTM单元结构:

通常会把$h^{t}$和$c^{t}$和$x^{t+1}$一起作为下一层的输入。

多层的LSTM

GRU

GRU将遗忘门和输入门合并成了更新门(forget gate + input gate => update gate),它只有两个门(重置门:reset 门更新门:update 门),参数量少,Training更加robust。

GRU公式:

image-20190804183517611

上式中的$a^{\langle t \rangle}$和$c^{<\langle t \rangle>}$等同于下式中的$h_{t}$,都表示$t$时刻的隐藏状态。


GRU的输入为:前一时刻隐藏层$h_{t-1}$和当前输入 $x_{t}$ , 输出为下一时刻隐藏层$h_{t}$
GRU 包含两个门:重置门(Reset Gate)更新门(Update Gate):

  • 重置门$r_{t}$用来计算候选隐藏层$\tilde{h}_{t}$,控制的是保留多少前一时刻隐藏层$h_{t-1}$的信息( 控制候选状态$\tilde{h}_{t}$的计算是否依赖上一时刻的状态$h_{t-1}$)
  • 更新门$z_{t}$用来控制加入多少候选隐藏层$\tilde{h}_{t}$的信息,从而得到输出$h_{t}$(控制当前状态$h_{t}$需要从历史状态$h_{t-1}$中保留多少信息(不经过非线性变换),以及需要从候选状态$ \tilde{h}_{t}$中接受多少新信息)
  • REFERENCE

    Illustrated Guide to LSTM’s and GRU’s: A step by step explanation

    Understanding LSTM Networks

    RECURRENT NEURAL NETWORKS

    Animated RNN, LSTM and GRU

    RNN, LSTM and GRU tutorial

    三次简化一张图:一招理解LSTM/GRU门控机制

    LSTMs for Time Series in PyTorch

    Simple LSTM - PyTorch version