first commit

This commit is contained in:
JiageWang
2018-11-21 09:54:09 +08:00
commit 9d75a26c5c
4 changed files with 298 additions and 0 deletions

60
hwdb.py Normal file
View File

@@ -0,0 +1,60 @@
import os
import torch
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
class HWDB(object):
def __init__(self, path='./data'):
# 预处理过程
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Lambda(lambda x: Image.fromarray(255 - np.array(x))),
transforms.CenterCrop(64),
transforms.ToTensor(),
])
#
traindir = os.path.join(path, 'train')
testdir = os.path.join(path, 'test')
self.trainset = datasets.ImageFolder(traindir, transform)
self.testset = datasets.ImageFolder(testdir, transform)
self.train_size = len(self.trainset)
self.test_size = len(self.testset)
def get_sample(self, index=0):
sample = self.trainset[index]
sample_img, sample_label = sample
print(sample_img.size())
return sample_img, sample_label
def get_loader(self, batch_size=100):
train_loader = torch.utils.data.DataLoader(
self.trainset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
self.testset, batch_size=batch_size, shuffle=True)
return train_loader, test_loader
if __name__ == '__main__':
dataset = HWDB()
for i in [1, 10, 2000, 6000, 1000]:
img, label = dataset.get_sample(i)
img = img[0]
plt.imshow(img, cmap='gray')
plt.show()
train_loader, test_loader = dataset.get_loader()
for (img, label) in train_loader:
print(img)
print(label)