update code and readme
This commit is contained in:
99
train.py
99
train.py
@@ -1,92 +1,85 @@
|
||||
import pickle
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.autograd import Variable
|
||||
import pickle
|
||||
import numpy as np
|
||||
from tensorboardX import SummaryWriter
|
||||
from torchvision import transforms
|
||||
|
||||
from hwdb import HWDB
|
||||
from convnet import ConvNet
|
||||
from model import ConvNet, ConvNet2
|
||||
|
||||
|
||||
def train(net,
|
||||
criterion,
|
||||
optimizer,
|
||||
train_loader,
|
||||
test_loarder,
|
||||
epoch=10,
|
||||
save_path='./pretrained_models/'):
|
||||
def adjust_learning_rate(optimizer, decay_rate=.9):
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = param_group['lr'] * decay_rate
|
||||
def train(net, criterion, optimizer, train_loader, test_loarder, writer, epoch, save_path):
|
||||
print("开始训练...")
|
||||
net.train()
|
||||
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
|
||||
for epoch in range(epoch):
|
||||
sum_loss = 0.0
|
||||
total = 0
|
||||
correct = 0
|
||||
if epoch/3 == 1:
|
||||
adjust_learning_rate(optimizer, 0.5)
|
||||
# 数据读取
|
||||
for i, (inputs, labels) in enumerate(train_loader):
|
||||
# 梯度清零
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward + backward
|
||||
if torch.cuda.is_available():
|
||||
# inputs, labels = Variable(inputs.cuda(0)), Variable(labels.cuda(0))
|
||||
inputs = inputs.to('cuda')
|
||||
labels = labels.to('cuda')
|
||||
#print(inputs.device)
|
||||
else:
|
||||
print('cuda not available')
|
||||
inputs = inputs.cuda()
|
||||
labels = labels.cuda()
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum()
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
#print(loss.item())
|
||||
# 每训练100个batch打印一次平均loss与acc
|
||||
sum_loss += loss.item()
|
||||
# if i % 100 == 99:
|
||||
if i % 100 == 99:
|
||||
loss = sum_loss/100
|
||||
print('epoch: %d, batch: %d loss: %.03f'
|
||||
% (epoch + 11, i + 1, loss), end=',')
|
||||
batch_loss = sum_loss / 100
|
||||
# 每跑完一次epoch测试一下准确率
|
||||
acc = 100 * correct / total
|
||||
print('acc:%d%%' % (acc))
|
||||
print('epoch: %d, batch: %d loss: %.03f, acc: %.04f'
|
||||
% (epoch, i + 1, batch_loss, acc))
|
||||
writer.add_scalar('train_loss', batch_loss, global_step=i+len(train_loader)*epoch)
|
||||
writer.add_scalar('train_acc', acc, global_step=i+len(train_loader)*epoch)
|
||||
for name, layer in net.named_parameters():
|
||||
writer.add_histogram(name+'_grad', layer.grad.cpu().data.numpy(), global_step=i+len(train_loader)*epoch)
|
||||
writer.add_histogram(name+'_data', layer.cpu().data.numpy(), global_step=i+len(train_loader)*epoch)
|
||||
total = 0
|
||||
correct = 0
|
||||
sum_loss = 0.0
|
||||
|
||||
print("epoch%d 训练结束, 正在保存模型..."%(epoch+11))
|
||||
torch.save(net.state_dict(), save_path+'handwriting_iter_%03d.pth' % (epoch + 11))
|
||||
if epoch%3 == 0:
|
||||
print("epoch%d 训练结束, 正在保存模型..." % (epoch))
|
||||
torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % (epoch))
|
||||
if epoch % 3 == 0:
|
||||
with torch.no_grad():
|
||||
correct = 0
|
||||
total = 0
|
||||
for images, labels in test_loader:
|
||||
images, labels = images.to('cuda'), labels.to('cuda')
|
||||
for images, labels in test_loarder:
|
||||
images, labels = images.cuda(), labels.cuda()
|
||||
outputs = net(images)
|
||||
# 取得分最高的那个类
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum()
|
||||
print('correct number: ',correct)
|
||||
correct += (predicted == labels).sum().item()
|
||||
print('correct number: ', correct)
|
||||
print('totol number:', total)
|
||||
acc = 100 * correct / total
|
||||
print('第%d个epoch的识别准确率为:%d%%' % (epoch+11, acc))
|
||||
|
||||
print('第%d个epoch的识别准确率为:%d%%' % (epoch, acc))
|
||||
writer.add_scalar('test_acc', acc, global_step=epoch)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_path = r'C:\Users\Administrator\Desktop\hand-writing-recognition\data'
|
||||
save_path = r'checkpoints'
|
||||
if not os.path.exists(save_path):
|
||||
os.mkdir(save_path)
|
||||
# 超参数
|
||||
epoch = 20
|
||||
batch_size = 100
|
||||
lr = 0.02
|
||||
|
||||
# 读取分类类别
|
||||
f = open('char_dict', 'rb')
|
||||
@@ -94,22 +87,26 @@ if __name__ == "__main__":
|
||||
num_classes = len(class_dict)
|
||||
|
||||
# 读取数据
|
||||
dataset = HWDB()
|
||||
transform = transforms.Compose([
|
||||
# transforms.Grayscale(),
|
||||
transforms.Resize((64, 64)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
dataset = HWDB(transform=transform, path=data_path)
|
||||
print("训练集数据:", dataset.train_size)
|
||||
print("测试集数据:", dataset.test_size)
|
||||
train_loader, test_loader = dataset.get_loader(batch_size)
|
||||
trainloader, testloader = dataset.get_loader(batch_size)
|
||||
|
||||
|
||||
net = ConvNet(num_classes)
|
||||
net = ConvNet2(num_classes)
|
||||
print('网络结构:\n', net)
|
||||
if torch.cuda.is_available():
|
||||
net = net.cuda(0)
|
||||
net = net.cuda()
|
||||
else:
|
||||
print('cuda not available')
|
||||
net.load_state_dict(torch.load('./pretrained_models/handwriting_iter_010.pth'))
|
||||
# net.load_state_dict(torch.load('./pretrained_models/handwriting_iter_010.pth'))
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
#optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optimizer = optim.RMSprop(net.parameters(), lr=0.000005, momentum=0.9, weight_decay=0.0005)
|
||||
train(net, criterion, optimizer, train_loader, test_loader)
|
||||
|
||||
optimizer = optim.SGD(net.parameters(), lr=lr)
|
||||
writer = SummaryWriter('logs/model/batch{}_lr{}'.format(batch_size, lr))
|
||||
# writer = SummaryWriter('logs/model_dw_res/batch{}_lr{}'.format(batch_size, lr))
|
||||
train(net, criterion, optimizer, trainloader, testloader, writer=writer, epoch=100, save_path=save_path)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user