1class Seq2SeqAttentionDecoder(AttentionDecoder):2 def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,3 dropout=0, **kwargs):4 super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)5 self.attention = d2l.AdditiveAttention(6 num_hiddens, num_hiddens, num_hiddens, dropout)7 self.embedding = nn.Embedding(vocab_size, embed_size)8 self.rnn = nn.GRU(9 embed_size + num_hiddens, num_hiddens, num_layers,10 dropout=dropout)11 self.dense = nn.Linear(num_hiddens, vocab_size)12
13 def init_state(self, enc_outputs, enc_valid_lens, *args):14 # outputs的形状为(batch_size,num_steps,num_hiddens).15 # hidden_state的形状为(num_layers,batch_size,num_hiddens)32 collapsed lines
16 outputs, hidden_state = enc_outputs17 return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)18
19 def forward(self, X, state):20 # enc_outputs的形状为(batch_size,num_steps,num_hiddens).21 # hidden_state的形状为(num_layers,batch_size,22 # num_hiddens)23 enc_outputs, hidden_state, enc_valid_lens = state24 # 输出X的形状为(num_steps,batch_size,embed_size)25 X = self.embedding(X).permute(1, 0, 2)26 outputs, self._attention_weights = [], []27 for x in X:28 # query的形状为(batch_size,1,num_hiddens)29 query = torch.unsqueeze(hidden_state[-1], dim=1)30 # context的形状为(batch_size,1,num_hiddens)31 context = self.attention(32 query, enc_outputs, enc_outputs, enc_valid_lens)33 # 在特征维度上连结34 x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)35 # 将x变形为(1,batch_size,embed_size+num_hiddens)36 out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)37 outputs.append(out)38 self._attention_weights.append(self.attention.attention_weights)39 # 全连接层变换后,outputs的形状为40 # (num_steps,batch_size,vocab_size)41 outputs = self.dense(torch.cat(outputs, dim=0))42 return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,43 enc_valid_lens]44
45 @property46 def attention_weights(self):47 return self._attention_weights
Article title:10.4.bahdanau-attention
Article author:Julyfun
Release time:Dec 22, 2024
Copyright 2025
Sitemap