nn.RNN
接口很简单。注意隐状态 state
是在 RNN 对象之外另开对象进行存储的。
1batch_size, num_steps = 32, 352train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)3
4num_hiddens = 2565# RNN() 内部含有 W, b 参数之类的6rnn_layer = nn.RNN(len(vocab), num_hiddens)7
8# state 形状 [隐藏层数,批量大小,隐藏单元数]9state = torch.zeros((1, batch_size, num_hiddens))10# X.shape = [时间步,批量,词典大小(独热)]11X = torch.rand(size=(num_steps, batch_size, len(vocab)))12# Y.shape = [时间步,批量,隐藏单元数], state 形状不变13Y, state = self.rnn(X, state)
封装
1class RNNModel(nn.Module):2 """循环神经网络模型"""3 def __init__(self, rnn_layer, vocab_size, **kwargs):4 super(RNNModel, self).__init__(**kwargs)5 self.rnn = rnn_layer6 self.vocab_size = vocab_size7 self.num_hiddens = self.rnn.hidden_size8 # 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是19 if not self.rnn.bidirectional:10 self.num_directions = 111 self.linear = nn.Linear(self.num_hiddens, self.vocab_size)12 else:13 self.num_directions = 214 self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)15
23 collapsed lines
16 def forward(self, inputs, state):17 X = F.one_hot(inputs.T.long(), self.vocab_size)18 X = X.to(torch.float32)19 Y, state = self.rnn(X, state)20 # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)21 # 它的输出形状是(时间步数*批量大小,词表大小)。22 output = self.linear(Y.reshape((-1, Y.shape[-1])))23 return output, state24
25 def begin_state(self, device, batch_size=1):26 if not isinstance(self.rnn, nn.LSTM):27 # nn.GRU以张量作为隐状态28 return torch.zeros((self.num_directions * self.rnn.num_layers,29 batch_size, self.num_hiddens),30 device=device)31 else:32 # nn.LSTM以元组作为隐状态33 return (torch.zeros((34 self.num_directions * self.rnn.num_layers,35 batch_size, self.num_hiddens), device=device),36 torch.zeros((37 self.num_directions * self.rnn.num_layers,38 batch_size, self.num_hiddens), device=device))
32:07