torch.Size([2, 1, 6])
语法
- n 批量矩阵乘法
X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
torch.bmm(X, Y).shape
# torch.Size([2, 1, 6])
Nadaraya-Watson 核回归思想
若训练数据包含若干
被称为查询, 为键值对,这就是在查询 和 键之间的注意力权重
其中
这对
使用平方损失函数 nn.MESLoss
和随机梯度下降 nn.optim.SGD(lr=0.5)
进行训练。得到的注意力热图为:
(此图中任意一条横线表示一个测试输入
学到的
done.