语法
- n 批量矩阵乘法
1X = torch.ones((2, 1, 4))2Y = torch.ones((2, 4, 6))3torch.bmm(X, Y).shape4# torch.Size([2, 1, 6])Nadaraya-Watson 核回归思想
若训练数据包含若干 对,则测试输入为 时,将对每个训练数据分配一个权重(注意力权重),最终测试输出为 。Nadaraya 的想法如下:
- 被称为查询, 为键值对,这就是在查询 和 键之间的注意力权重
其中 可设计。例如,设计为 ,则推导出
这对 的分布有要求。因此改进为带参数的注意力汇聚模型,额外学习一个 参数使得:
使用平方损失函数 nn.MESLoss 和随机梯度下降 nn.optim.SGD(lr=0.5) 进行训练。得到的注意力热图为:

(此图中任意一条横线表示一个测试输入 下对整条 轴的权重分配)
学到的 很小,使注意力更多分配到最近的 。
done.