116 lines
3.9 KiB
Python
116 lines
3.9 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
from torch.autograd import Variable
|
||
import pickle
|
||
import numpy as np
|
||
from hwdb import HWDB
|
||
from convnet import ConvNet
|
||
|
||
|
||
def train(net,
|
||
criterion,
|
||
optimizer,
|
||
train_loader,
|
||
test_loarder,
|
||
epoch=10,
|
||
save_path='./pretrained_models/'):
|
||
def adjust_learning_rate(optimizer, decay_rate=.9):
|
||
for param_group in optimizer.param_groups:
|
||
param_group['lr'] = param_group['lr'] * decay_rate
|
||
print("开始训练...")
|
||
net.train()
|
||
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
|
||
for epoch in range(epoch):
|
||
sum_loss = 0.0
|
||
total = 0
|
||
correct = 0
|
||
if epoch/3 == 1:
|
||
adjust_learning_rate(optimizer, 0.5)
|
||
# 数据读取
|
||
for i, (inputs, labels) in enumerate(train_loader):
|
||
# 梯度清零
|
||
optimizer.zero_grad()
|
||
|
||
# forward + backward
|
||
if torch.cuda.is_available():
|
||
# inputs, labels = Variable(inputs.cuda(0)), Variable(labels.cuda(0))
|
||
inputs = inputs.to('cuda')
|
||
labels = labels.to('cuda')
|
||
#print(inputs.device)
|
||
else:
|
||
print('cuda not available')
|
||
outputs = net(inputs)
|
||
loss = criterion(outputs, labels)
|
||
_, predicted = torch.max(outputs.data, 1)
|
||
total += labels.size(0)
|
||
correct += (predicted == labels).sum()
|
||
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
#print(loss.item())
|
||
# 每训练100个batch打印一次平均loss与acc
|
||
sum_loss += loss.item()
|
||
# if i % 100 == 99:
|
||
if i % 100 == 99:
|
||
loss = sum_loss/100
|
||
print('epoch: %d, batch: %d loss: %.03f'
|
||
% (epoch + 11, i + 1, loss), end=',')
|
||
# 每跑完一次epoch测试一下准确率
|
||
acc = 100 * correct / total
|
||
print('acc:%d%%' % (acc))
|
||
total = 0
|
||
correct = 0
|
||
sum_loss = 0.0
|
||
|
||
print("epoch%d 训练结束, 正在保存模型..."%(epoch+11))
|
||
torch.save(net.state_dict(), save_path+'handwriting_iter_%03d.pth' % (epoch + 11))
|
||
if epoch%3 == 0:
|
||
with torch.no_grad():
|
||
correct = 0
|
||
total = 0
|
||
for images, labels in test_loader:
|
||
images, labels = images.to('cuda'), labels.to('cuda')
|
||
outputs = net(images)
|
||
# 取得分最高的那个类
|
||
_, predicted = torch.max(outputs.data, 1)
|
||
total += labels.size(0)
|
||
correct += (predicted == labels).sum()
|
||
print('correct number: ',correct)
|
||
print('totol number:', total)
|
||
acc = 100 * correct / total
|
||
print('第%d个epoch的识别准确率为:%d%%' % (epoch+11, acc))
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 超参数
|
||
batch_size = 100
|
||
|
||
# 读取分类类别
|
||
f = open('char_dict', 'rb')
|
||
class_dict = pickle.load(f)
|
||
num_classes = len(class_dict)
|
||
|
||
# 读取数据
|
||
dataset = HWDB()
|
||
print("训练集数据:", dataset.train_size)
|
||
print("测试集数据:", dataset.test_size)
|
||
train_loader, test_loader = dataset.get_loader(batch_size)
|
||
|
||
|
||
net = ConvNet(num_classes)
|
||
print('网络结构:\n', net)
|
||
if torch.cuda.is_available():
|
||
net = net.cuda(0)
|
||
else:
|
||
print('cuda not available')
|
||
net.load_state_dict(torch.load('./pretrained_models/handwriting_iter_010.pth'))
|
||
criterion = nn.CrossEntropyLoss()
|
||
#optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||
optimizer = optim.RMSprop(net.parameters(), lr=0.000005, momentum=0.9, weight_decay=0.0005)
|
||
train(net, criterion, optimizer, train_loader, test_loader)
|
||
|
||
|