update code and readme
This commit is contained in:
61
hwdb.py
61
hwdb.py
@@ -1,23 +1,14 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from torch.utils.data import DataLoader
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
class HWDB(object):
|
||||
def __init__(self, path='./data'):
|
||||
# 预处理过程
|
||||
transform = transforms.Compose([
|
||||
transforms.Grayscale(),
|
||||
transforms.Lambda(lambda x: Image.fromarray(255 - np.array(x))),
|
||||
transforms.CenterCrop(64),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
#
|
||||
class HWDB(object):
|
||||
def __init__(self, transform, path='./data'):
|
||||
# 预处理过程
|
||||
|
||||
traindir = os.path.join(path, 'train')
|
||||
testdir = os.path.join(path, 'test')
|
||||
|
||||
@@ -29,32 +20,32 @@ class HWDB(object):
|
||||
def get_sample(self, index=0):
|
||||
sample = self.trainset[index]
|
||||
sample_img, sample_label = sample
|
||||
print(sample_img.size())
|
||||
return sample_img, sample_label
|
||||
|
||||
def get_loader(self, batch_size=100):
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
self.trainset, batch_size=batch_size, shuffle=True)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
self.testset, batch_size=batch_size, shuffle=True)
|
||||
return train_loader, test_loader
|
||||
trainloader = DataLoader(self.trainset, batch_size=batch_size, shuffle=True)
|
||||
testloader = DataLoader(self.testset, batch_size=batch_size, shuffle=True)
|
||||
return trainloader, testloader
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset = HWDB()
|
||||
for i in [1, 10, 2000, 6000, 1000]:
|
||||
img, label = dataset.get_sample(i)
|
||||
img = img[0]
|
||||
plt.imshow(img, cmap='gray')
|
||||
plt.show()
|
||||
transform = transforms.Compose([
|
||||
# transforms.Grayscale(),
|
||||
transforms.Resize((64, 64)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
dataset = HWDB(transform=transform, path=r'C:\Users\Administrator\Desktop\hand-writing-recognition\data')
|
||||
print(dataset.train_size)
|
||||
print(dataset.test_size)
|
||||
# for i in [1020, 120, 2000, 6000, 1000]:
|
||||
# img, label = dataset.get_sample(i)
|
||||
# img = img[0]
|
||||
# print(label)
|
||||
# plt.imshow(img, cmap='gray')
|
||||
# plt.show()
|
||||
|
||||
train_loader, test_loader = dataset.get_loader()
|
||||
for (img, label) in train_loader:
|
||||
print(img)
|
||||
print(label)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
print(len(train_loader))
|
||||
# for (img, label) in train_loader:
|
||||
# print(img)
|
||||
# print(label)
|
||||
|
||||
Reference in New Issue
Block a user