update code and readme
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# hand-writing-recognition
|
||||
基于pytorch卷积神经网络的手写汉字识别,使用HWDB数据库
|
||||
基于pytorch卷积神经网络的中文手写汉字识别,使用HWDB数据库
|
||||

|
||||
|
||||
## Dependence
|
||||
* PIL
|
||||
|
||||
3
hwdb.py
3
hwdb.py
@@ -34,7 +34,7 @@ if __name__ == '__main__':
|
||||
transforms.Resize((64, 64)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
dataset = HWDB(transform=transform, path=r'C:\Users\Administrator\Desktop\hand-writing-recognition\data')
|
||||
dataset = HWDB(transform=transform, path=r'data')
|
||||
print(dataset.train_size)
|
||||
print(dataset.test_size)
|
||||
for i in [1020, 120, 2000, 6000, 1000]:
|
||||
@@ -45,7 +45,6 @@ if __name__ == '__main__':
|
||||
plt.show()
|
||||
|
||||
train_loader, test_loader = dataset.get_loader()
|
||||
print(len(train_loader))
|
||||
# for (img, label) in train_loader:
|
||||
# print(img)
|
||||
# print(label)
|
||||
|
||||
57
model.py
57
model.py
@@ -1,7 +1,5 @@
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import init
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -46,7 +44,7 @@ class ConvNet(nn.Module):
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(4096, num_classes),
|
||||
)
|
||||
# self.weight_init()
|
||||
self.weight_init()
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.conv1(x)
|
||||
@@ -85,56 +83,3 @@ class ConvNet(nn.Module):
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
class ConvNet2(nn.Module):
|
||||
def __init__(self, num_classes):
|
||||
super(ConvNet2, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
||||
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
||||
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512 * 4 * 4, 4096),
|
||||
# nn.Dropout(),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(4096, num_classes),
|
||||
)
|
||||
self.weight_init()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def weight_init(self):
|
||||
for layer in self.features:
|
||||
self._layer_init(layer)
|
||||
for layer in self.classifier:
|
||||
self._layer_init(layer)
|
||||
|
||||
def _layer_init(self, m):
|
||||
# 使用isinstance来判断m属于什么类型
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, np.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
# n = m.weight.size(1)
|
||||
m.weight.data.normal_(0, 0.01)
|
||||
m.bias.data.zero_()
|
||||
# init.xavier_normal_(m.weight)
|
||||
|
||||
@@ -5,6 +5,7 @@ import threading
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
# 处理单个gnt文件获取图像与标签
|
||||
@@ -33,43 +34,53 @@ def read_from_gnt_dir(gnt_dir):
|
||||
|
||||
|
||||
def gnt_to_img(gnt_dir, img_dir):
|
||||
counter = 0
|
||||
for image, label in read_from_gnt_dir(gnt_dir=gnt_dir):
|
||||
def save_img(label, image, counter):
|
||||
label = struct.pack('>H', label).decode('gb2312')
|
||||
img = Image.fromarray(image)
|
||||
dir_name = os.path.join(img_dir, '%0.5d' % char_dict[label])
|
||||
if not os.path.exists(dir_name):
|
||||
os.mkdir(dir_name)
|
||||
img.convert('RGB').save(dir_name + '/' + str(counter) + '.png')
|
||||
print("train_counter=", counter)
|
||||
print("thread: {}, counter=".format(threading.current_thread().name), counter)
|
||||
|
||||
counter = 0
|
||||
thread_pool = ThreadPoolExecutor(4) # 定义4个线程执行此任务
|
||||
for image, label in read_from_gnt_dir(gnt_dir=gnt_dir):
|
||||
thread_pool.submit(save_img, label, image, counter)
|
||||
counter += 1
|
||||
thread_pool.shutdown()
|
||||
|
||||
|
||||
# 路径
|
||||
data_dir = './data'
|
||||
train_gnt_dir = os.path.join(data_dir, 'HWDB1.1trn_gnt')
|
||||
test_gnt_dir = os.path.join(data_dir, 'HWDB1.1tst_gnt')
|
||||
train_img_dir = os.path.join(data_dir, 'train')
|
||||
test_img_dir = os.path.join(data_dir, 'test')
|
||||
if not os.path.exists(train_img_dir):
|
||||
os.mkdir(train_img_dir)
|
||||
if not os.path.exists(test_img_dir):
|
||||
os.mkdir(test_img_dir)
|
||||
if __name__ == "__main__":
|
||||
# 路径
|
||||
data_dir = r'./data'
|
||||
train_gnt_dir = os.path.join(data_dir, 'HWDB1.1trn_gnt')
|
||||
test_gnt_dir = os.path.join(data_dir, 'HWDB1.1tst_gnt')
|
||||
train_img_dir = os.path.join(data_dir, 'train')
|
||||
test_img_dir = os.path.join(data_dir, 'test')
|
||||
if not os.path.exists(train_img_dir):
|
||||
os.mkdir(train_img_dir)
|
||||
if not os.path.exists(test_img_dir):
|
||||
os.mkdir(test_img_dir)
|
||||
|
||||
# 获取字符集合
|
||||
char_set = set()
|
||||
for _, tagcode in read_from_gnt_dir(gnt_dir=test_gnt_dir):
|
||||
tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
|
||||
char_set.add(tagcode_unicode)
|
||||
char_list = list(char_set)
|
||||
char_dict = dict(zip(sorted(char_list), range(len(char_list))))
|
||||
print(len(char_dict))
|
||||
print("char_dict=", char_dict)
|
||||
# 获取字符集合
|
||||
if not os.path.exists('char_dict'):
|
||||
char_set = set()
|
||||
for _, tagcode in read_from_gnt_dir(gnt_dir=test_gnt_dir):
|
||||
tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
|
||||
char_set.add(tagcode_unicode)
|
||||
char_list = list(char_set)
|
||||
char_dict = dict(zip(sorted(char_list), range(len(char_list))))
|
||||
print(len(char_dict))
|
||||
print("char_dict=", char_dict)
|
||||
|
||||
with open('char_dict', 'wb') as f:
|
||||
pickle.dump(char_dict, f)
|
||||
with open('char_dict', 'wb') as f:
|
||||
pickle.dump(char_dict, f)
|
||||
else:
|
||||
with open('char_dict', 'rb') as f:
|
||||
char_dict = pickle.load(f)
|
||||
|
||||
train_thread = threading.Thread(target=gnt_to_img, args=(train_gnt_dir, train_img_dir)).start()
|
||||
test_thread = threading.Thread(target=gnt_to_img, args=(test_gnt_dir, test_img_dir)).start()
|
||||
train_thread.join()
|
||||
test_thread.join()
|
||||
train_thread = threading.Thread(target=gnt_to_img, args=(train_gnt_dir, train_img_dir)).start()
|
||||
test_thread = threading.Thread(target=gnt_to_img, args=(test_gnt_dir, test_img_dir)).start()
|
||||
train_thread.join()
|
||||
test_thread.join()
|
||||
|
||||
123
train.py
123
train.py
@@ -8,76 +8,78 @@ from tensorboardX import SummaryWriter
|
||||
from torchvision import transforms
|
||||
|
||||
from hwdb import HWDB
|
||||
from model import ConvNet, ConvNet2
|
||||
from model import ConvNet
|
||||
|
||||
|
||||
def train(net, criterion, optimizer, train_loader, test_loarder, writer, epoch, save_path):
|
||||
print("开始训练...")
|
||||
net.train()
|
||||
for epoch in range(epoch):
|
||||
sum_loss = 0.0
|
||||
total = 0
|
||||
def valid(epoch, net, test_loarder, writer):
|
||||
print("epoch %d 开始验证..." % epoch)
|
||||
with torch.no_grad():
|
||||
correct = 0
|
||||
# 数据读取
|
||||
for i, (inputs, labels) in enumerate(train_loader):
|
||||
# 梯度清零
|
||||
optimizer.zero_grad()
|
||||
if torch.cuda.is_available():
|
||||
inputs = inputs.cuda()
|
||||
labels = labels.cuda()
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
total = 0
|
||||
for images, labels in test_loarder:
|
||||
images, labels = images.cuda(), labels.cuda()
|
||||
outputs = net(images)
|
||||
# 取得分最高的那个类
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
print('correct number: ', correct)
|
||||
print('totol number:', total)
|
||||
acc = 100 * correct / total
|
||||
print('第%d个epoch的识别准确率为:%d%%' % (epoch, acc))
|
||||
writer.add_scalar('valid_acc', acc, global_step=epoch)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 每训练100个batch打印一次平均loss与acc
|
||||
sum_loss += loss.item()
|
||||
if i % 100 == 99:
|
||||
batch_loss = sum_loss / 100
|
||||
# 每跑完一次epoch测试一下准确率
|
||||
acc = 100 * correct / total
|
||||
print('epoch: %d, batch: %d loss: %.03f, acc: %.04f'
|
||||
% (epoch, i + 1, batch_loss, acc))
|
||||
writer.add_scalar('train_loss', batch_loss, global_step=i+len(train_loader)*epoch)
|
||||
writer.add_scalar('train_acc', acc, global_step=i+len(train_loader)*epoch)
|
||||
for name, layer in net.named_parameters():
|
||||
writer.add_histogram(name+'_grad', layer.grad.cpu().data.numpy(), global_step=i+len(train_loader)*epoch)
|
||||
writer.add_histogram(name+'_data', layer.cpu().data.numpy(), global_step=i+len(train_loader)*epoch)
|
||||
total = 0
|
||||
correct = 0
|
||||
sum_loss = 0.0
|
||||
def train(epoch, net, criterion, optimizer, train_loader, writer, save_iter=100):
|
||||
print("epoch %d 开始训练..." % epoch)
|
||||
net.train()
|
||||
sum_loss = 0.0
|
||||
total = 0
|
||||
correct = 0
|
||||
# 数据读取
|
||||
for i, (inputs, labels) in enumerate(train_loader):
|
||||
# 梯度清零
|
||||
optimizer.zero_grad()
|
||||
if torch.cuda.is_available():
|
||||
inputs = inputs.cuda()
|
||||
labels = labels.cuda()
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
# 取得分最高的那个类
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
print("epoch%d 训练结束, 正在保存模型..." % (epoch))
|
||||
torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % (epoch))
|
||||
if epoch % 3 == 0:
|
||||
with torch.no_grad():
|
||||
correct = 0
|
||||
total = 0
|
||||
for images, labels in test_loarder:
|
||||
images, labels = images.cuda(), labels.cuda()
|
||||
outputs = net(images)
|
||||
# 取得分最高的那个类
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
print('correct number: ', correct)
|
||||
print('totol number:', total)
|
||||
acc = 100 * correct / total
|
||||
print('第%d个epoch的识别准确率为:%d%%' % (epoch, acc))
|
||||
writer.add_scalar('test_acc', acc, global_step=epoch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 每训练100个batch打印一次平均loss与acc
|
||||
sum_loss += loss.item()
|
||||
if (i + 1) % save_iter == 0:
|
||||
batch_loss = sum_loss / save_iter
|
||||
# 每跑完一次epoch测试一下准确率
|
||||
acc = 100 * correct / total
|
||||
print('epoch: %d, batch: %d loss: %.03f, acc: %.04f'
|
||||
% (epoch, i + 1, batch_loss, acc))
|
||||
writer.add_scalar('train_loss', batch_loss, global_step=i + len(train_loader) * epoch)
|
||||
writer.add_scalar('train_acc', acc, global_step=i + len(train_loader) * epoch)
|
||||
for name, layer in net.named_parameters():
|
||||
writer.add_histogram(name + '_grad', layer.grad.cpu().data.numpy(),
|
||||
global_step=i + len(train_loader) * epoch)
|
||||
writer.add_histogram(name + '_data', layer.cpu().data.numpy(),
|
||||
global_step=i + len(train_loader) * epoch)
|
||||
total = 0
|
||||
correct = 0
|
||||
sum_loss = 0.0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_path = r'C:\Users\Administrator\Desktop\hand-writing-recognition\data'
|
||||
save_path = r'checkpoints'
|
||||
data_path = r'data'
|
||||
save_path = r'checkpoints/'
|
||||
if not os.path.exists(save_path):
|
||||
os.mkdir(save_path)
|
||||
# 超参数
|
||||
epoch = 20
|
||||
epochs = 20
|
||||
batch_size = 100
|
||||
lr = 0.02
|
||||
|
||||
@@ -103,10 +105,11 @@ if __name__ == "__main__":
|
||||
net = net.cuda()
|
||||
else:
|
||||
print('cuda not available')
|
||||
# net.load_state_dict(torch.load('./pretrained_models/handwriting_iter_010.pth'))
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(net.parameters(), lr=lr)
|
||||
writer = SummaryWriter('logs/model/batch{}_lr{}'.format(batch_size, lr))
|
||||
# writer = SummaryWriter('logs/model_dw_res/batch{}_lr{}'.format(batch_size, lr))
|
||||
train(net, criterion, optimizer, trainloader, testloader, writer=writer, epoch=100, save_path=save_path)
|
||||
|
||||
writer = SummaryWriter('logs/batch_100_{}_lr{}'.format(batch_size, lr))
|
||||
for epoch in range(epochs):
|
||||
train(epoch, net, criterion, optimizer, trainloader, writer=writer)
|
||||
valid(epoch, net, testloader, writer=writer)
|
||||
print("epoch%d 结束, 正在保存模型..." % (epoch))
|
||||
torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % (epoch))
|
||||
|
||||
Reference in New Issue
Block a user