update code and readme

This commit is contained in:
JiageWang
2019-08-24 21:26:06 +08:00
parent bff40d9966
commit 90ef0ab210
5 changed files with 238 additions and 144 deletions

61
hwdb.py
View File

@@ -1,23 +1,14 @@
import os
import torch
import torch.utils.data
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
class HWDB(object):
def __init__(self, path='./data'):
# 预处理过程
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Lambda(lambda x: Image.fromarray(255 - np.array(x))),
transforms.CenterCrop(64),
transforms.ToTensor(),
])
#
class HWDB(object):
def __init__(self, transform, path='./data'):
# 预处理过程
traindir = os.path.join(path, 'train')
testdir = os.path.join(path, 'test')
@@ -29,32 +20,32 @@ class HWDB(object):
def get_sample(self, index=0):
sample = self.trainset[index]
sample_img, sample_label = sample
print(sample_img.size())
return sample_img, sample_label
def get_loader(self, batch_size=100):
train_loader = torch.utils.data.DataLoader(
self.trainset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
self.testset, batch_size=batch_size, shuffle=True)
return train_loader, test_loader
trainloader = DataLoader(self.trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(self.testset, batch_size=batch_size, shuffle=True)
return trainloader, testloader
if __name__ == '__main__':
dataset = HWDB()
for i in [1, 10, 2000, 6000, 1000]:
img, label = dataset.get_sample(i)
img = img[0]
plt.imshow(img, cmap='gray')
plt.show()
transform = transforms.Compose([
# transforms.Grayscale(),
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
dataset = HWDB(transform=transform, path=r'C:\Users\Administrator\Desktop\hand-writing-recognition\data')
print(dataset.train_size)
print(dataset.test_size)
# for i in [1020, 120, 2000, 6000, 1000]:
# img, label = dataset.get_sample(i)
# img = img[0]
# print(label)
# plt.imshow(img, cmap='gray')
# plt.show()
train_loader, test_loader = dataset.get_loader()
for (img, label) in train_loader:
print(img)
print(label)
print(len(train_loader))
# for (img, label) in train_loader:
# print(img)
# print(label)