Skip to content

7.3.nin

关于卷积层的提示

注意输入通道和输出通道是全连接的,即:

若单层图像的核有 X 个(取决于图像长宽,kernel_size, padding 和 stride),输入通道 n,输出通道为 m,则核函数有 X * n * m 个

核参数则有 $X \times n \times m \times \texttt{kernel_size} \times \texttt{kernel_size}$ 个

ref: https://zh.d2l.ai/chapter_convolutional-neural-networks/channels.html

特色

相比 vgg,nin 主要特色两个:

  • block 内部是 Conv + ReLU + Conv(kernel_size = 1) + ReLU + Conv(kernel_size = 1) + ReLU + MaxPool,给每个像素做了通道到通道的全连接
  • 最后直接是 384 通道到分类数通道的 nin block,没有 Linear

全局平均汇聚层将图像的大小压缩至 1 * 1,不改变 n 和 channels,也不改变维数。

def nin_block(in_channels, out_channels, kernel_size, strides, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())

net = nn.Sequential(
    nin_block(1, 96, kernel_size=11, strides=4, padding=0),
    nn.MaxPool2d(3, stride=2),
    nin_block(96, 256, kernel_size=5, strides=1, padding=2),
    nn.MaxPool2d(3, stride=2),
    nin_block(256, 384, kernel_size=3, strides=1, padding=1),
    nn.MaxPool2d(3, stride=2),
    nn.Dropout(0.5),
    # 标签类别数是10
    nin_block(384, 10, kernel_size=3, strides=1, padding=1),
    nn.AdaptiveAvgPool2d((1, 1)),
    # 将四维的输出转成二维的输出,其形状为(批量大小,10)
    nn.Flatten())
X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)
Copy to clipboard
Sequential output shape:     torch.Size([1, 96, 54, 54]) # 包含整个 nin_block
MaxPool2d output shape:      torch.Size([1, 96, 26, 26])
Sequential output shape:     torch.Size([1, 256, 26, 26])
MaxPool2d output shape:      torch.Size([1, 256, 12, 12])
Sequential output shape:     torch.Size([1, 384, 12, 12])
MaxPool2d output shape:      torch.Size([1, 384, 5, 5])
Dropout output shape:        torch.Size([1, 384, 5, 5])
Sequential output shape:     torch.Size([1, 10, 5, 5])
AdaptiveAvgPool2d output shape:      torch.Size([1, 10, 1, 1])
Flatten output shape:        torch.Size([1, 10])