Skip to content

Latest commit

 

History

History
289 lines (177 loc) · 22.1 KB

File metadata and controls

289 lines (177 loc) · 22.1 KB

第二节 LSTM 与 GRU

在了解了 RNN 的基本架构及其固有的缺陷后,本节将探讨两种经典的 RNN 改进方案——长短期记忆网络 (LSTM) 与门控循环单元 (GRU),并剖析它们是如何通过精巧的结构设计来克服长距离依赖这一挑战的。

一、LSTM 与门控机制

常规 RNN 的问题是它内部状态的更新方式是“粗暴”的。每一步的新信息都会与旧信息(隐藏状态)无差别地混合,并通过权重矩阵 $W$ 进行变换。这种强制性的矩阵乘法,无论信息重要与否,都会在反向传播中形成梯度累乘,导致梯度信号的衰减或爆炸。LSTM 的设计哲学是赋予网络自行决定信息取舍的能力。它不再强制性地混合所有信息,而是引入了 “门控机制”(Gating Mechanism),让模型在训练过程中学会有选择地让信息通过、遗忘旧信息或输出信息 1

1.1 从单一状态到双轨并行

与 RNN 只有一个隐藏状态 $h_t$ 在时间步之间传递不同,LSTM 引入了两个独立的状态向量在时间轴上并行传递:

