diff --git a/hwdb.py b/hwdb.py index 92fd16c..0792e83 100644 --- a/hwdb.py +++ b/hwdb.py @@ -1,4 +1,5 @@ import os +import random from torch.utils.data import DataLoader import torchvision.transforms as transforms import torchvision.datasets as datasets @@ -6,7 +7,7 @@ import matplotlib.pyplot as plt class HWDB(object): - def __init__(self, transform, path='./data'): + def __init__(self,path, transform): # 预处理过程 traindir = os.path.join(path, 'train') @@ -16,6 +17,8 @@ class HWDB(object): self.testset = datasets.ImageFolder(testdir, transform) self.train_size = len(self.trainset) self.test_size = len(self.testset) + self.num_classes = len(self.trainset.classes) + self.class_to_idx = self.trainset.class_to_idx def get_sample(self, index=0): sample = self.trainset[index] @@ -30,21 +33,14 @@ class HWDB(object): if __name__ == '__main__': transform = transforms.Compose([ - # transforms.Grayscale(), transforms.Resize((64, 64)), transforms.ToTensor(), ]) - dataset = HWDB(transform=transform, path=r'data') - print(dataset.train_size) - print(dataset.test_size) - for i in [1020, 120, 2000, 6000, 1000]: - img, label = dataset.get_sample(i) - img = img[0] - print(label) - plt.imshow(img, cmap='gray') - plt.show() - - train_loader, test_loader = dataset.get_loader() - # for (img, label) in train_loader: - # print(img) - # print(label) + dataset = HWDB(path=r'data', transform=transform) + print("训练集数量:", dataset.train_size) + print("测试集数量:", dataset.test_size) + print("类别数量:", dataset.num_classes) + index = random.randint(0, dataset.train_size) + img = dataset.get_sample(index)[0][0] + plt.imshow(img, cmap='gray') + plt.show() diff --git a/model.py b/model.py index d009dda..fe0e002 100644 --- a/model.py +++ b/model.py @@ -1,6 +1,8 @@ +import numpy as np + import torch.nn as nn import torch.nn.functional as F -import numpy as np +from torchsummary import summary def conv_dw(inp, oup, stride): @@ -38,11 +40,15 @@ class ConvNet(nn.Module): self.conv10 = conv_dw(128, 128, 1) # 16x16x128 self.conv11 = conv_dw(128, 128, 1) # 16x16x128 self.conv12 = conv_dw(128, 256, 2) # 8x8x256 + self.conv13 = conv_dw(256, 256, 1) # 8x8x256 + self.conv14 = conv_dw(256, 256, 1) # 8x8x256 + self.conv15 = conv_dw(256, 512, 2) # 4x4x512 + self.conv16 = conv_dw(512, 512, 1) # 4x4x512 self.classifier = nn.Sequential( - nn.Linear(256 * 8 * 8, 4096), + nn.Linear(512*4*4, 1024), nn.Dropout(0.2), nn.ReLU(inplace=True), - nn.Linear(4096, num_classes), + nn.Linear(1024, num_classes), ) self.weight_init() @@ -61,7 +67,12 @@ class ConvNet(nn.Module): x11 = self.conv11(x10) x11 = F.relu(x10 + x11) x12 = self.conv12(x11) - x = x12.view(x12.size(0), -1) + x13 = self.conv13(x12) + x14 = self.conv14(x13) + x14 = F.relu(x13 + x14) + x15 = self.conv15(x14) + x16 = self.conv16(x15) + x = x16.view(x16.size(0), -1) x = self.classifier(x) return x @@ -83,3 +94,6 @@ class ConvNet(nn.Module): m.bias.data.zero_() +if __name__ == "__main__": + model = ConvNet(3755).cuda() + summary(model, input_size=(3, 64, 64), device='cuda') diff --git a/train.py b/train.py index 866b8f1..b926df5 100644 --- a/train.py +++ b/train.py @@ -4,8 +4,10 @@ import os import torch import torch.nn as nn import torch.optim as optim + from tensorboardX import SummaryWriter from torchvision import transforms +from torchsummary import summary from hwdb import HWDB from model import ConvNet @@ -74,42 +76,43 @@ def train(epoch, net, criterion, optimizer, train_loader, writer, save_iter=100) if __name__ == "__main__": - data_path = r'data' + data_path = r'C:\Users\Administrator\Desktop\hand-writing-recognition\data' + log_path = r'logs/batch_100_100_lr0.02' save_path = r'checkpoints/' if not os.path.exists(save_path): os.mkdir(save_path) # 超参数 epochs = 20 batch_size = 100 - lr = 0.02 + lr = 0.0001 # 读取分类类别 - f = open('char_dict', 'rb') - class_dict = pickle.load(f) + with open('char_dict', 'rb') as f: + class_dict = pickle.load(f) num_classes = len(class_dict) # 读取数据 transform = transforms.Compose([ - # transforms.Grayscale(), transforms.Resize((64, 64)), transforms.ToTensor(), ]) - dataset = HWDB(transform=transform, path=data_path) + dataset = HWDB(path=data_path, transform=transform) print("训练集数据:", dataset.train_size) print("测试集数据:", dataset.test_size) trainloader, testloader = dataset.get_loader(batch_size) net = ConvNet(num_classes) - print('网络结构:\n', net) + # net.load_state_dict(torch.load('checkpoints/handwriting_iter_004.pth')) + + print('网络结构:\n') + summary(net, input_size=(3, 64, 64), device='cuda') if torch.cuda.is_available(): net = net.cuda() - else: - print('cuda not available') criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=lr) - writer = SummaryWriter('logs/batch_100_{}_lr{}'.format(batch_size, lr)) + writer = SummaryWriter(log_path) for epoch in range(epochs): train(epoch, net, criterion, optimizer, trainloader, writer=writer) valid(epoch, net, testloader, writer=writer) - print("epoch%d 结束, 正在保存模型..." % (epoch)) - torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % (epoch)) + print("epoch%d 结束, 正在保存模型..." % epoch) + torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % epoch)