1from torch import nn2
3class Encoder(nn.Module):4 """编码器-解码器架构的基本编码器接口"""5 def __init__(self, **kwargs):6 super(Encoder, self).__init__(**kwargs)7
8 def forward(self, X, *args):9 raise NotImplementedError10
11class Decoder(nn.Module):12 """编码器-解码器架构的基本解码器接口"""13 def __init__(self, **kwargs):14 super(Decoder, self).__init__(**kwargs)15
17 collapsed lines
16 def init_state(self, enc_outputs, *args):17 raise NotImplementedError18
19 def forward(self, X, state):20 raise NotImplementedError21
22class EncoderDecoder(nn.Module):23 """编码器-解码器架构的基类"""24 def __init__(self, encoder, decoder, **kwargs):25 super(EncoderDecoder, self).__init__(**kwargs)26 self.encoder = encoder27 self.decoder = decoder28
29 def forward(self, enc_X, dec_X, *args):30 enc_outputs = self.encoder(enc_X, *args)31 dec_state = self.decoder.init_state(enc_outputs, *args)32 return self.decoder(dec_X, dec_state)
编码器有输入 env_X
,同时解码器还会将编码器输出转化为 state,并和自己的额外输入 dec_X
一起输入到解码器中。这里并没有说输入会是独热编码。