update code and readme

This commit is contained in:
JiageWang
2019-08-25 19:07:44 +08:00
parent 446aac5c71
commit 2beb4adb0f
6 changed files with 106 additions and 147 deletions

View File

@@ -1,5 +1,6 @@
# hand-writing-recognition # hand-writing-recognition
基于pytorch卷积神经网络的手写汉字识别使用HWDB数据库 基于pytorch卷积神经网络的中文手写汉字识别使用HWDB数据库
![hwdb](hwdb.jpg)
## Dependence ## Dependence
* PIL * PIL

BIN
hwdb.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 179 KiB

View File

@@ -34,7 +34,7 @@ if __name__ == '__main__':
transforms.Resize((64, 64)), transforms.Resize((64, 64)),
transforms.ToTensor(), transforms.ToTensor(),
]) ])
dataset = HWDB(transform=transform, path=r'C:\Users\Administrator\Desktop\hand-writing-recognition\data') dataset = HWDB(transform=transform, path=r'data')
print(dataset.train_size) print(dataset.train_size)
print(dataset.test_size) print(dataset.test_size)
for i in [1020, 120, 2000, 6000, 1000]: for i in [1020, 120, 2000, 6000, 1000]:
@@ -45,7 +45,6 @@ if __name__ == '__main__':
plt.show() plt.show()
train_loader, test_loader = dataset.get_loader() train_loader, test_loader = dataset.get_loader()
print(len(train_loader))
# for (img, label) in train_loader: # for (img, label) in train_loader:
# print(img) # print(img)
# print(label) # print(label)

View File

