LSTM 直觉问答:c 是加性高速公路、h 输出为何要过 tanh¶
- 章节:第25章 · 循环神经网络(25.2.1 长短期记忆网络)
- 用途:装我自己提出、用来辅助记忆的两个「为什么」。公式本身见 [[RNN全公式速查-SRNN-LSTM-GRU字母与函数对齐]],这条专管直觉。
- 出处:式25.23–25.25(p455);课本 p457「梯度传播不是通过矩阵连乘的线性组合,所以可以避免梯度消失和梯度爆炸」。tanh 压缩的「为什么」是直觉补充,非课本原话。
先把要用到的两行公式钉住: $\(\boldsymbol c_t=\boldsymbol f_t\odot\boldsymbol c_{t-1}+\boldsymbol i_t\odot\tilde{\boldsymbol c}_t\qquad\qquad \boldsymbol h_t=\boldsymbol o_t\odot\tanh(\boldsymbol c_t)\)$ $\(\tilde{\boldsymbol c}_t=\tanh(\boldsymbol U_c\boldsymbol h_{t-1}+\boldsymbol W_c\boldsymbol x_t+\boldsymbol b_c)\quad(\text{候选记忆,由 }\boldsymbol h_{t-1}\text{ 和 }\boldsymbol x_t\text{ 一起生成})\)$
我的问题①:h 和 h_{t-1} 的连接「没那么紧密」了,会不会有问题?¶
不是问题,恰恰是 LSTM 的精髓。
- 在 SRNN 里,\(\boldsymbol h_{t-1}\to\boldsymbol h_t\) 是直接连的(一根 \(\boldsymbol U\) 横着串),梯度只能走这条乘法老路 → 连乘≈矩阵连续自乘 → 梯度消失(见 [[RNN缺点-不能并行与S-RNN梯度问题]])。
- LSTM 故意把主干改道:相邻时刻最顺畅的通路不再是 \(\boldsymbol h\),而是 \(\boldsymbol c\) 那条加性高速公路 \(\boldsymbol c_{t-1}\to\boldsymbol c_t\)(来自 \(\boldsymbol c_t=\boldsymbol f_t\odot\boldsymbol c_{t-1}+\dots\) 的加号)。\(\boldsymbol h_t\) 退居「从 \(\boldsymbol c_t\) 取一瓢对外输出」的角色。
- 课本撑腰(p457 原话):「学习中由于位置之间的梯度传播不是通过矩阵连乘的线性组合,所以可以避免梯度消失和梯度爆炸。」
记忆的接力棒交给了 \(\boldsymbol c\)(加性、无损),这才让远处梯度传得动;\(\boldsymbol h\)「之间没那么紧密」是设计如此。
我的问题②:h_t = o_t ⊙ tanh(c_t),记忆 c_t 为什么还要再过一个 tanh?¶
前提先纠正:我以为「记忆是 0~1 之间的数」——不对。
- 0~1 的是门 \(\boldsymbol f,\boldsymbol i,\boldsymbol o\)(\(\sigma\) 出来,当开关)。
- 但门点乘的对象 \(\tilde{\boldsymbol c}_t\) 是 \(\tanh\) 出来的(范围 \(-1\sim1\)),而且 \(\boldsymbol c_t\) 是一路加累加的——加几十步后值可能涨到 5、10 甚至更大,既不在 0~1 也不在 -1~1。这正是「高速公路」的特点:为不丢信息,它不做强压缩,数值会越滚越大。
所以那个 tanh 的作用:\(\boldsymbol h_t\) 要对外输出、还要喂给下一步算门,如果把动辄 5、10 的 \(\boldsymbol c_t\) 直接丢出去,下一层/下一时刻容易被大数值带偏、不稳定。最后那个 \(\tanh\) 就是把可能很大的 \(\boldsymbol c_t\) 压回 \((-1,1)\) 的温和范围,再由输出门 \(\boldsymbol o_t\) 放行。(此「为什么」是直觉补充,课本只写式25.25 的 what。)
一句话记(侧重点)¶
- \(\boldsymbol c\) 负责「无损地记」(加性高速公路,值可以大);\(\boldsymbol h\) 负责「稳定地说」(过 tanh 压回 \(-1\sim1\) 再输出)。\(\boldsymbol c\) 与 \(\boldsymbol h\) 分工不同。
- 门 \(\sigma\in(0,1)\) 是开关;候选/记忆是数值、走 tanh;别把「门是 0~1」错套到「记忆是 0~1」。
- 候选记忆 \(\tilde{\boldsymbol c}_t\) 由 \(\boldsymbol h_{t-1}\) + \(\boldsymbol x_t\) 一起生成(别漏当前输入);本次新记忆 \(\boldsymbol c_t\) 生成本次隐状态 \(\boldsymbol h_t\)。