how to

RNN() 内部含有 W, b 参数之类的

Aug 16, 2024
notesjulyfun技术学习d2l
2 Minutes
317 Words

nn.RNN

接口很简单。注意隐状态 state 是在 RNN 对象之外另开对象进行存储的。

1
batch_size, num_steps = 32, 35
2
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
3
4
num_hiddens = 256
5
# RNN() 内部含有 W, b 参数之类的
6
rnn_layer = nn.RNN(len(vocab), num_hiddens)
7
8
# state 形状 [隐藏层数,批量大小,隐藏单元数]
9
state = torch.zeros((1, batch_size, num_hiddens))
10
# X.shape = [时间步,批量,词典大小(独热)]
11
X = torch.rand(size=(num_steps, batch_size, len(vocab)))
12
# Y.shape = [时间步,批量,隐藏单元数], state 形状不变
13
Y, state = self.rnn(X, state)

封装

1
class RNNModel(nn.Module):
2
"""循环神经网络模型"""
3
def __init__(self, rnn_layer, vocab_size, **kwargs):
4
super(RNNModel, self).__init__(**kwargs)
5
self.rnn = rnn_layer
6
self.vocab_size = vocab_size
7
self.num_hiddens = self.rnn.hidden_size
8
# 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1
9
if not self.rnn.bidirectional:
10
self.num_directions = 1
11
self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
12
else:
13
self.num_directions = 2
14
self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)
15
23 collapsed lines
16
def forward(self, inputs, state):
17
X = F.one_hot(inputs.T.long(), self.vocab_size)
18
X = X.to(torch.float32)
19
Y, state = self.rnn(X, state)
20
# 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)
21
# 它的输出形状是(时间步数*批量大小,词表大小)。
22
output = self.linear(Y.reshape((-1, Y.shape[-1])))
23
return output, state
24
25
def begin_state(self, device, batch_size=1):
26
if not isinstance(self.rnn, nn.LSTM):
27
# nn.GRU以张量作为隐状态
28
return torch.zeros((self.num_directions * self.rnn.num_layers,
29
batch_size, self.num_hiddens),
30
device=device)
31
else:
32
# nn.LSTM以元组作为隐状态
33
return (torch.zeros((
34
self.num_directions * self.rnn.num_layers,
35
batch_size, self.num_hiddens), device=device),
36
torch.zeros((
37
self.num_directions * self.rnn.num_layers,
38
batch_size, self.num_hiddens), device=device))

32:07

Article title:RNN() 内部含有 W, b 参数之类的
Article author:Julyfun
Release time:Aug 16, 2024
Copyright 2025
Sitemap