update code and readme
This commit is contained in:
22
model.py
22
model.py
@@ -1,6 +1,8 @@
|
||||
import numpy as np
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from torchsummary import summary
|
||||
|
||||
|
||||
def conv_dw(inp, oup, stride):
|
||||
@@ -38,11 +40,15 @@ class ConvNet(nn.Module):
|
||||
self.conv10 = conv_dw(128, 128, 1) # 16x16x128
|
||||
self.conv11 = conv_dw(128, 128, 1) # 16x16x128
|
||||
self.conv12 = conv_dw(128, 256, 2) # 8x8x256
|
||||
self.conv13 = conv_dw(256, 256, 1) # 8x8x256
|
||||
self.conv14 = conv_dw(256, 256, 1) # 8x8x256
|
||||
self.conv15 = conv_dw(256, 512, 2) # 4x4x512
|
||||
self.conv16 = conv_dw(512, 512, 1) # 4x4x512
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(256 * 8 * 8, 4096),
|
||||
nn.Linear(512*4*4, 1024),
|
||||
nn.Dropout(0.2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(4096, num_classes),
|
||||
nn.Linear(1024, num_classes),
|
||||
)
|
||||
self.weight_init()
|
||||
|
||||
@@ -61,7 +67,12 @@ class ConvNet(nn.Module):
|
||||
x11 = self.conv11(x10)
|
||||
x11 = F.relu(x10 + x11)
|
||||
x12 = self.conv12(x11)
|
||||
x = x12.view(x12.size(0), -1)
|
||||
x13 = self.conv13(x12)
|
||||
x14 = self.conv14(x13)
|
||||
x14 = F.relu(x13 + x14)
|
||||
x15 = self.conv15(x14)
|
||||
x16 = self.conv16(x15)
|
||||
x = x16.view(x16.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
@@ -83,3 +94,6 @@ class ConvNet(nn.Module):
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = ConvNet(3755).cuda()
|
||||
summary(model, input_size=(3, 64, 64), device='cuda')
|
||||
|
||||
Reference in New Issue
Block a user