first commit
This commit is contained in:
115
train.py
Normal file
115
train.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.autograd import Variable
|
||||
import pickle
|
||||
import numpy as np
|
||||
from hwdb import HWDB
|
||||
from convnet import ConvNet
|
||||
|
||||
|
||||
def train(net,
|
||||
criterion,
|
||||
optimizer,
|
||||
train_loader,
|
||||
test_loarder,
|
||||
epoch=10,
|
||||
save_path='./pretrained_models/'):
|
||||
def adjust_learning_rate(optimizer, decay_rate=.9):
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = param_group['lr'] * decay_rate
|
||||
print("开始训练...")
|
||||
net.train()
|
||||
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
|
||||
for epoch in range(epoch):
|
||||
sum_loss = 0.0
|
||||
total = 0
|
||||
correct = 0
|
||||
if epoch/3 == 1:
|
||||
adjust_learning_rate(optimizer, 0.5)
|
||||
# 数据读取
|
||||
for i, (inputs, labels) in enumerate(train_loader):
|
||||
# 梯度清零
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward + backward
|
||||
if torch.cuda.is_available():
|
||||
# inputs, labels = Variable(inputs.cuda(0)), Variable(labels.cuda(0))
|
||||
inputs = inputs.to('cuda')
|
||||
labels = labels.to('cuda')
|
||||
#print(inputs.device)
|
||||
else:
|
||||
print('cuda not available')
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum()
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
#print(loss.item())
|
||||
# 每训练100个batch打印一次平均loss与acc
|
||||
sum_loss += loss.item()
|
||||
# if i % 100 == 99:
|
||||
if i % 100 == 99:
|
||||
loss = sum_loss/100
|
||||
print('epoch: %d, batch: %d loss: %.03f'
|
||||
% (epoch + 11, i + 1, loss), end=',')
|
||||
# 每跑完一次epoch测试一下准确率
|
||||
acc = 100 * correct / total
|
||||
print('acc:%d%%' % (acc))
|
||||
total = 0
|
||||
correct = 0
|
||||
sum_loss = 0.0
|
||||
|
||||
print("epoch%d 训练结束, 正在保存模型..."%(epoch+11))
|
||||
torch.save(net.state_dict(), save_path+'handwriting_iter_%03d.pth' % (epoch + 11))
|
||||
if epoch%3 == 0:
|
||||
with torch.no_grad():
|
||||
correct = 0
|
||||
total = 0
|
||||
for images, labels in test_loader:
|
||||
images, labels = images.to('cuda'), labels.to('cuda')
|
||||
outputs = net(images)
|
||||
# 取得分最高的那个类
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum()
|
||||
print('correct number: ',correct)
|
||||
print('totol number:', total)
|
||||
acc = 100 * correct / total
|
||||
print('第%d个epoch的识别准确率为:%d%%' % (epoch+11, acc))
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 超参数
|
||||
batch_size = 100
|
||||
|
||||
# 读取分类类别
|
||||
f = open('char_dict', 'rb')
|
||||
class_dict = pickle.load(f)
|
||||
num_classes = len(class_dict)
|
||||
|
||||
# 读取数据
|
||||
dataset = HWDB()
|
||||
print("训练集数据:", dataset.train_size)
|
||||
print("测试集数据:", dataset.test_size)
|
||||
train_loader, test_loader = dataset.get_loader(batch_size)
|
||||
|
||||
|
||||
net = ConvNet(num_classes)
|
||||
print('网络结构:\n', net)
|
||||
if torch.cuda.is_available():
|
||||
net = net.cuda(0)
|
||||
else:
|
||||
print('cuda not available')
|
||||
net.load_state_dict(torch.load('./pretrained_models/handwriting_iter_010.pth'))
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
#optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optimizer = optim.RMSprop(net.parameters(), lr=0.000005, momentum=0.9, weight_decay=0.0005)
|
||||
train(net, criterion, optimizer, train_loader, test_loader)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user