Skip to content

5.2.参数管理

net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = torch.rand(size=(2, 4))
net(X)

访问参数

  • net[2].state_dict(): 第三层参数字典
  • net[2].bias net[2].bais.data: 第三层偏置捏
  • net.named_parameters(): 所有参数字典
  • net.state_dict()['2.bias'].data: 第三层偏置
  • print(net): 显示网络结构