update code and readme

This commit is contained in:
JiageWang
2019-08-24 21:26:06 +08:00
parent bff40d9966
commit 90ef0ab210
5 changed files with 238 additions and 144 deletions

View File

@@ -1,2 +1,16 @@
# hand-writing-recognition
基于pytorch卷积神经网络的手写汉字识别使用HWDB数据库
## Dependence
* PIL
* numpy
* torch
* torchvision
* tensorboardX(for visulizztion)
## Usage
1. Download HWDB dataset and unzip to `data` folder
2. run `python process_gnt.py` to generate img from gnt fiel. Due to the huge dataset (897758+223991 images), it may take a lot of time. I suggest to put the data folder out of project or your pycharm will get slow.
3. run `python hwdb.py` to visualize the image.
4. run `python train.py` to start trianing.

61
hwdb.py
View File

@@ -1,23 +1,14 @@
import os
import torch
import torch.utils.data
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
class HWDB(object):
def __init__(self, path='./data'):
# 预处理过程
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Lambda(lambda x: Image.fromarray(255 - np.array(x))),
transforms.CenterCrop(64),
transforms.ToTensor(),
])
#
class HWDB(object):
def __init__(self, transform, path='./data'):
# 预处理过程
traindir = os.path.join(path, 'train')
testdir = os.path.join(path, 'test')
@@ -29,32 +20,32 @@ class HWDB(object):
def get_sample(self, index=0):
sample = self.trainset[index]
sample_img, sample_label = sample
print(sample_img.size())
return sample_img, sample_label
def get_loader(self, batch_size=100):
train_loader = torch.utils.data.DataLoader(
self.trainset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
self.testset, batch_size=batch_size, shuffle=True)
return train_loader, test_loader
trainloader = DataLoader(self.trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(self.testset, batch_size=batch_size, shuffle=True)
return trainloader, testloader
if __name__ == '__main__':
dataset = HWDB()
for i in [1, 10, 2000, 6000, 1000]:
img, label = dataset.get_sample(i)
img = img[0]
plt.imshow(img, cmap='gray')
plt.show()
transform = transforms.Compose([
# transforms.Grayscale(),
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
dataset = HWDB(transform=transform, path=r'C:\Users\Administrator\Desktop\hand-writing-recognition\data')
print(dataset.train_size)
print(dataset.test_size)
# for i in [1020, 120, 2000, 6000, 1000]:
# img, label = dataset.get_sample(i)
# img = img[0]
# print(label)
# plt.imshow(img, cmap='gray')
# plt.show()
train_loader, test_loader = dataset.get_loader()
for (img, label) in train_loader:
print(img)
print(label)
print(len(train_loader))
# for (img, label) in train_loader:
# print(img)
# print(label)

105
model.py
View File

@@ -1,32 +1,113 @@
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torchvision.models.vgg import vgg16_bn
import numpy as np
def conv_dw(inp, oup, stride):
return nn.Sequential(
# dw
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)
class ConvNet(nn.Module):
def __init__(self, num_classes):
super(ConvNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
self.conv1 = conv_bn(3, 8, 1) # 64x64x1
self.conv2 = conv_bn(8, 16, 1) # 64x64x16
self.conv3 = conv_dw(16, 32, 1) # 64x64x32
self.conv4 = conv_dw(32, 32, 2) # 32x32x32
self.conv5 = conv_dw(32, 64, 1) # 32x32x64
self.conv6 = conv_dw(64, 64, 2) # 16x16x64
self.conv7 = conv_dw(64, 128, 1) # 16x16x128
self.conv8 = conv_dw(128, 128, 1) # 16x16x128
self.conv9 = conv_dw(128, 128, 1) # 16x16x128
self.conv10 = conv_dw(128, 128, 1) # 16x16x128
self.conv11 = conv_dw(128, 128, 1) # 16x16x128
self.conv12 = conv_dw(128, 256, 2) # 8x8x256
self.classifier = nn.Sequential(
nn.Linear(256 * 8 * 8, 4096),
nn.Dropout(0.2),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
# self.weight_init()
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x2)
x4 = self.conv4(x3)
x5 = self.conv5(x4)
x6 = self.conv6(x5)
x7 = self.conv7(x6)
x8 = self.conv8(x7)
x9 = self.conv9(x8)
x9 = F.relu(x8 + x9)
x10 = self.conv10(x9)
x11 = self.conv11(x10)
x11 = F.relu(x10 + x11)
x12 = self.conv12(x11)
x = x12.view(x12.size(0), -1)
x = self.classifier(x)
return x
def weight_init(self):
for layer in self.modules():
self._layer_init(layer)
def _layer_init(self, m):
# 使用isinstance来判断m属于什么类型
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, np.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
# m中的weightbias其实都是Variable为了能学习参数以及后向传播
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
class ConvNet2(nn.Module):
def __init__(self, num_classes):
super(ConvNet2, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.LeakyReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(512*4*4, 1024),
nn.Linear(512*4*4, 4096),
# nn.Dropout(),
nn.ReLU(inplace=True),
nn.Linear(1024, num_classes),
nn.Linear(4096, num_classes),
)
self.weight_init()
@@ -53,5 +134,7 @@ class ConvNet(nn.Module):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.xavier_normal(m.weight)
# n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
# init.xavier_normal_(m.weight)

View File

@@ -1,66 +1,75 @@
import os
import numpy as np
import struct
import pickle
import threading
import numpy as np
from PIL import Image
data_dir = './data'
# train_data_dir = "../data/HWDB1.1trn_gnt"
train_data_dir = os.path.join(data_dir, 'HWDB1.1trn_gnt')
test_data_dir = os.path.join(data_dir, 'HWDB1.1tst_gnt')
def read_from_gnt_dir(gnt_dir=test_data_dir):
# 处理单个gnt文件获取图像与标签
def read_from_gnt_dir(gnt_dir):
def one_file(f):
header_size = 10
while True:
header = np.fromfile(f, dtype='uint8', count=header_size)
if not header.size: break
sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)
tagcode = header[5] + (header[4]<<8)
width = header[6] + (header[7]<<8)
height = header[8] + (header[9]<<8)
if header_size + width*height != sample_size:
if not header.size:
break
image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))
yield image, tagcode
sample_size = header[0] + (header[1] << 8) + (header[2] << 16) + (header[3] << 24)
label = header[5] + (header[4] << 8)
width = header[6] + (header[7] << 8)
height = header[8] + (header[9] << 8)
if header_size + width * height != sample_size:
break
image = np.fromfile(f, dtype='uint8', count=width * height).reshape((height, width))
yield image, label
for file_name in os.listdir(gnt_dir):
if file_name.endswith('.gnt'):
file_path = os.path.join(gnt_dir, file_name)
with open(file_path, 'rb') as f:
for image, tagcode in one_file(f):
yield image, tagcode
for image, label in one_file(f):
yield image, label
def gnt_to_img(gnt_dir, img_dir):
counter = 0
for image, label in read_from_gnt_dir(gnt_dir=gnt_dir):
label = struct.pack('>H', label).decode('gb2312')
img = Image.fromarray(image)
dir_name = os.path.join(img_dir, '%0.5d' % char_dict[label])
if not os.path.exists(dir_name):
os.mkdir(dir_name)
img.convert('RGB').save(dir_name + '/' + str(counter) + '.png')
print("train_counter=", counter)
counter += 1
# 路径
data_dir = './data'
train_gnt_dir = os.path.join(data_dir, 'HWDB1.1trn_gnt')
test_gnt_dir = os.path.join(data_dir, 'HWDB1.1tst_gnt')
train_img_dir = os.path.join(data_dir, 'train')
test_img_dir = os.path.join(data_dir, 'test')
if not os.path.exists(train_img_dir):
os.mkdir(train_img_dir)
if not os.path.exists(test_img_dir):
os.mkdir(test_img_dir)
# 获取字符集合
char_set = set()
for _, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir):
for _, tagcode in read_from_gnt_dir(gnt_dir=test_gnt_dir):
tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
char_set.add(tagcode_unicode)
char_list = list(char_set)
char_dict = dict(zip(sorted(char_list), range(len(char_list))))
print(len(char_dict))
print("char_dict=",char_dict)
import pickle
f = open('char_dict', 'wb')
pickle.dump(char_dict, f)
f.close()
train_counter = 0
test_counter = 0
for image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir):
tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
im = Image.fromarray(image)
dir_name = './data/train/' + '%0.5d'%char_dict[tagcode_unicode]
if not os.path.exists(dir_name):
os.mkdir(dir_name)
im.convert('RGB').save(dir_name+'/' + str(train_counter) + '.png')
print("train_counter=",train_counter)
train_counter += 1
# for image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir):
# tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
# im = Image.fromarray(image)
# dir_name = './data/test/' + '%0.5d'%char_dict[tagcode_unicode]
# if not os.path.exists(dir_name):
# os.mkdir(dir_name)
# im.convert('RGB').save(dir_name+'/' + str(test_counter) + '.png')
# print("test_counter=",test_counter)
# test_counter += 1
print("char_dict=", char_dict)
with open('char_dict', 'wb') as f:
pickle.dump(char_dict, f)
train_thread = threading.Thread(target=gnt_to_img, args=(train_gnt_dir, train_img_dir)).start()
test_thread = threading.Thread(target=gnt_to_img, args=(test_gnt_dir, test_img_dir)).start()
train_thread.join()
test_thread.join()

