mobilenet v1 论文解读
论文地址: https://arxiv.org/abs/1704.04861
核心思想就是通过 depthwise conv 替代普通 conv.
有关 depthwise conv 可以参考 https://www.cnblogs.com/sdu20112013/p/11759928.html
模型结构:
类似于 vgg 这种堆叠的结构.
每一层的运算量
可以看到, 运算量并不是与参数数量绝对成正比, 当然整体趋势而言, 参数量更少的模型会运算更快.
代码实现
https://github.com/marvis/pytorch-mobilenet
网络结构:
- class.NET(nn.Module):
- def __init__(self):
- super.NET, self).__init__()
- def conv_bn(inp, oup, stride):
- return nn.Sequential(
- nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
- nn.BatchNorm2d(oup),
- nn.ReLU(inplace=True)
- )
- def conv_dw(inp, oup, stride):
- return nn.Sequential(
- nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
- nn.BatchNorm2d(inp),
- nn.ReLU(inplace=True),
- nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
- nn.BatchNorm2d(oup),
- nn.ReLU(inplace=True),
- )
- self.model = nn.Sequential(
- conv_bn( 3, 32, 2),
- conv_dw( 32, 64, 1),
- conv_dw( 64, 128, 2),
- conv_dw(128, 128, 1),
- conv_dw(128, 256, 2),
- conv_dw(256, 256, 1),
- conv_dw(256, 512, 2),
- conv_dw(512, 512, 1),
- conv_dw(512, 512, 1),
- conv_dw(512, 512, 1),
- conv_dw(512, 512, 1),
- conv_dw(512, 512, 1),
- conv_dw(512, 1024, 2),
- conv_dw(1024, 1024, 1),
- nn.AvgPool2d(7),
- )
- self.fc = nn.Linear(1024, 1000)
- def forward(self, x):
- x = self.model(x)
- x = x.view(-1, 1024)
- x = self.fc(x)
- return x
参考论文中的结构, 第一层是普通的卷积层, 后面接的都是可分离卷积.
这里注意 groups 参数的用法. 当 groups = 输入 channel 数目时, 即对每个 channel 分别做卷积. 默认 groups=1, 此时即为普通卷积.
训练伪代码
- # create model
- model = Net()
- # define loss function (criterion) and optimizer
- criterion = nn.CrossEntropyLoss().cuda()
- optimizer = torch.optim.SGD(model.parameters(), args.lr,
- momentum=args.momentum,
- weight_decay=args.weight_decay)
- # load data
- train_loader = torch.utils.data.DataLoader()
- # train
- for every epoch:
- input,target=get_from_data
- #前向传播得到预测值
- output = model(input_var)
- #计算 loss
- loss = criterion(output, target_var)
- #反向传播更新网络参数
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
来源: https://www.cnblogs.com/sdu20112013/p/11765507.html