From 2ffcef9ac7c83a423c7dca84cf199a75da192e15 Mon Sep 17 00:00:00 2001 From: JiageWang <1076050774@qq.com> Date: Sat, 24 Aug 2019 21:26:38 +0800 Subject: [PATCH] update code and readme --- model.py | 54 ------------------------------------------------------ train.py | 2 +- 2 files changed, 1 insertion(+), 55 deletions(-) diff --git a/model.py b/model.py index 680b777..54d93a9 100644 --- a/model.py +++ b/model.py @@ -84,57 +84,3 @@ class ConvNet(nn.Module): m.weight.data.normal_(0, 0.01) m.bias.data.zero_() -class ConvNet2(nn.Module): - def __init__(self, num_classes): - super(ConvNet2, self).__init__() - self.features = nn.Sequential( - nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), - nn.LeakyReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), - nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), - nn.LeakyReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), - nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), - nn.LeakyReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), - nn.LeakyReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - ) - self.classifier = nn.Sequential( - nn.Linear(512*4*4, 4096), - # nn.Dropout(), - nn.ReLU(inplace=True), - nn.Linear(4096, num_classes), - ) - self.weight_init() - - def forward(self, x): - x = self.features(x) - x = x.view(x.size(0), -1) - x = self.classifier(x) - return x - - def weight_init(self): - for layer in self.features: - self._layer_init(layer) - for layer in self.classifier: - self._layer_init(layer) - - - def _layer_init(self, m): - # 使用isinstance来判断m属于什么类型 - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, np.sqrt(2. / n)) - elif isinstance(m, nn.BatchNorm2d): - # m中的weight,bias其实都是Variable,为了能学习参数以及后向传播 - m.weight.data.fill_(1) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - # n = m.weight.size(1) - m.weight.data.normal_(0, 0.01) - m.bias.data.zero_() - # init.xavier_normal_(m.weight) \ No newline at end of file diff --git a/train.py b/train.py index 56d0caf..edb8fe7 100644 --- a/train.py +++ b/train.py @@ -97,7 +97,7 @@ if __name__ == "__main__": print("测试集数据:", dataset.test_size) trainloader, testloader = dataset.get_loader(batch_size) - net = ConvNet2(num_classes) + net = ConvNet(num_classes) print('网络结构:\n', net) if torch.cuda.is_available(): net = net.cuda()