View File

@@ -1,92 +1,85 @@
import pickle
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import pickle
import numpy as np
from tensorboardX import SummaryWriter
from torchvision import transforms
from hwdb import HWDB
from convnet import ConvNet
from model import ConvNet, ConvNet2
def train(net,
criterion,
optimizer,
train_loader,
test_loarder,
epoch=10,
save_path='./pretrained_models/'):
def adjust_learning_rate(optimizer, decay_rate=.9):
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] * decay_rate
def train(net, criterion, optimizer, train_loader, test_loarder, writer, epoch, save_path):
print("开始训练...")
net.train()
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
for epoch in range(epoch):
sum_loss = 0.0
total = 0
correct = 0
if epoch/3 == 1:
adjust_learning_rate(optimizer, 0.5)
# 数据读取
for i, (inputs, labels) in enumerate(train_loader):
# 梯度清零
optimizer.zero_grad()
# forward + backward
if torch.cuda.is_available():
# inputs, labels = Variable(inputs.cuda(0)), Variable(labels.cuda(0))
inputs = inputs.to('cuda')
labels = labels.to('cuda')
#print(inputs.device)
else:
print('cuda not available')
inputs = inputs.cuda()
labels = labels.cuda()
outputs = net(inputs)
loss = criterion(outputs, labels)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
correct += (predicted == labels).sum().item()
loss.backward()
optimizer.step()
#print(loss.item())
# 每训练100个batch打印一次平均loss与acc
sum_loss += loss.item()
# if i % 100 == 99:
if i % 100 == 99:
loss = sum_loss/100
print('epoch: %d, batch: %d loss: %.03f'
% (epoch + 11, i + 1, loss), end=',')
batch_loss = sum_loss / 100
# 每跑完一次epoch测试一下准确率
acc = 100 * correct / total
print('acc%d%%' % (acc))
print('epoch: %d, batch: %d loss: %.03f, acc: %.04f'
% (epoch, i + 1, batch_loss, acc))
writer.add_scalar('train_loss', batch_loss, global_step=i+len(train_loader)*epoch)
writer.add_scalar('train_acc', acc, global_step=i+len(train_loader)*epoch)
for name, layer in net.named_parameters():
writer.add_histogram(name+'_grad', layer.grad.cpu().data.numpy(), global_step=i+len(train_loader)*epoch)
writer.add_histogram(name+'_data', layer.cpu().data.numpy(), global_step=i+len(train_loader)*epoch)
total = 0
correct = 0
sum_loss = 0.0
print("epoch%d 训练结束, 正在保存模型..."%(epoch+11))
torch.save(net.state_dict(), save_path+'handwriting_iter_%03d.pth' % (epoch + 11))
if epoch%3 == 0:
print("epoch%d 训练结束, 正在保存模型..." % (epoch))
torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % (epoch))
if epoch % 3 == 0:
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images, labels = images.to('cuda'), labels.to('cuda')
for images, labels in test_loarder:
images, labels = images.cuda(), labels.cuda()
outputs = net(images)
# 取得分最高的那个类
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print('correct number: ',correct)
correct += (predicted == labels).sum().item()
print('correct number: ', correct)
print('totol number:', total)
acc = 100 * correct / total
print('%d个epoch的识别准确率为%d%%' % (epoch+11, acc))
print('%d个epoch的识别准确率为%d%%' % (epoch, acc))
writer.add_scalar('test_acc', acc, global_step=epoch)
if __name__ == "__main__":
data_path = r'C:\Users\Administrator\Desktop\hand-writing-recognition\data'
save_path = r'checkpoints'
if not os.path.exists(save_path):
os.mkdir(save_path)
# 超参数
epoch = 20
batch_size = 100
lr = 0.02
# 读取分类类别
f = open('char_dict', 'rb')
@@ -94,22 +87,26 @@ if __name__ == "__main__":
num_classes = len(class_dict)
# 读取数据
dataset = HWDB()
transform = transforms.Compose([
# transforms.Grayscale(),
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
dataset = HWDB(transform=transform, path=data_path)
print("训练集数据:", dataset.train_size)
print("测试集数据:", dataset.test_size)
train_loader, test_loader = dataset.get_loader(batch_size)
trainloader, testloader = dataset.get_loader(batch_size)
net = ConvNet(num_classes)
net = ConvNet2(num_classes)
print('网络结构:\n', net)
if torch.cuda.is_available():
net = net.cuda(0)
net = net.cuda()
else:
print('cuda not available')
net.load_state_dict(torch.load('./pretrained_models/handwriting_iter_010.pth'))
# net.load_state_dict(torch.load('./pretrained_models/handwriting_iter_010.pth'))
criterion = nn.CrossEntropyLoss()
#optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer = optim.RMSprop(net.parameters(), lr=0.000005, momentum=0.9, weight_decay=0.0005)
train(net, criterion, optimizer, train_loader, test_loader)
optimizer = optim.SGD(net.parameters(), lr=lr)
writer = SummaryWriter('logs/model/batch{}_lr{}'.format(batch_size, lr))
# writer = SummaryWriter('logs/model_dw_res/batch{}_lr{}'.format(batch_size, lr))
train(net, criterion, optimizer, trainloader, testloader, writer=writer, epoch=100, save_path=save_path)