(1)细胞状态(Cell State, $c_t$: 这是 LSTM 的核心,原始论文中称之为 “恒定误差旋转木马”(Constant Error Carousel, CEC)。可以把它想象成一条“信息高速公路”或“传送带”,负责在整个序列中传递长期记忆。信息可以直接在这条传送带上流动,仅经过按元素的加权与相加,没有额外的矩阵连乘,从而极大地缓解了梯度消失问题。

(2)隐藏状态(Hidden State, $h_t$: 与 RNN 中的隐藏状态类似,代表了当前时间步的短期记忆最终输出。$h_t$ 的计算依赖于当前的细胞状态 $c_t$

LSTM 通过门控单元,来控制细胞状态 $c_t$ 这条“高速公路”在每个时间点应该遗忘什么旧内容,以及应该写入什么新内容。

1.2 门的结构

LSTM 中的“门”是一种让信息选择性通过的结构,设计灵感来源于数字电路中的逻辑门。它的实现非常简单,就是一个以 Sigmoid 为激活函数的全连接层。这个全连接层的输入通常是当前时间步的输入 $x_t$ 和上一个时间步的隐藏状态 $h_{t-1}$ 的拼接向量,经线性变换后通过 Sigmoid 函数,最终输出一个元素值在 (0, 1) 区间内的向量。这个向量会与另一个向量进行按元素乘法,当门输出向量的某个元素接近 1 时意味着“允许”对应维度的信息完全通过,而当其接近 0 时则意味着“阻止”对应维度的信息通过,即“遗忘”或“忽略”它。

从另一个角度看,门控机制也是为了解决权重冲突问题。其中输入门保护细胞状态不受无关输入的干扰,输出门则保护其他单元不受当前细胞状态中无关记忆的干扰。LSTM 内部署了三个这样的门,来精确控制信息的流动。

二、LSTM 单元结构

一个 LSTM 单元在 $t$ 时刻接收三个输入,当前输入 $x_t$、前一时刻的隐藏状态 $h_{t-1}$ 和前一时刻的细胞状态 $c_{t-1}$。然后,如图 3-2 所示它通过内部的三个门和一个 tanh 层,计算出新的细胞状态 $c_t$ 和隐藏状态 $h_t$

LSTM 单元结构

图 3-2 LSTM 单元结构

为了方便后续公式的表述,首先将当前输入 $x_t$ 和前一时刻的隐藏状态 $h_{t-1}$ 拼接起来,记为 $[h_{t-1}, x_t]$。LSTM 内部的每一次线性变换,都是作用在这个拼接后的向量上,只是各自使用不同的权重矩阵。

这种拼接操作是一种常见的计算优化。将两个向量拼接后再进行一次矩阵乘法,与对两个向量分别进行矩阵乘法然后相加,其结果是等价的。例如, $W \cdot [h_{t-1}, x_t]$ 等价于 $W_h \cdot h_{t-1} + W_x \cdot x_t$,其中 $W$ 被相应地拆分为 $W_h$$W_x$ 两部分。这样做可以利用深度学习框架中优化过的大矩阵乘法操作,提升计算效率。

在深入公式之前,可以将上图中的计算模块与公式对应起来,以便理解信息流,图中各符号的含义如下:

  • 每一个 σ (Sigmoid) 符号都对应一个的计算,即遗忘门、输入门和输出门。
  • tanh 符号有两个,分别负责生成候选记忆($\tilde{c}_t$),以及对细胞状态 $c_t$ 进行处理,将其值压缩到 [-1, 1] 区间以计算最终的隐藏状态 $h_t$
  • (圆圈)符号代表按元素乘法,这是门控机制发挥作用的关键。
  • +(圆圈)符号代表按元素加法,用于更新细胞状态。

2.1 第一步:遗忘门(Forget Gate)

LSTM 的第一步是决定我们从细胞状态中丢弃什么信息。这个决定由被称为“遗忘门”的 Sigmoid 层完成。它会审视 $h_{t-1}$$x_t$,然后为 $c_{t-1}$ 中的每个数值输出一个介于 0 和 1 之间的“遗忘系数”。当该系数接近 1 时表示“几乎完全保留”,当其接近 0 时表示“几乎完全丢弃”。

$$ f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \tag{3.7} $$

1997 年提出的原始 LSTM 结构中并没有遗忘门,这一机制是由 Gers 等人在 2000 年引入的 2

2.2 第二步:输入门与候选记忆(Input Gate & Candidate Memory)

下一步是决定我们在细胞状态中存储什么新信息。这由两部分共同完成:

(1)输入门:首先,一个 Sigmoid 层(即图中的“输入门”)决定了更新哪些值。

$$ i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \tag{3.8} $$

(2)候选记忆:然后,一个 tanh 层(即图中的“候选值”模块)创建一个新的候选记忆向量 $\tilde{c}_t$,这是准备添加到细胞状态中的新内容。这部分的计算与常规 RNN 的计算非常相似。

$$ \tilde{c}_t = \tanh(W_c \cdot [h_{t-1}, x_t] + b_c) \tag{3.9} $$

2.3 第三步:更新细胞状态 (Cell State Update)

现在,可以更新细胞状态了,我们需把旧状态 $c_{t-1}$ 更新为新状态 $c_t$,此步骤对应于图中细胞状态传送带上的核心操作。首先,将旧状态 $c_{t-1}$ 与遗忘门 $f_t$ 的输出进行逐元素相乘,丢弃掉决定要忘记的部分;然后,将输入门 $i_t$ 与候选记忆 $\tilde{c}_t$ 逐元素相乘,筛选出需要加入的新信息;最后,将这两部分相加,得到新的细胞状态 $c_t$

$$ c_t = (f_t \odot c_{t-1}) + (i_t \odot \tilde{c}_t) \tag{3.10} $$

2.4 第四步:输出门(Output Gate)

最后,我们需要决定输出什么。输出将基于我们的细胞状态,但会是一个过滤后的版本。输出的生成分为两步:

(1)输出门:一个 Sigmoid 层(“输出门”)决定了细胞状态的哪些部分将被输出。

$$ o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \tag{3.11} $$

(2)计算隐藏状态:将刚刚更新的细胞状态 $c_t$ 通过 tanh 函数(将其值规范化到 -1 和 1 之间),并将其与输出门 $o_t$ 的结果逐元素相乘,最终只输出我们决定的那部分信息作为隐藏状态 $h_t$

$$ h_t = o_t \odot \tanh(c_t) \tag{3.12} $$

这个 $h_t$ 将作为当前时间步的最终输出,并传递给下一个时间步。

三、LSTM 如何缓解长距离依赖

现在,回到最初的问题。LSTM 是如何通过这套复杂的门控机制来缓解梯度消失问题的?关键在于细胞状态 $c_t$ 的更新法则。在 RNN 中,梯度在时间步之间反向传播时,每一步都必须乘以同一个权重矩阵 $W$。当序列很长时,这种连乘操作($W \cdot W \cdot W \dots$)极易导致梯度消失或爆炸。而在 LSTM 中,梯度的传播路径被分成了两条。一条是通过隐藏状态 $h_t$ 传递的,与 RNN 类似,依然会经过激活函数的导数和权重矩阵,存在梯度衰减的风险。但另一条,也是更重要的一条,是通过细胞状态 $c_t$ 传递的。我们来考察损失 $L$ 对前一时刻细胞状态 $c_{t-1}$ 的梯度,根据链式法则:

$$ \frac{\partial L}{\partial c_{t-1}} = \frac{\partial L}{\partial c_t} \frac{\partial c_t}{\partial c_{t-1}} \tag{3.13} $$

其中,细胞状态的更新公式为 $c_t = (f_t \odot c_{t-1}) + (i_t \odot \tilde{c}t)$。可以看到,$c_t$ 对 $c{t-1}$ 的偏导数直接就是遗忘门 $f_t$(另一项与 $c_{t-1}$ 无关,导数为0)。因此:

$$ \frac{\partial L}{\partial c_{t-1}} = \frac{\partial L}{\partial c_t} \odot f_t \tag{3.14} $$

这个关系非常关键。它表明,从 $t$ 时刻的细胞状态反向传播到 $t-1$ 时刻,梯度仅仅是按元素乘以了遗忘门 $f_t$ 的值,而没有经过权重矩阵的乘法。如果将这个链条一直追溯到更早的 $k$ 时刻,梯度就变成了:

$$ \frac{\partial L}{\partial c_k} = \frac{\partial L}{\partial c_t} \odot (f_t \odot f_{t-1} \odot \dots \odot f_{k+1}) \tag{3.15} $$

由此,梯度的“高速公路”得以建立,从序列末端到开端的梯度传递主要取决于一系列遗忘门 $f_t$ 的连乘。由于 $f_t$ 是一个独立的门控单元,它的值是在每次计算中动态生成的。如果模型在训练中发现某个早期信息非常重要,它可以通过学习将中间所有时间步的遗忘门 $f_t$ 的值都设置为接近 1。在这种情况下,梯度就可以几乎无衰减地从序列末端传播到序列开端。这本质上构建了一种可学习的依赖关系,与常规 RNN 的结构性缺陷不同,LSTM 将长距离依赖问题转化成了一个可学习的问题。模型能够通过优化损失函数,自行调整门控单元的参数,来决定哪些信息需要长期记忆(通过将 $f_t$ 设置为 $\approx 1$ 来“记住”),哪些信息可以被舍弃(通过将 $f_t$ 设置为 $\approx 0$ 来“遗忘”)。

所以,我们不说 LSTM 解决了梯度消失问题,而是极大地缓解了它。因为即使 $f_t$ 的值很接近1(例如0.99),在足够长的序列上连乘后,梯度依然会衰减。但相比 RNN 几十步就会出现问题的窘境,LSTM 已经实现了巨大的飞跃。

四、从零实现一个 LSTM

本节完整代码

为了更深刻地理解 LSTM 内部复杂的信息流动,可以像实现 RNN 一样,基于公式,用 NumPy 手写一个 LSTM 的前向传播过程。这里我们同样实现一个不含偏置项的简化版 LSTM,计算公式如下:

  • 遗忘门: $f_t = \sigma(U_f x_t + W_f h_{t-1})$
  • 输入门: $i_t = \sigma(U_i x_t + W_i h_{t-1})$
  • 候选记忆: $\tilde{c}t = \tanh(U_c x_t + W_c h{t-1})$
  • 细胞状态更新: $c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t$
  • 输出门: $o_t = \sigma(U_o x_t + W_o h_{t-1})$
  • 隐藏状态更新: $h_t = o_t \odot \tanh(c_t)$

4.1 基于 NumPy 实现 LSTM

为了将上述公式转化为可执行的代码,我们可以遵循以下步骤:

(1)初始化: 我们第一步需要初始化两个零向量 h_prevc_prev,分别作为处理序列开始前的“短期记忆”和“长期记忆”。

(2)逐帧处理: 接着,使用 for 循环遍历序列中的每一个时间步,对每个时间步的输入进行处理。

(3)核心计算: 在循环内部,严格遵循 LSTM 的四个步骤进行计算。我们首先要计算遗忘门 f_t,决定从旧的细胞状态 c_prev 中忘记多少信息;然后计算输入门 i_t候选记忆 c_tilde_t,准备要写入的新信息;随后通过公式 c_t = f_t * c_prev + i_t * c_tilde_t,结合遗忘和记忆操作,得到新的细胞状态 c_t;最后计算输出门 o_t,并结合 tanh(c_t) 生成新的隐藏状态 h_t

(4)状态更新: 在每一步计算结束后,执行 h_prev, c_prev = h_t, c_t,将当前计算出的状态传递给下一个时间步,完成“循环”过程。

具体的代码实现如下:

def manual_lstm_numpy(x_np, weights):
    U_f, W_f, U_i, W_i, U_c, W_c, U_o, W_o = weights
    B_local, T_local, _ = x_np.shape
    h_prev = np.zeros((B_local, H), dtype=np.float32)
    c_prev = np.zeros((B_local, H), dtype=np.float32)
    
    steps = []
    # 按时间步循环
    for t in range(T_local):
        x_t = x_np[:, t, :]
        
        # 1. 遗忘门
        f_t = sigmoid(x_t @ U_f + h_prev @ W_f)
        
        # 2. 输入门与候选记忆
        i_t = sigmoid(x_t @ U_i + h_prev @ W_i)
        c_tilde_t = np.tanh(x_t @ U_c + h_prev @ W_c)
        
        # 3. 更新细胞状态
        c_t = f_t * c_prev + i_t * c_tilde_t
        
        # 4. 输出门与隐藏状态
        o_t = sigmoid(x_t @ U_o + h_prev @ W_o)
        h_t = o_t * np.tanh(c_t)
        
        steps.append(h_t)
        h_prev, c_prev = h_t, c_t
        
    outputs = np.stack(steps, axis=1)
    return outputs, h_prev, c_prev

通过这个实现,我们可以直观地看到 LSTM 是如何通过门控机制,在每个时间步对信息流进行控制的。

4.2 LSTM 实现层面的优化

在实际的深度学习框架(如 PyTorch、TensorFlow)中,为了提高计算效率,LSTM 的实现通常会进行一项优化。回顾 LSTM 的计算过程,遗忘门、输入门、候选记忆和输出门都对拼接向量 $[h_{t-1}, x_t]$ 进行了独立的线性变换:

  • $f_t$ : $W_f \cdot [h_{t-1}, x_t] + b_f$
  • $i_t$ : $W_i \cdot [h_{t-1}, x_t] + b_i$
  • $\tilde{c}t$ : $W_c \cdot [h{t-1}, x_t] + b_c$
  • $o_t$ : $W_o \cdot [h_{t-1}, x_t] + b_o$

这相当于对同一个输入进行了四次独立的线性层计算。为了优化,框架会将这四个权重矩阵和偏置项在内部进行拼接,将与输入 $x_t$ 相关的权重 $W_{fx}, W_{ix}, W_{cx}, W_{ox}$ 拼接成一个大的权重矩阵 $W_{x}$,同时将与隐藏状态 $h_{t-1}$ 相关的权重 $W_{fh}, W_{ih}, W_{ch}, W_{oh}$ 拼接成一个大的权重矩阵 $W_{h}$。这样,四次独立的矩阵乘法就可以被合并成两次更大规模的矩阵乘法,然后再将结果切分成四份,分别送入各自的激活函数。这种方式能更好地利用 GPU 的并行计算能力,提升运算速度。

五、门控循环单元(GRU)

GRU 由 Cho 等人在 2014 年提出 3。它最初是在 RNN Encoder-Decoder 框架下被提出的,是为了解决统计机器翻译中的短语表示问题。实验表明,这种结构能够很好地捕捉短语的语义和句法结构,并且相比 LSTM 更易于训练。

5.1 GRU 的主要改进

GRU 对 LSTM 做了两大核心改动,一个是合并细胞状态与隐藏状态,也就是不再区分细胞状态 $c_t$ 和隐藏状态 $h_t$,只有一个同时包含长期记忆并作为输出的状态向量 $h_t$ 在时间步之间传递,与 RNN 的结构类似。另一个是简化门控结构,将 LSTM 的三个门简化为了两个门。其中**更新门(Update Gate, $z_t$)**的作用类似于 LSTM 中耦合的遗忘门和输入门,同时决定了保留多少旧信息以及接收多少新信息;**重置门(Reset Gate, $r_t$)**则决定在计算候选状态时忽略多少旧信息。这种设计让每个隐藏单元能够自适应地捕捉不同时间尺度的依赖关系,从而更灵活地处理长短句法结构。

5.2 GRU 单元结构与公式

一个 GRU 单元如图 3-3 在 $t$ 时刻接收两个输入,当前输入 $x_t$ 和前一时刻的隐藏状态 $h_{t-1}$。然后,它通过内部的两个门,计算出新的隐藏状态 $h_t$

GRU 单元结构

图 3-3 GRU 单元结构

其计算过程可以分解为以下四步:

(1)重置门($r_t$)

决定如何将新输入 $x_t$ 与之前的记忆 $h_{t-1}$ 结合,这个门控制着哪些旧信息可以被用来计算“候选记忆”。计算公式如下:

$$ r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) \tag{3.16} $$

(2)更新门($z_t$)

控制前一时刻的状态信息 $h_{t-1}$ 有多少能够被直接带入到当前时刻,这与 LSTM 的遗忘门功能相似。计算公式如下:

$$ z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) \tag{3.17} $$

