update code and readme

This commit is contained in:
JiageWang
2019-08-26 15:25:12 +08:00
parent 2beb4adb0f
commit 43ea8da649
3 changed files with 45 additions and 32 deletions

28
hwdb.py
View File

@@ -1,4 +1,5 @@
import os
import random
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
@@ -6,7 +7,7 @@ import matplotlib.pyplot as plt
class HWDB(object):
def __init__(self, transform, path='./data'):
def __init__(self,path, transform):
# 预处理过程
traindir = os.path.join(path, 'train')
@@ -16,6 +17,8 @@ class HWDB(object):
self.testset = datasets.ImageFolder(testdir, transform)
self.train_size = len(self.trainset)
self.test_size = len(self.testset)
self.num_classes = len(self.trainset.classes)
self.class_to_idx = self.trainset.class_to_idx
def get_sample(self, index=0):
sample = self.trainset[index]
@@ -30,21 +33,14 @@ class HWDB(object):
if __name__ == '__main__':
transform = transforms.Compose([
# transforms.Grayscale(),
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
dataset = HWDB(transform=transform, path=r'data')
print(dataset.train_size)
print(dataset.test_size)
for i in [1020, 120, 2000, 6000, 1000]:
img, label = dataset.get_sample(i)
img = img[0]
print(label)
plt.imshow(img, cmap='gray')
plt.show()
train_loader, test_loader = dataset.get_loader()
# for (img, label) in train_loader:
# print(img)
# print(label)
dataset = HWDB(path=r'data', transform=transform)
print("训练集数量:", dataset.train_size)
print("测试集数量:", dataset.test_size)
print("类别数量:", dataset.num_classes)
index = random.randint(0, dataset.train_size)
img = dataset.get_sample(index)[0][0]
plt.imshow(img, cmap='gray')
plt.show()

View File

@@ -1,6 +1,8 @@
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchsummary import summary
def conv_dw(inp, oup, stride):
@@ -38,11 +40,15 @@ class ConvNet(nn.Module):
self.conv10 = conv_dw(128, 128, 1) # 16x16x128
self.conv11 = conv_dw(128, 128, 1) # 16x16x128
self.conv12 = conv_dw(128, 256, 2) # 8x8x256
self.conv13 = conv_dw(256, 256, 1) # 8x8x256
self.conv14 = conv_dw(256, 256, 1) # 8x8x256
self.conv15 = conv_dw(256, 512, 2) # 4x4x512
self.conv16 = conv_dw(512, 512, 1) # 4x4x512
self.classifier = nn.Sequential(
nn.Linear(256 * 8 * 8, 4096),
nn.Linear(512*4*4, 1024),
nn.Dropout(0.2),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
nn.Linear(1024, num_classes),
)
self.weight_init()
@@ -61,7 +67,12 @@ class ConvNet(nn.Module):
x11 = self.conv11(x10)
x11 = F.relu(x10 + x11)
x12 = self.conv12(x11)
x = x12.view(x12.size(0), -1)
x13 = self.conv13(x12)
x14 = self.conv14(x13)
x14 = F.relu(x13 + x14)
x15 = self.conv15(x14)
x16 = self.conv16(x15)
x = x16.view(x16.size(0), -1)
x = self.classifier(x)
return x
@@ -83,3 +94,6 @@ class ConvNet(nn.Module):
m.bias.data.zero_()
if __name__ == "__main__":
model = ConvNet(3755).cuda()
summary(model, input_size=(3, 64, 64), device='cuda')

View File

@@ -4,8 +4,10 @@ import os
import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
from torchvision import transforms
from torchsummary import summary
from hwdb import HWDB
from model import ConvNet
@@ -74,42 +76,43 @@ def train(epoch, net, criterion, optimizer, train_loader, writer, save_iter=100)
if __name__ == "__main__":
data_path = r'data'
data_path = r'C:\Users\Administrator\Desktop\hand-writing-recognition\data'
log_path = r'logs/batch_100_100_lr0.02'
save_path = r'checkpoints/'
if not os.path.exists(save_path):
os.mkdir(save_path)
# 超参数
epochs = 20
batch_size = 100
lr = 0.02
lr = 0.0001
# 读取分类类别
f = open('char_dict', 'rb')
class_dict = pickle.load(f)
with open('char_dict', 'rb') as f:
class_dict = pickle.load(f)
num_classes = len(class_dict)
# 读取数据
transform = transforms.Compose([
# transforms.Grayscale(),
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
dataset = HWDB(transform=transform, path=data_path)
dataset = HWDB(path=data_path, transform=transform)
print("训练集数据:", dataset.train_size)
print("测试集数据:", dataset.test_size)
trainloader, testloader = dataset.get_loader(batch_size)
net = ConvNet(num_classes)
print('网络结构:\n', net)
# net.load_state_dict(torch.load('checkpoints/handwriting_iter_004.pth'))
print('网络结构:\n')
summary(net, input_size=(3, 64, 64), device='cuda')
if torch.cuda.is_available():
net = net.cuda()
else:
print('cuda not available')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr)
writer = SummaryWriter('logs/batch_100_{}_lr{}'.format(batch_size, lr))
writer = SummaryWriter(log_path)
for epoch in range(epochs):
train(epoch, net, criterion, optimizer, trainloader, writer=writer)
valid(epoch, net, testloader, writer=writer)
print("epoch%d 结束, 正在保存模型..." % (epoch))
torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % (epoch))
print("epoch%d 结束, 正在保存模型..." % epoch)
torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % epoch)