113 lines
4.2 KiB
Python
113 lines
4.2 KiB
Python
import pickle
|
||
import os
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
from tensorboardX import SummaryWriter
|
||
from torchvision import transforms
|
||
|
||
from hwdb import HWDB
|
||
from model import ConvNet, ConvNet2
|
||
|
||
|
||
def train(net, criterion, optimizer, train_loader, test_loarder, writer, epoch, save_path):
|
||
print("开始训练...")
|
||
net.train()
|
||
for epoch in range(epoch):
|
||
sum_loss = 0.0
|
||
total = 0
|
||
correct = 0
|
||
# 数据读取
|
||
for i, (inputs, labels) in enumerate(train_loader):
|
||
# 梯度清零
|
||
optimizer.zero_grad()
|
||
if torch.cuda.is_available():
|
||
inputs = inputs.cuda()
|
||
labels = labels.cuda()
|
||
outputs = net(inputs)
|
||
loss = criterion(outputs, labels)
|
||
_, predicted = torch.max(outputs.data, 1)
|
||
total += labels.size(0)
|
||
correct += (predicted == labels).sum().item()
|
||
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
# 每训练100个batch打印一次平均loss与acc
|
||
sum_loss += loss.item()
|
||
if i % 100 == 99:
|
||
batch_loss = sum_loss / 100
|
||
# 每跑完一次epoch测试一下准确率
|
||
acc = 100 * correct / total
|
||
print('epoch: %d, batch: %d loss: %.03f, acc: %.04f'
|
||
% (epoch, i + 1, batch_loss, acc))
|
||
writer.add_scalar('train_loss', batch_loss, global_step=i+len(train_loader)*epoch)
|
||
writer.add_scalar('train_acc', acc, global_step=i+len(train_loader)*epoch)
|
||
for name, layer in net.named_parameters():
|
||
writer.add_histogram(name+'_grad', layer.grad.cpu().data.numpy(), global_step=i+len(train_loader)*epoch)
|
||
writer.add_histogram(name+'_data', layer.cpu().data.numpy(), global_step=i+len(train_loader)*epoch)
|
||
total = 0
|
||
correct = 0
|
||
sum_loss = 0.0
|
||
|
||
print("epoch%d 训练结束, 正在保存模型..." % (epoch))
|
||
torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % (epoch))
|
||
if epoch % 3 == 0:
|
||
with torch.no_grad():
|
||
correct = 0
|
||
total = 0
|
||
for images, labels in test_loarder:
|
||
images, labels = images.cuda(), labels.cuda()
|
||
outputs = net(images)
|
||
# 取得分最高的那个类
|
||
_, predicted = torch.max(outputs.data, 1)
|
||
total += labels.size(0)
|
||
correct += (predicted == labels).sum().item()
|
||
print('correct number: ', correct)
|
||
print('totol number:', total)
|
||
acc = 100 * correct / total
|
||
print('第%d个epoch的识别准确率为:%d%%' % (epoch, acc))
|
||
writer.add_scalar('test_acc', acc, global_step=epoch)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
data_path = r'C:\Users\Administrator\Desktop\hand-writing-recognition\data'
|
||
save_path = r'checkpoints'
|
||
if not os.path.exists(save_path):
|
||
os.mkdir(save_path)
|
||
# 超参数
|
||
epoch = 20
|
||
batch_size = 100
|
||
lr = 0.02
|
||
|
||
# 读取分类类别
|
||
f = open('char_dict', 'rb')
|
||
class_dict = pickle.load(f)
|
||
num_classes = len(class_dict)
|
||
|
||
# 读取数据
|
||
transform = transforms.Compose([
|
||
# transforms.Grayscale(),
|
||
transforms.Resize((64, 64)),
|
||
transforms.ToTensor(),
|
||
])
|
||
dataset = HWDB(transform=transform, path=data_path)
|
||
print("训练集数据:", dataset.train_size)
|
||
print("测试集数据:", dataset.test_size)
|
||
trainloader, testloader = dataset.get_loader(batch_size)
|
||
|
||
net = ConvNet(num_classes)
|
||
print('网络结构:\n', net)
|
||
if torch.cuda.is_available():
|
||
net = net.cuda()
|
||
else:
|
||
print('cuda not available')
|
||
# net.load_state_dict(torch.load('./pretrained_models/handwriting_iter_010.pth'))
|
||
criterion = nn.CrossEntropyLoss()
|
||
optimizer = optim.SGD(net.parameters(), lr=lr)
|
||
writer = SummaryWriter('logs/model/batch{}_lr{}'.format(batch_size, lr))
|
||
# writer = SummaryWriter('logs/model_dw_res/batch{}_lr{}'.format(batch_size, lr))
|
||
train(net, criterion, optimizer, trainloader, testloader, writer=writer, epoch=100, save_path=save_path)
|
||
|