how to

torch.Size([2, 1, 6])

Sep 26, 2024
notesjulyfun技术学习d2l
2 Minutes
287 Words

语法

  • n 批量矩阵乘法
1
X = torch.ones((2, 1, 4))
2
Y = torch.ones((2, 4, 6))
3
torch.bmm(X, Y).shape
4
# torch.Size([2, 1, 6])

Nadaraya-Watson 核回归思想

若训练数据包含若干 𝑥𝑖,𝑦𝑖 对,则测试输入为 𝑥 时,将对每个训练数据分配一个权重(注意力权重),最终测试输出为 𝑓(𝑥)=𝑛𝑖1𝛼(𝑥,𝑥𝑖)𝑦𝑖。Nadaraya 的想法如下: 𝛼(𝑥,𝑥𝑖)=𝐾(𝑥𝑥𝑖)

  • 𝑥 被称为查询𝑥𝑖,𝑦𝑖 为键值对,这就是在查询 𝑥𝑥𝑖 键之间的注意力权重

其中 𝐾 可设计。例如,设计为 𝐾(𝑢)=12𝜋exp(𝑢22),则推导出 𝑓(𝑥)=𝑛𝑖=1softmax(12(𝑥𝑥𝑖)2)𝑦𝑖

这对 𝑥 的分布有要求。因此改进为带参数的注意力汇聚模型,额外学习一个 𝑤 参数使得:

𝑓(𝑥)=𝑛𝑖=1softmax(12((𝑥𝑥𝑖)𝑤)2)𝑦𝑖

使用平方损失函数 nn.MESLoss 和随机梯度下降 nn.optim.SGD(lr=0.5) 进行训练。得到的注意力热图为:

default

(此图中任意一条横线表示一个测试输入 𝑥 下对整条 𝑥 轴的权重分配)

学到的 𝑤 很小,使注意力更多分配到最近的 𝑥

done.

Article title:torch.Size([2, 1, 6])
Article author:Julyfun
Release time:Sep 26, 2024
Copyright 2025
Sitemap