update code and readme
This commit is contained in:
28
hwdb.py
28
hwdb.py
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import random
|
||||
from torch.utils.data import DataLoader
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
@@ -6,7 +7,7 @@ import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class HWDB(object):
|
||||
def __init__(self, transform, path='./data'):
|
||||
def __init__(self,path, transform):
|
||||
# 预处理过程
|
||||
|
||||
traindir = os.path.join(path, 'train')
|
||||
@@ -16,6 +17,8 @@ class HWDB(object):
|
||||
self.testset = datasets.ImageFolder(testdir, transform)
|
||||
self.train_size = len(self.trainset)
|
||||
self.test_size = len(self.testset)
|
||||
self.num_classes = len(self.trainset.classes)
|
||||
self.class_to_idx = self.trainset.class_to_idx
|
||||
|
||||
def get_sample(self, index=0):
|
||||
sample = self.trainset[index]
|
||||
@@ -30,21 +33,14 @@ class HWDB(object):
|
||||
|
||||
if __name__ == '__main__':
|
||||
transform = transforms.Compose([
|
||||
# transforms.Grayscale(),
|
||||
transforms.Resize((64, 64)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
dataset = HWDB(transform=transform, path=r'data')
|
||||
print(dataset.train_size)
|
||||
print(dataset.test_size)
|
||||
for i in [1020, 120, 2000, 6000, 1000]:
|
||||
img, label = dataset.get_sample(i)
|
||||
img = img[0]
|
||||
print(label)
|
||||
plt.imshow(img, cmap='gray')
|
||||
plt.show()
|
||||
|
||||
train_loader, test_loader = dataset.get_loader()
|
||||
# for (img, label) in train_loader:
|
||||
# print(img)
|
||||
# print(label)
|
||||
dataset = HWDB(path=r'data', transform=transform)
|
||||
print("训练集数量:", dataset.train_size)
|
||||
print("测试集数量:", dataset.test_size)
|
||||
print("类别数量:", dataset.num_classes)
|
||||
index = random.randint(0, dataset.train_size)
|
||||
img = dataset.get_sample(index)[0][0]
|
||||
plt.imshow(img, cmap='gray')
|
||||
plt.show()
|
||||
|
||||
22
model.py
22
model.py
@@ -1,6 +1,8 @@
|
||||
import numpy as np
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from torchsummary import summary
|
||||
|
||||
|
||||
def conv_dw(inp, oup, stride):
|
||||
@@ -38,11 +40,15 @@ class ConvNet(nn.Module):
|
||||
self.conv10 = conv_dw(128, 128, 1) # 16x16x128
|
||||
self.conv11 = conv_dw(128, 128, 1) # 16x16x128
|
||||
self.conv12 = conv_dw(128, 256, 2) # 8x8x256
|
||||
self.conv13 = conv_dw(256, 256, 1) # 8x8x256
|
||||
self.conv14 = conv_dw(256, 256, 1) # 8x8x256
|
||||
self.conv15 = conv_dw(256, 512, 2) # 4x4x512
|
||||
self.conv16 = conv_dw(512, 512, 1) # 4x4x512
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(256 * 8 * 8, 4096),
|
||||
nn.Linear(512*4*4, 1024),
|
||||
nn.Dropout(0.2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(4096, num_classes),
|
||||
nn.Linear(1024, num_classes),
|
||||
)
|
||||
self.weight_init()
|
||||
|
||||
@@ -61,7 +67,12 @@ class ConvNet(nn.Module):
|
||||
x11 = self.conv11(x10)
|
||||
x11 = F.relu(x10 + x11)
|
||||
x12 = self.conv12(x11)
|
||||
x = x12.view(x12.size(0), -1)
|
||||
x13 = self.conv13(x12)
|
||||
x14 = self.conv14(x13)
|
||||
x14 = F.relu(x13 + x14)
|
||||
x15 = self.conv15(x14)
|
||||
x16 = self.conv16(x15)
|
||||
x = x16.view(x16.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
@@ -83,3 +94,6 @@ class ConvNet(nn.Module):
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = ConvNet(3755).cuda()
|
||||
summary(model, input_size=(3, 64, 64), device='cuda')
|
||||
|
||||
27
train.py
27
train.py
@@ -4,8 +4,10 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
from torchvision import transforms
|
||||
from torchsummary import summary
|
||||
|
||||
from hwdb import HWDB
|
||||
from model import ConvNet
|
||||
@@ -74,42 +76,43 @@ def train(epoch, net, criterion, optimizer, train_loader, writer, save_iter=100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_path = r'data'
|
||||
data_path = r'C:\Users\Administrator\Desktop\hand-writing-recognition\data'
|
||||
log_path = r'logs/batch_100_100_lr0.02'
|
||||
save_path = r'checkpoints/'
|
||||
if not os.path.exists(save_path):
|
||||
os.mkdir(save_path)
|
||||
# 超参数
|
||||
epochs = 20
|
||||
batch_size = 100
|
||||
lr = 0.02
|
||||
lr = 0.0001
|
||||
|
||||
# 读取分类类别
|
||||
f = open('char_dict', 'rb')
|
||||
class_dict = pickle.load(f)
|
||||
with open('char_dict', 'rb') as f:
|
||||
class_dict = pickle.load(f)
|
||||
num_classes = len(class_dict)
|
||||
|
||||
# 读取数据
|
||||
transform = transforms.Compose([
|
||||
# transforms.Grayscale(),
|
||||
transforms.Resize((64, 64)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
dataset = HWDB(transform=transform, path=data_path)
|
||||
dataset = HWDB(path=data_path, transform=transform)
|
||||
print("训练集数据:", dataset.train_size)
|
||||
print("测试集数据:", dataset.test_size)
|
||||
trainloader, testloader = dataset.get_loader(batch_size)
|
||||
|
||||
net = ConvNet(num_classes)
|
||||
print('网络结构:\n', net)
|
||||
# net.load_state_dict(torch.load('checkpoints/handwriting_iter_004.pth'))
|
||||
|
||||
print('网络结构:\n')
|
||||
summary(net, input_size=(3, 64, 64), device='cuda')
|
||||
if torch.cuda.is_available():
|
||||
net = net.cuda()
|
||||
else:
|
||||
print('cuda not available')
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(net.parameters(), lr=lr)
|
||||
writer = SummaryWriter('logs/batch_100_{}_lr{}'.format(batch_size, lr))
|
||||
writer = SummaryWriter(log_path)
|
||||
for epoch in range(epochs):
|
||||
train(epoch, net, criterion, optimizer, trainloader, writer=writer)
|
||||
valid(epoch, net, testloader, writer=writer)
|
||||
print("epoch%d 结束, 正在保存模型..." % (epoch))
|
||||
torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % (epoch))
|
||||
print("epoch%d 结束, 正在保存模型..." % epoch)
|
||||
torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % epoch)
|
||||
|
||||
Reference in New Issue
Block a user