update code and readme
This commit is contained in:
27
train.py
27
train.py
@@ -4,8 +4,10 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
from torchvision import transforms
|
||||
from torchsummary import summary
|
||||
|
||||
from hwdb import HWDB
|
||||
from model import ConvNet
|
||||
@@ -74,42 +76,43 @@ def train(epoch, net, criterion, optimizer, train_loader, writer, save_iter=100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_path = r'data'
|
||||
data_path = r'C:\Users\Administrator\Desktop\hand-writing-recognition\data'
|
||||
log_path = r'logs/batch_100_100_lr0.02'
|
||||
save_path = r'checkpoints/'
|
||||
if not os.path.exists(save_path):
|
||||
os.mkdir(save_path)
|
||||
# 超参数
|
||||
epochs = 20
|
||||
batch_size = 100
|
||||
lr = 0.02
|
||||
lr = 0.0001
|
||||
|
||||
# 读取分类类别
|
||||
f = open('char_dict', 'rb')
|
||||
class_dict = pickle.load(f)
|
||||
with open('char_dict', 'rb') as f:
|
||||
class_dict = pickle.load(f)
|
||||
num_classes = len(class_dict)
|
||||
|
||||
# 读取数据
|
||||
transform = transforms.Compose([
|
||||
# transforms.Grayscale(),
|
||||
transforms.Resize((64, 64)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
dataset = HWDB(transform=transform, path=data_path)
|
||||
dataset = HWDB(path=data_path, transform=transform)
|
||||
print("训练集数据:", dataset.train_size)
|
||||
print("测试集数据:", dataset.test_size)
|
||||
trainloader, testloader = dataset.get_loader(batch_size)
|
||||
|
||||
net = ConvNet(num_classes)
|
||||
print('网络结构:\n', net)
|
||||
# net.load_state_dict(torch.load('checkpoints/handwriting_iter_004.pth'))
|
||||
|
||||
print('网络结构:\n')
|
||||
summary(net, input_size=(3, 64, 64), device='cuda')
|
||||
if torch.cuda.is_available():
|
||||
net = net.cuda()
|
||||
else:
|
||||
print('cuda not available')
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(net.parameters(), lr=lr)
|
||||
writer = SummaryWriter('logs/batch_100_{}_lr{}'.format(batch_size, lr))
|
||||
writer = SummaryWriter(log_path)
|
||||
for epoch in range(epochs):
|
||||
train(epoch, net, criterion, optimizer, trainloader, writer=writer)
|
||||
valid(epoch, net, testloader, writer=writer)
|
||||
print("epoch%d 结束, 正在保存模型..." % (epoch))
|
||||
torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % (epoch))
|
||||
print("epoch%d 结束, 正在保存模型..." % epoch)
|
||||
torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % epoch)
|
||||
|
||||
Reference in New Issue
Block a user