how to

iter 每次迭代 X([n, d] or [n, c, h, w]), y([n])

Aug 19, 2024
notesjulyfun技术学习d2l
2 Minutes
303 Words
1
import torch
2
from torch import nn
3
from d2l import torch as d2l
4
5
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 256), nn.ReLU(),
6
nn.Linear(256, 10))
7
def init_weights(m):
8
if type(m) == nn.Linear:
9
nn.init.normal_(m.weight, std=0.01)
10
net.apply(init_weights)
11
12
batch_size, lr, num_epochs = 256, 0.1, 10
13
loss = nn.CrossEntropyLoss()
14
trainer = torch.optim.SGD(net.parameters(), lr=lr)
15
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
57 collapsed lines
16
# iter 每次迭代 X([n, d] or [n, c, h, w]), y([n])
17
# 例如: (torch.Size([256, 1, 28, 28]), torch.Size([256]))
18
19
class Accumulator:
20
def __init__(self, n):
21
self.data = [0.0] * n
22
23
def add(self, *args):
24
self.data = [a + float(b) for a, b in zip(self.data, args)]
25
26
def reset(self):
27
self.data = [0.0] * len(self.data)
28
29
def __getitem__(self, idx):
30
return self.data[idx]
31
32
def evaluate_accuracy(net, data_iter):
33
if isinstance(net, torch.nn.Module):
34
net.eval() # 将模型设置为评估模式
35
metric = Accumulator(2) # 正确预测数、预测总数
36
with torch.no_grad():
37
for X, y in data_iter:
38
metric.add(accuracy(net(X), y), y.numel())
39
return metric[0] / metric[1]
40
41
def train_epoch_ch3(net, train_iter, loss, updater):
42
# 将模型设置为训练模式
43
if isinstance(net, torch.nn.Module):
44
net.train()
45
# 训练损失总和、训练准确度总和、样本数
46
metric = Accumulator(3)
47
for X, y in train_iter:
48
# 计算梯度并更新参数
49
y_hat = net(X)
50
l = loss(y_hat, y)
51
if isinstance(updater, torch.optim.Optimizer):
52
updater.zero_grad()
53
l.mean().backward()
54
updater.step()
55
else:
56
l.sum().backward()
57
updater(X.shape[0])
58
metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
59
# 训练损失和训练精度
60
return metric[0] / metric[2], metric[1] / metric[2]
61
62
def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater): #@save
63
legend=['train loss', 'train acc', 'test acc'])
64
for epoch in range(num_epochs):
65
train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
66
test_acc = evaluate_accuracy(net, test_iter)
67
train_loss, train_acc = train_metrics
68
assert train_loss < 0.5, train_loss
69
assert train_acc <= 1 and train_acc > 0.7, train_acc
70
assert test_acc <= 1 and test_acc > 0.7, test_acc
71
72
train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
Article title:iter 每次迭代 X([n, d] or [n, c, h, w]), y([n])
Article author:Julyfun
Release time:Aug 19, 2024
Copyright 2025
Sitemap