(3)候选记忆($\tilde{h}_t$)

计算当前时间步的候选隐藏状态。这个计算受到了重置门 $r_t$ 的影响。$r_t$ 与 $h_{t-1}$ 逐元素相乘,如果 $r_t$ 的某个元素接近 0,则表示在计算候选记忆时,将完全忽略掉 $h_{t-1}$ 对应维度的信息。计算公式如下:

$$ \tilde{h}_t = \tanh(W_h \cdot [r_t \odot h_{t-1}, x_t] + b_h) \tag{3.18} $$

(4)最终隐藏状态($h_t$)

结合更新门 $z_t$、旧状态 $h_{t-1}$ 和候选记忆 $\tilde{h}t$,生成当前时间步的最终输出 $h_t$。其中 $z_t \odot h{t-1}$ 表示对旧状态 $h_{t-1}$ 中需要保留的信息, $(1-z_t) \odot \tilde{h}_t$ 表示从候选记忆 $\tilde{h}t$ 中需要选择的新信息。这个更新机制非常巧妙,更新门 $z_t$ 的值在 0 到 1 之间,可以看作是一个“开关”。当 $z_t$ 接近 1 时,模型倾向于保留更多的旧信息 $h{t-1}$;当 $z_t$ 接近 0 时,模型则倾向于保留更多的新信息。计算公式如下:

