相比于 nin,google net 的 inception 块有 4 条路径。注意在合并时是直接在 channels 维度拼接而不是相加。但 googlenet 的最后重新引入了 Linear 层。
1import torch2from torch import nn3from torch.nn import functional as F4from d2l import torch as d2l5
6
7class Inception(nn.Module):8 # c1--c4是每条路径的输出通道数9 def __init__(self, in_channels, c1, c2, c3, c4, **kwargs):10 super(Inception, self).__init__(**kwargs)11 # 线路1,单1x1卷积层12 self.p1_1 = nn.Conv2d(in_channels, c1, kernel_size=1)13 # 线路2,1x1卷积层后接3x3卷积层14 self.p2_1 = nn.Conv2d(in_channels, c2[0], kernel_size=1)15 self.p2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)40 collapsed lines
16 # 线路3,1x1卷积层后接5x5卷积层17 self.p3_1 = nn.Conv2d(in_channels, c3[0], kernel_size=1)18 self.p3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)19 # 线路4,3x3最大汇聚层后接1x1卷积层20 self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)21 self.p4_2 = nn.Conv2d(in_channels, c4, kernel_size=1)22
23 def forward(self, x):24 p1 = F.relu(self.p1_1(x))25 p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))26 p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))27 p4 = F.relu(self.p4_2(self.p4_1(x)))28 # 在通道维度上连结输出29 return torch.cat((p1, p2, p3, p4), dim=1)30
31b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),32 nn.ReLU(),33 nn.MaxPool2d(kernel_size=3, stride=2, padding=1))34
35b2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1),36 nn.ReLU(),37 nn.Conv2d(64, 192, kernel_size=3, padding=1),38 nn.ReLU(),39 nn.MaxPool2d(kernel_size=3, stride=2, padding=1))40
41b3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),42 Inception(256, 128, (128, 192), (32, 96), 64),43 nn.MaxPool2d(kernel_size=3, stride=2, padding=1))44b4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),45 Inception(512, 160, (112, 224), (24, 64), 64),46 Inception(512, 128, (128, 256), (24, 64), 64),47 Inception(512, 112, (144, 288), (32, 64), 64),48 Inception(528, 256, (160, 320), (32, 128), 128),49 nn.MaxPool2d(kernel_size=3, stride=2, padding=1))50b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),51 Inception(832, 384, (192, 384), (48, 128), 128),52 nn.AdaptiveAvgPool2d((1,1)),53 nn.Flatten())54
55net = nn.Sequential(b1, b2, b3, b4, b5, nn.Linear(1024, 10))