how to

9.6.encoderdecoder

Aug 22, 2024
notesjulyfun技术学习d2l
1 Minutes
170 Words
1
from torch import nn
2
3
class Encoder(nn.Module):
4
"""编码器-解码器架构的基本编码器接口"""
5
def __init__(self, **kwargs):
6
super(Encoder, self).__init__(**kwargs)
7
8
def forward(self, X, *args):
9
raise NotImplementedError
10
11
class Decoder(nn.Module):
12
"""编码器-解码器架构的基本解码器接口"""
13
def __init__(self, **kwargs):
14
super(Decoder, self).__init__(**kwargs)
15
17 collapsed lines
16
def init_state(self, enc_outputs, *args):
17
raise NotImplementedError
18
19
def forward(self, X, state):
20
raise NotImplementedError
21
22
class EncoderDecoder(nn.Module):
23
"""编码器-解码器架构的基类"""
24
def __init__(self, encoder, decoder, **kwargs):
25
super(EncoderDecoder, self).__init__(**kwargs)
26
self.encoder = encoder
27
self.decoder = decoder
28
29
def forward(self, enc_X, dec_X, *args):
30
enc_outputs = self.encoder(enc_X, *args)
31
dec_state = self.decoder.init_state(enc_outputs, *args)
32
return self.decoder(dec_X, dec_state)

编码器有输入 env_X,同时解码器还会将编码器输出转化为 state,并和自己的额外输入 dec_X 一起输入到解码器中。这里并没有说输入会是独热编码。

Article title:9.6.encoderdecoder
Article author:Julyfun
Release time:Aug 22, 2024
Copyright 2025
Sitemap