$$ h_t = z_t \odot h_{t-1} + (1 - z_t) \odot \tilde{h}_t \tag{3.19} $$

六、LSTM 的常见变体

标准的 LSTM 结构本身已经非常强大,但在其发展过程中,研究者们也提出了一些有趣的变体,用于简化计算或增强性能。其中两种著名的变体是“窥孔连接”和“耦合的输入/遗忘门”。

6.1 窥孔连接(Peephole Connections)

在标准 LSTM 中,三个门(遗忘、输入、输出)的决策依据仅来自当前输入 $x_t$ 和前一时刻的隐藏状态 $h_{t-1}$。但正如 Gers 和 Schmidhuber 在**《Recurrent Nets that Time and Count》一文中指出的,这种机制有一个潜在的弱点,门控单元无法直接“看到”它们控制的细胞状态 4。特别是当输出门关闭时,隐藏状态 $h_{t-1}$ 接近于 0,此时门控单元失去了关于细胞内部状态的重要信息。为了解决这个问题,研究者提出了窥孔连接**,允许门控单元直接访问细胞状态:遗忘门和输入门在做决策时会“窥视” 前一时刻的细胞状态 $c_{t-1}$,而输出门在做决策时则会“窥视” 当前刚刚更新的细胞状态 $c_t$

