Files
hand-writing-recognition/model.py
2019-08-24 21:26:38 +08:00

87 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import numpy as np
def conv_dw(inp, oup, stride):
return nn.Sequential(
# dw
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)
class ConvNet(nn.Module):
def __init__(self, num_classes):
super(ConvNet, self).__init__()
self.conv1 = conv_bn(3, 8, 1) # 64x64x1
self.conv2 = conv_bn(8, 16, 1) # 64x64x16
self.conv3 = conv_dw(16, 32, 1) # 64x64x32
self.conv4 = conv_dw(32, 32, 2) # 32x32x32
self.conv5 = conv_dw(32, 64, 1) # 32x32x64
self.conv6 = conv_dw(64, 64, 2) # 16x16x64
self.conv7 = conv_dw(64, 128, 1) # 16x16x128
self.conv8 = conv_dw(128, 128, 1) # 16x16x128
self.conv9 = conv_dw(128, 128, 1) # 16x16x128
self.conv10 = conv_dw(128, 128, 1) # 16x16x128
self.conv11 = conv_dw(128, 128, 1) # 16x16x128
self.conv12 = conv_dw(128, 256, 2) # 8x8x256
self.classifier = nn.Sequential(
nn.Linear(256 * 8 * 8, 4096),
nn.Dropout(0.2),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
# self.weight_init()
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x2)
x4 = self.conv4(x3)
x5 = self.conv5(x4)
x6 = self.conv6(x5)
x7 = self.conv7(x6)
x8 = self.conv8(x7)
x9 = self.conv9(x8)
x9 = F.relu(x8 + x9)
x10 = self.conv10(x9)
x11 = self.conv11(x10)
x11 = F.relu(x10 + x11)
x12 = self.conv12(x11)
x = x12.view(x12.size(0), -1)
x = self.classifier(x)
return x
def weight_init(self):
for layer in self.modules():
self._layer_init(layer)
def _layer_init(self, m):
# 使用isinstance来判断m属于什么类型
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, np.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
# m中的weightbias其实都是Variable为了能学习参数以及后向传播
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()