1import torch2from torch import nn3from d2l import torch as d2l4
5net = nn.Sequential(nn.Flatten(), nn.Linear(784, 256), nn.ReLU(),6 nn.Linear(256, 10))7def init_weights(m):8 if type(m) == nn.Linear:9 nn.init.normal_(m.weight, std=0.01)10net.apply(init_weights)11
12batch_size, lr, num_epochs = 256, 0.1, 1013loss = nn.CrossEntropyLoss()14trainer = torch.optim.SGD(net.parameters(), lr=lr)15train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)57 collapsed lines
16# iter 每次迭代 X([n, d] or [n, c, h, w]), y([n])17# 例如: (torch.Size([256, 1, 28, 28]), torch.Size([256]))18
19class Accumulator:20 def __init__(self, n):21 self.data = [0.0] * n22
23 def add(self, *args):24 self.data = [a + float(b) for a, b in zip(self.data, args)]25
26 def reset(self):27 self.data = [0.0] * len(self.data)28
29 def __getitem__(self, idx):30 return self.data[idx]31
32def evaluate_accuracy(net, data_iter):33 if isinstance(net, torch.nn.Module):34 net.eval() # 将模型设置为评估模式35 metric = Accumulator(2) # 正确预测数、预测总数36 with torch.no_grad():37 for X, y in data_iter:38 metric.add(accuracy(net(X), y), y.numel())39 return metric[0] / metric[1]40
41def train_epoch_ch3(net, train_iter, loss, updater):42 # 将模型设置为训练模式43 if isinstance(net, torch.nn.Module):44 net.train()45 # 训练损失总和、训练准确度总和、样本数46 metric = Accumulator(3)47 for X, y in train_iter:48 # 计算梯度并更新参数49 y_hat = net(X)50 l = loss(y_hat, y)51 if isinstance(updater, torch.optim.Optimizer):52 updater.zero_grad()53 l.mean().backward()54 updater.step()55 else:56 l.sum().backward()57 updater(X.shape[0])58 metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())59 # 训练损失和训练精度60 return metric[0] / metric[2], metric[1] / metric[2]61
62def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater): #@save63 legend=['train loss', 'train acc', 'test acc'])64 for epoch in range(num_epochs):65 train_metrics = train_epoch_ch3(net, train_iter, loss, updater)66 test_acc = evaluate_accuracy(net, test_iter)67 train_loss, train_acc = train_metrics68 assert train_loss < 0.5, train_loss69 assert train_acc <= 1 and train_acc > 0.7, train_acc70 assert test_acc <= 1 and test_acc > 0.7, test_acc71
72train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
Article title:iter 每次迭代 X([n, d] or [n, c, h, w]), y([n])
Article author:Julyfun
Release time:Aug 19, 2024
Copyright 2025
Sitemap