8.7.bptt
没时间看懂。重要小结:
梯度截断
$$h_t = f(x_t, h_(t - 1), w_h)$$
$$o_t = g(h_t, w_o)$$
$$L = 1 / T sum_(t = 1)^(T) l(y_t, o_t)$$
欲求 $(diff L) / (diff w_h)$,链式法则展开为 $sum_(t = 1)^(T) (diff l(y_t, o_t)) / (diff o_t) * (diff o_t) / (diff h_t) * (diff h_t) / (diff w_h)$
前两项好求,但第三项 $h_t$ 既依赖于 $w_h$ 又依赖于 $h_(t - 1)$,而后者又依赖于 $w_h$,最终导出的梯度是一个带有求积的求和项,容易产生梯度爆炸 / 消失,故可在一定时间步后截断求和计算。
其他
- 说是 bptt 会在计算期间缓存中间值。