实验表明,带有窥孔连接的 LSTM 在处理需要精确计时和计数的任务(如学习生成具有特定时间间隔的脉冲序列)时,性能显著优于标准 LSTM。公式上的变化体现是在计算每个门时,额外加入一个与细胞状态相关的项:

  • $f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + V_f \odot c_{t-1} + b_f)$
  • $i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + V_i \odot c_{t-1} + b_i)$
  • $o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + V_o \odot c_t + b_o)$

其中 $V_f, V_i, V_o$ 是新增的对角权重矩阵,代表窥孔连接的权重。由于细胞状态的维度通常与隐藏状态一致,所以这里的运算通常是按元素相乘,这意味着每个门的激活只受到其对应细胞单元状态的影响,保持了计算的局部性。

6.2 耦合的输入与遗忘门(Coupled Input and Forget Gate)

该变体的思想是遗忘旧信息和写入新信息是紧密耦合的两个决策。应该遗忘多少旧信息,恰恰是因为准备写入等量的新信息。基于此,它将输入门和遗忘门合并为一个决策。不再单独计算输入门 $i_t$,而是直接令 $i_t = 1 - f_t$。当遗忘门 $f_t$ 的某个元素为 0.8(保留 80% 的旧信息)时,对应的输入门元素就必须是 0.2(只允许 20% 的新信息进入)。细胞状态的更新公式因此变得更加简洁:

