61 lines
1.7 KiB
Python
61 lines
1.7 KiB
Python
import os
|
|
import torch
|
|
import torch.utils.data
|
|
import torchvision.transforms as transforms
|
|
import torchvision.datasets as datasets
|
|
from PIL import Image
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
class HWDB(object):
|
|
def __init__(self, path='./data'):
|
|
# 预处理过程
|
|
transform = transforms.Compose([
|
|
transforms.Grayscale(),
|
|
transforms.Lambda(lambda x: Image.fromarray(255 - np.array(x))),
|
|
transforms.CenterCrop(64),
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
#
|
|
traindir = os.path.join(path, 'train')
|
|
testdir = os.path.join(path, 'test')
|
|
|
|
self.trainset = datasets.ImageFolder(traindir, transform)
|
|
self.testset = datasets.ImageFolder(testdir, transform)
|
|
self.train_size = len(self.trainset)
|
|
self.test_size = len(self.testset)
|
|
|
|
def get_sample(self, index=0):
|
|
sample = self.trainset[index]
|
|
sample_img, sample_label = sample
|
|
print(sample_img.size())
|
|
return sample_img, sample_label
|
|
|
|
def get_loader(self, batch_size=100):
|
|
train_loader = torch.utils.data.DataLoader(
|
|
self.trainset, batch_size=batch_size, shuffle=True)
|
|
test_loader = torch.utils.data.DataLoader(
|
|
self.testset, batch_size=batch_size, shuffle=True)
|
|
return train_loader, test_loader
|
|
|
|
|
|
if __name__ == '__main__':
|
|
dataset = HWDB()
|
|
for i in [1, 10, 2000, 6000, 1000]:
|
|
img, label = dataset.get_sample(i)
|
|
img = img[0]
|
|
plt.imshow(img, cmap='gray')
|
|
plt.show()
|
|
|
|
train_loader, test_loader = dataset.get_loader()
|
|
for (img, label) in train_loader:
|
|
print(img)
|
|
print(label)
|
|
|
|
|
|
|
|
|
|
|
|
|