@@ -1,7 +1,5 @@
import math
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import init
import numpy as np import numpy as np
@@ -46,7 +44,7 @@ class ConvNet(nn.Module):
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(4096, num_classes), nn.Linear(4096, num_classes),
) )
# self.weight_init() self.weight_init()
def forward(self, x): def forward(self, x):
x1 = self.conv1(x) x1 = self.conv1(x)
@@ -85,56 +83,3 @@ class ConvNet(nn.Module):
m.bias.data.zero_() m.bias.data.zero_()
class ConvNet2(nn.Module):
def __init__(self, num_classes):
super(ConvNet2, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(512 * 4 * 4, 4096),
# nn.Dropout(),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
self.weight_init()
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def weight_init(self):
for layer in self.features:
self._layer_init(layer)
for layer in self.classifier:
self._layer_init(layer)
def _layer_init(self, m):
# 使用isinstance来判断m属于什么类型
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, np.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
# m中的weightbias其实都是Variable为了能学习参数以及后向传播
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
# n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
# init.xavier_normal_(m.weight)

View File

@@ -5,6 +5,7 @@ import threading
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from concurrent.futures import ThreadPoolExecutor
# 处理单个gnt文件获取图像与标签 # 处理单个gnt文件获取图像与标签
@@ -33,20 +34,26 @@ def read_from_gnt_dir(gnt_dir):
def gnt_to_img(gnt_dir, img_dir): def gnt_to_img(gnt_dir, img_dir):
counter = 0 def save_img(label, image, counter):
for image, label in read_from_gnt_dir(gnt_dir=gnt_dir):
label = struct.pack('>H', label).decode('gb2312') label = struct.pack('>H', label).decode('gb2312')
img = Image.fromarray(image) img = Image.fromarray(image)
dir_name = os.path.join(img_dir, '%0.5d' % char_dict[label]) dir_name = os.path.join(img_dir, '%0.5d' % char_dict[label])
if not os.path.exists(dir_name): if not os.path.exists(dir_name):
os.mkdir(dir_name) os.mkdir(dir_name)
img.convert('RGB').save(dir_name + '/' + str(counter) + '.png') img.convert('RGB').save(dir_name + '/' + str(counter) + '.png')
print("train_counter=", counter) print("thread: {}, counter=".format(threading.current_thread().name), counter)
counter = 0
thread_pool = ThreadPoolExecutor(4) # 定义4个线程执行此任务
for image, label in read_from_gnt_dir(gnt_dir=gnt_dir):
thread_pool.submit(save_img, label, image, counter)
counter += 1 counter += 1
thread_pool.shutdown()
if __name__ == "__main__":
# 路径 # 路径
data_dir = './data' data_dir = r'./data'
train_gnt_dir = os.path.join(data_dir, 'HWDB1.1trn_gnt') train_gnt_dir = os.path.join(data_dir, 'HWDB1.1trn_gnt')
test_gnt_dir = os.path.join(data_dir, 'HWDB1.1tst_gnt') test_gnt_dir = os.path.join(data_dir, 'HWDB1.1tst_gnt')
train_img_dir = os.path.join(data_dir, 'train') train_img_dir = os.path.join(data_dir, 'train')
@@ -57,6 +64,7 @@ if not os.path.exists(test_img_dir):
os.mkdir(test_img_dir) os.mkdir(test_img_dir)
# 获取字符集合 # 获取字符集合
if not os.path.exists('char_dict'):
char_set = set() char_set = set()
for _, tagcode in read_from_gnt_dir(gnt_dir=test_gnt_dir): for _, tagcode in read_from_gnt_dir(gnt_dir=test_gnt_dir):
tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312') tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
@@ -68,6 +76,9 @@ print("char_dict=", char_dict)
with open('char_dict', 'wb') as f: with open('char_dict', 'wb') as f:
pickle.dump(char_dict, f) pickle.dump(char_dict, f)
else:
with open('char_dict', 'rb') as f:
char_dict = pickle.load(f)
train_thread = threading.Thread(target=gnt_to_img, args=(train_gnt_dir, train_img_dir)).start() train_thread = threading.Thread(target=gnt_to_img, args=(train_gnt_dir, train_img_dir)).start()
test_thread = threading.Thread(target=gnt_to_img, args=(test_gnt_dir, test_img_dir)).start() test_thread = threading.Thread(target=gnt_to_img, args=(test_gnt_dir, test_img_dir)).start()

109
train.py
View File

@@ -8,52 +8,11 @@ from tensorboardX import SummaryWriter
from torchvision import transforms from torchvision import transforms
from hwdb import HWDB from hwdb import HWDB
from model import ConvNet, ConvNet2 from model import ConvNet
def train(net, criterion, optimizer, train_loader, test_loarder, writer, epoch, save_path): def valid(epoch, net, test_loarder, writer):
print("开始训练...") print("epoch %d 开始验证..." % epoch)
net.train()
for epoch in range(epoch):
sum_loss = 0.0
total = 0
correct = 0
# 数据读取
for i, (inputs, labels) in enumerate(train_loader):
# 梯度清零
optimizer.zero_grad()
if torch.cuda.is_available():
inputs = inputs.cuda()
labels = labels.cuda()
outputs = net(inputs)
loss = criterion(outputs, labels)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loss.backward()
optimizer.step()
# 每训练100个batch打印一次平均loss与acc
sum_loss += loss.item()
if i % 100 == 99:
batch_loss = sum_loss / 100
# 每跑完一次epoch测试一下准确率
acc = 100 * correct / total
print('epoch: %d, batch: %d loss: %.03f, acc: %.04f'
% (epoch, i + 1, batch_loss, acc))
writer.add_scalar('train_loss', batch_loss, global_step=i+len(train_loader)*epoch)
writer.add_scalar('train_acc', acc, global_step=i+len(train_loader)*epoch)
for name, layer in net.named_parameters():
writer.add_histogram(name+'_grad', layer.grad.cpu().data.numpy(), global_step=i+len(train_loader)*epoch)
writer.add_histogram(name+'_data', layer.cpu().data.numpy(), global_step=i+len(train_loader)*epoch)
total = 0
correct = 0
sum_loss = 0.0
print("epoch%d 训练结束, 正在保存模型..." % (epoch))
torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % (epoch))
if epoch % 3 == 0:
with torch.no_grad(): with torch.no_grad():
correct = 0 correct = 0
total = 0 total = 0
@@ -68,16 +27,59 @@ def train(net, criterion, optimizer, train_loader, test_loarder, writer, epoch,
print('totol number:', total) print('totol number:', total)
acc = 100 * correct / total acc = 100 * correct / total
print('%d个epoch的识别准确率为%d%%' % (epoch, acc)) print('%d个epoch的识别准确率为%d%%' % (epoch, acc))
writer.add_scalar('test_acc', acc, global_step=epoch) writer.add_scalar('valid_acc', acc, global_step=epoch)
def train(epoch, net, criterion, optimizer, train_loader, writer, save_iter=100):
print("epoch %d 开始训练..." % epoch)
net.train()
sum_loss = 0.0
total = 0
correct = 0
# 数据读取
for i, (inputs, labels) in enumerate(train_loader):
# 梯度清零
optimizer.zero_grad()
if torch.cuda.is_available():
inputs = inputs.cuda()
labels = labels.cuda()
outputs = net(inputs)
loss = criterion(outputs, labels)
# 取得分最高的那个类
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loss.backward()
optimizer.step()
# 每训练100个batch打印一次平均loss与acc
sum_loss += loss.item()
if (i + 1) % save_iter == 0:
batch_loss = sum_loss / save_iter
# 每跑完一次epoch测试一下准确率
acc = 100 * correct / total
print('epoch: %d, batch: %d loss: %.03f, acc: %.04f'
% (epoch, i + 1, batch_loss, acc))
writer.add_scalar('train_loss', batch_loss, global_step=i + len(train_loader) * epoch)
writer.add_scalar('train_acc', acc, global_step=i + len(train_loader) * epoch)
for name, layer in net.named_parameters():
writer.add_histogram(name + '_grad', layer.grad.cpu().data.numpy(),
global_step=i + len(train_loader) * epoch)
writer.add_histogram(name + '_data', layer.cpu().data.numpy(),
global_step=i + len(train_loader) * epoch)
total = 0
correct = 0
sum_loss = 0.0
if __name__ == "__main__": if __name__ == "__main__":
data_path = r'C:\Users\Administrator\Desktop\hand-writing-recognition\data' data_path = r'data'
save_path = r'checkpoints' save_path = r'checkpoints/'
if not os.path.exists(save_path): if not os.path.exists(save_path):
os.mkdir(save_path) os.mkdir(save_path)
# 超参数 # 超参数
epoch = 20 epochs = 20
batch_size = 100 batch_size = 100
lr = 0.02 lr = 0.02
@@ -103,10 +105,11 @@ if __name__ == "__main__":
net = net.cuda() net = net.cuda()
else: else:
print('cuda not available') print('cuda not available')
# net.load_state_dict(torch.load('./pretrained_models/handwriting_iter_010.pth'))
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr) optimizer = optim.SGD(net.parameters(), lr=lr)
writer = SummaryWriter('logs/model/batch{}_lr{}'.format(batch_size, lr)) writer = SummaryWriter('logs/batch_100_{}_lr{}'.format(batch_size, lr))
# writer = SummaryWriter('logs/model_dw_res/batch{}_lr{}'.format(batch_size, lr)) for epoch in range(epochs):
train(net, criterion, optimizer, trainloader, testloader, writer=writer, epoch=100, save_path=save_path) train(epoch, net, criterion, optimizer, trainloader, writer=writer)
valid(epoch, net, testloader, writer=writer)
print("epoch%d 结束, 正在保存模型..." % (epoch))
torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % (epoch))