Skip to content

8.7.bptt

没时间看懂。重要小结:

梯度截断

$$h_t = f(x_t, h_(t - 1), w_h)$$

$$o_t = g(h_t, w_o)$$

$$L = 1 / T sum_(t = 1)^(T) l(y_t, o_t)$$

欲求 $(diff L) / (diff w_h)$,链式法则展开为 $sum_(t = 1)^(T) (diff l(y_t, o_t)) / (diff o_t) * (diff o_t) / (diff h_t) * (diff h_t) / (diff w_h)$

前两项好求,但第三项 $h_t$ 既依赖于 $w_h$ 又依赖于 $h_(t - 1)$,而后者又依赖于 $w_h$,最终导出的梯度是一个带有求积的求和项,容易产生梯度爆炸 / 消失,故可在一定时间步后截断求和计算。

其他

  • 说是 bptt 会在计算期间缓存中间值。