Files
hand-writing-recognition/hwdb.py
2018-11-21 09:54:09 +08:00

61 lines
1.7 KiB
Python

import os
import torch
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
class HWDB(object):
def __init__(self, path='./data'):
# 预处理过程
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Lambda(lambda x: Image.fromarray(255 - np.array(x))),
transforms.CenterCrop(64),
transforms.ToTensor(),
])
#
traindir = os.path.join(path, 'train')
testdir = os.path.join(path, 'test')
self.trainset = datasets.ImageFolder(traindir, transform)
self.testset = datasets.ImageFolder(testdir, transform)
self.train_size = len(self.trainset)
self.test_size = len(self.testset)
def get_sample(self, index=0):
sample = self.trainset[index]
sample_img, sample_label = sample
print(sample_img.size())
return sample_img, sample_label
def get_loader(self, batch_size=100):
train_loader = torch.utils.data.DataLoader(
self.trainset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
self.testset, batch_size=batch_size, shuffle=True)
return train_loader, test_loader
if __name__ == '__main__':
dataset = HWDB()
for i in [1, 10, 2000, 6000, 1000]:
img, label = dataset.get_sample(i)
img = img[0]
plt.imshow(img, cmap='gray')
plt.show()
train_loader, test_loader = dataset.get_loader()
for (img, label) in train_loader:
print(img)
print(label)