first commit

This commit is contained in:
JiageWang
2018-11-21 09:54:09 +08:00
commit 9d75a26c5c
4 changed files with 298 additions and 0 deletions

60
hwdb.py Normal file
View File

@@ -0,0 +1,60 @@
import os
import torch
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
class HWDB(object):
def __init__(self, path='./data'):
# 预处理过程
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Lambda(lambda x: Image.fromarray(255 - np.array(x))),
transforms.CenterCrop(64),
transforms.ToTensor(),
])
#
traindir = os.path.join(path, 'train')
testdir = os.path.join(path, 'test')
self.trainset = datasets.ImageFolder(traindir, transform)
self.testset = datasets.ImageFolder(testdir, transform)
self.train_size = len(self.trainset)
self.test_size = len(self.testset)
def get_sample(self, index=0):
sample = self.trainset[index]
sample_img, sample_label = sample
print(sample_img.size())
return sample_img, sample_label
def get_loader(self, batch_size=100):
train_loader = torch.utils.data.DataLoader(
self.trainset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
self.testset, batch_size=batch_size, shuffle=True)
return train_loader, test_loader
if __name__ == '__main__':
dataset = HWDB()
for i in [1, 10, 2000, 6000, 1000]:
img, label = dataset.get_sample(i)
img = img[0]
plt.imshow(img, cmap='gray')
plt.show()
train_loader, test_loader = dataset.get_loader()
for (img, label) in train_loader:
print(img)
print(label)

57
model.py Normal file
View File

@@ -0,0 +1,57 @@
import torch.nn as nn
from torch.nn import init
from torchvision.models.vgg import vgg16_bn
import numpy as np
class ConvNet(nn.Module):
def __init__(self, num_classes):
super(ConvNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(512*4*4, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, num_classes),
)
self.weight_init()
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def weight_init(self):
for layer in self.features:
self._layer_init(layer)
for layer in self.classifier:
self._layer_init(layer)
def _layer_init(self, m):
# 使用isinstance来判断m属于什么类型
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, np.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
# m中的weightbias其实都是Variable为了能学习参数以及后向传播
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.xavier_normal(m.weight)

66
process_gnt.py Normal file
View File

@@ -0,0 +1,66 @@
import os
import numpy as np
import struct
from PIL import Image
data_dir = './data'
# train_data_dir = "../data/HWDB1.1trn_gnt"
train_data_dir = os.path.join(data_dir, 'HWDB1.1trn_gnt')
test_data_dir = os.path.join(data_dir, 'HWDB1.1tst_gnt')
def read_from_gnt_dir(gnt_dir=test_data_dir):
def one_file(f):
header_size = 10
while True:
header = np.fromfile(f, dtype='uint8', count=header_size)
if not header.size: break
sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)
tagcode = header[5] + (header[4]<<8)
width = header[6] + (header[7]<<8)
height = header[8] + (header[9]<<8)
if header_size + width*height != sample_size:
break
image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))
yield image, tagcode
for file_name in os.listdir(gnt_dir):
if file_name.endswith('.gnt'):
file_path = os.path.join(gnt_dir, file_name)
with open(file_path, 'rb') as f:
for image, tagcode in one_file(f):
yield image, tagcode
char_set = set()
for _, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir):
tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
char_set.add(tagcode_unicode)
char_list = list(char_set)
char_dict = dict(zip(sorted(char_list), range(len(char_list))))
print(len(char_dict))
print("char_dict=",char_dict)
import pickle
f = open('char_dict', 'wb')
pickle.dump(char_dict, f)
f.close()
train_counter = 0
test_counter = 0
for image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir):
tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
im = Image.fromarray(image)
dir_name = './data/train/' + '%0.5d'%char_dict[tagcode_unicode]
if not os.path.exists(dir_name):
os.mkdir(dir_name)
im.convert('RGB').save(dir_name+'/' + str(train_counter) + '.png')
print("train_counter=",train_counter)
train_counter += 1
# for image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir):
# tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
# im = Image.fromarray(image)
# dir_name = './data/test/' + '%0.5d'%char_dict[tagcode_unicode]
# if not os.path.exists(dir_name):
# os.mkdir(dir_name)
# im.convert('RGB').save(dir_name+'/' + str(test_counter) + '.png')
# print("test_counter=",test_counter)
# test_counter += 1

115
train.py Normal file
View File

@@ -0,0 +1,115 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import pickle
import numpy as np
from hwdb import HWDB
from convnet import ConvNet
def train(net,
criterion,
optimizer,
train_loader,
test_loarder,
epoch=10,
save_path='./pretrained_models/'):
def adjust_learning_rate(optimizer, decay_rate=.9):
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] * decay_rate
print("开始训练...")
net.train()
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
for epoch in range(epoch):
sum_loss = 0.0
total = 0
correct = 0
if epoch/3 == 1:
adjust_learning_rate(optimizer, 0.5)
# 数据读取
for i, (inputs, labels) in enumerate(train_loader):
# 梯度清零
optimizer.zero_grad()
# forward + backward
if torch.cuda.is_available():
# inputs, labels = Variable(inputs.cuda(0)), Variable(labels.cuda(0))
inputs = inputs.to('cuda')
labels = labels.to('cuda')
#print(inputs.device)
else:
print('cuda not available')
outputs = net(inputs)
loss = criterion(outputs, labels)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
loss.backward()
optimizer.step()
#print(loss.item())
# 每训练100个batch打印一次平均loss与acc
sum_loss += loss.item()
# if i % 100 == 99:
if i % 100 == 99:
loss = sum_loss/100
print('epoch: %d, batch: %d loss: %.03f'
% (epoch + 11, i + 1, loss), end=',')
# 每跑完一次epoch测试一下准确率
acc = 100 * correct / total
print('acc%d%%' % (acc))
total = 0
correct = 0
sum_loss = 0.0
print("epoch%d 训练结束, 正在保存模型..."%(epoch+11))
torch.save(net.state_dict(), save_path+'handwriting_iter_%03d.pth' % (epoch + 11))
if epoch%3 == 0:
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images, labels = images.to('cuda'), labels.to('cuda')
outputs = net(images)
# 取得分最高的那个类
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print('correct number: ',correct)
print('totol number:', total)
acc = 100 * correct / total
print('%d个epoch的识别准确率为%d%%' % (epoch+11, acc))
if __name__ == "__main__":
# 超参数
batch_size = 100
# 读取分类类别
f = open('char_dict', 'rb')
class_dict = pickle.load(f)
num_classes = len(class_dict)
# 读取数据
dataset = HWDB()
print("训练集数据:", dataset.train_size)
print("测试集数据:", dataset.test_size)
train_loader, test_loader = dataset.get_loader(batch_size)
net = ConvNet(num_classes)
print('网络结构:\n', net)
if torch.cuda.is_available():
net = net.cuda(0)
else:
print('cuda not available')
net.load_state_dict(torch.load('./pretrained_models/handwriting_iter_010.pth'))
criterion = nn.CrossEntropyLoss()
#optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer = optim.RMSprop(net.parameters(), lr=0.000005, momentum=0.9, weight_decay=0.0005)
train(net, criterion, optimizer, train_loader, test_loader)