import pickle import os import torch import torch.nn as nn import torch.optim as optim from tensorboardX import SummaryWriter from torchvision import transforms from hwdb import HWDB from model import ConvNet def valid(epoch, net, test_loarder, writer): print("epoch %d 开始验证..." % epoch) with torch.no_grad(): correct = 0 total = 0 for images, labels in test_loarder: images, labels = images.cuda(), labels.cuda() outputs = net(images) # 取得分最高的那个类 _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('correct number: ', correct) print('totol number:', total) acc = 100 * correct / total print('第%d个epoch的识别准确率为:%d%%' % (epoch, acc)) writer.add_scalar('valid_acc', acc, global_step=epoch) def train(epoch, net, criterion, optimizer, train_loader, writer, save_iter=100): print("epoch %d 开始训练..." % epoch) net.train() sum_loss = 0.0 total = 0 correct = 0 # 数据读取 for i, (inputs, labels) in enumerate(train_loader): # 梯度清零 optimizer.zero_grad() if torch.cuda.is_available(): inputs = inputs.cuda() labels = labels.cuda() outputs = net(inputs) loss = criterion(outputs, labels) # 取得分最高的那个类 _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() loss.backward() optimizer.step() # 每训练100个batch打印一次平均loss与acc sum_loss += loss.item() if (i + 1) % save_iter == 0: batch_loss = sum_loss / save_iter # 每跑完一次epoch测试一下准确率 acc = 100 * correct / total print('epoch: %d, batch: %d loss: %.03f, acc: %.04f' % (epoch, i + 1, batch_loss, acc)) writer.add_scalar('train_loss', batch_loss, global_step=i + len(train_loader) * epoch) writer.add_scalar('train_acc', acc, global_step=i + len(train_loader) * epoch) for name, layer in net.named_parameters(): writer.add_histogram(name + '_grad', layer.grad.cpu().data.numpy(), global_step=i + len(train_loader) * epoch) writer.add_histogram(name + '_data', layer.cpu().data.numpy(), global_step=i + len(train_loader) * epoch) total = 0 correct = 0 sum_loss = 0.0 if __name__ == "__main__": data_path = r'data' save_path = r'checkpoints/' if not os.path.exists(save_path): os.mkdir(save_path) # 超参数 epochs = 20 batch_size = 100 lr = 0.02 # 读取分类类别 f = open('char_dict', 'rb') class_dict = pickle.load(f) num_classes = len(class_dict) # 读取数据 transform = transforms.Compose([ # transforms.Grayscale(), transforms.Resize((64, 64)), transforms.ToTensor(), ]) dataset = HWDB(transform=transform, path=data_path) print("训练集数据:", dataset.train_size) print("测试集数据:", dataset.test_size) trainloader, testloader = dataset.get_loader(batch_size) net = ConvNet(num_classes) print('网络结构:\n', net) if torch.cuda.is_available(): net = net.cuda() else: print('cuda not available') criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=lr) writer = SummaryWriter('logs/batch_100_{}_lr{}'.format(batch_size, lr)) for epoch in range(epochs): train(epoch, net, criterion, optimizer, trainloader, writer=writer) valid(epoch, net, testloader, writer=writer) print("epoch%d 结束, 正在保存模型..." % (epoch)) torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % (epoch))