update code and readme

This commit is contained in:
JiageWang
2019-08-26 15:25:12 +08:00
parent 2beb4adb0f
commit 43ea8da649
3 changed files with 45 additions and 32 deletions

View File

@@ -4,8 +4,10 @@ import os
import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
from torchvision import transforms
from torchsummary import summary
from hwdb import HWDB
from model import ConvNet
@@ -74,42 +76,43 @@ def train(epoch, net, criterion, optimizer, train_loader, writer, save_iter=100)
if __name__ == "__main__":
data_path = r'data'
data_path = r'C:\Users\Administrator\Desktop\hand-writing-recognition\data'
log_path = r'logs/batch_100_100_lr0.02'
save_path = r'checkpoints/'
if not os.path.exists(save_path):
os.mkdir(save_path)
# 超参数
epochs = 20
batch_size = 100
lr = 0.02
lr = 0.0001
# 读取分类类别
f = open('char_dict', 'rb')
class_dict = pickle.load(f)
with open('char_dict', 'rb') as f:
class_dict = pickle.load(f)
num_classes = len(class_dict)
# 读取数据
transform = transforms.Compose([
# transforms.Grayscale(),
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
dataset = HWDB(transform=transform, path=data_path)
dataset = HWDB(path=data_path, transform=transform)
print("训练集数据:", dataset.train_size)
print("测试集数据:", dataset.test_size)
trainloader, testloader = dataset.get_loader(batch_size)
net = ConvNet(num_classes)
print('网络结构:\n', net)
# net.load_state_dict(torch.load('checkpoints/handwriting_iter_004.pth'))
print('网络结构:\n')
summary(net, input_size=(3, 64, 64), device='cuda')
if torch.cuda.is_available():
net = net.cuda()
else:
print('cuda not available')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr)
writer = SummaryWriter('logs/batch_100_{}_lr{}'.format(batch_size, lr))
writer = SummaryWriter(log_path)
for epoch in range(epochs):
train(epoch, net, criterion, optimizer, trainloader, writer=writer)
valid(epoch, net, testloader, writer=writer)
print("epoch%d 结束, 正在保存模型..." % (epoch))
torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % (epoch))
print("epoch%d 结束, 正在保存模型..." % epoch)
torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % epoch)