$$ c_t = (f_t \odot c_{t-1}) + ((1 - f_t) \odot \tilde{c}_t) \tag{3.20} $$

这种方式不仅使得模型逻辑更直观,还减少了模型的参数量。Greff 等人在大规模实验研究中验证了这一点,发现 CIFG 可以在不降低模型性能的前提下有效减少计算开销 5。除了验证 CIFG 的有效性外,该研究还对 LSTM 的架构进行了详尽的探索,得出了一些对工程实践极具价值的结论。例如,核心组件中的遗忘门输出激活函数是 LSTM 中最关键的组件,移除它们会显著降低性能;各超参数独立性较高,这意味着可以单独调整学习率等参数,而无需进行复杂的组合搜索;在在线随机梯度下降训练中动量作用有限,无论是好是坏,动量对 LSTM 的性能影响都微乎其微。

练习

要做哦,别偷懒 😇

  • 在前面的练习中,我们构建了一个基于全连接层的文本分类模型。现在,尝试将其改造为使用 LSTM 网络结构,以更好地捕捉文本中的序列信息。(可以参考基于 LSTM 的文本分类

参考文献

Footnotes

  1. Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural Computation, 9(8), 1735-1780.

  2. Gers, F. A., Schmidhuber, J., & Cummins, F. (2000). Learning to forget: Continual prediction with LSTM. Neural Computation, 12(10), 2451-2471.

  3. Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F., Schwenk, H., & Bengio, Y. (2014). Learning phrase representations using RNN encoder-decoder for statistical machine translation. Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing (EMNLP), 1724-1734.

  4. Gers, F. A., Schmidhuber, J. (2000). Recurrent nets that time and count. Proceedings of the IEEE-INNS-ENNS International Joint Conference on Neural Networks (IJCNN), 3, 189-194.

  5. Greff, K., Srivastava, R. K., Koutník, J., Steunebrink, B. R., & Schmidhuber, J. (2017). LSTM: A search space odyssey. IEEE Transactions on Neural Networks and Learning Systems, 28(10), 2222-2232.