10.2.注意力汇聚
语法
- n 批量矩阵乘法
X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape
# torch.Size([2, 1, 6])
Nadaraya-Watson 核回归思想
若训练数据包含若干 $x_i, y_i$ 对,则测试输入为 $x$ 时,将对每个训练数据分配一个权重(注意力权重),最终测试输出为 $f(x) = sum_(i - 1)^n alpha (x, x_i) y_i$。Nadaraya 的想法如下: $$alpha(x, x_i) = K(x - x_i)$$ - $x$ 被称为查询,$x_i, y_i$ 为键值对,这就是在查询 $x$ 和 $x_i$ 键之间的注意力权重
其中 $K$ 可设计。例如,设计为 $K(u) = 1 / (sqrt(2 pi)) exp(- u^2 / 2)$,则推导出 $$f(x) = sum_(i = 1)^n "softmax"(-1 / 2 (x - x_i)^2)y_i$$
这对 $x$ 的分布有要求。因此改进为带参数的注意力汇聚模型,额外学习一个 $w$ 参数使得:
$$ f(x) = sum_(i = 1)^n "softmax"(-1 / 2 ((x - x_i)w)^2)y_i $$
使用平方损失函数 nn.MESLoss
和随机梯度下降 nn.optim.SGD(lr=0.5)
进行训练。得到的注意力热图为:
(此图中任意一条横线表示一个测试输入 $x$ 下对整条 $x$ 轴的权重分配)
学到的 $w$ 很小,使注意力更多分配到最近的 $x$。
done.