This commit is contained in:
JiageWang
2019-08-27 10:23:57 +08:00
parent 43ea8da649
commit 2a1fc96ee2

View File

@@ -76,7 +76,7 @@ def train(epoch, net, criterion, optimizer, train_loader, writer, save_iter=100)
if __name__ == "__main__":
data_path = r'C:\Users\Administrator\Desktop\hand-writing-recognition\data'
data_path = r'data'
log_path = r'logs/batch_100_100_lr0.02'
save_path = r'checkpoints/'
if not os.path.exists(save_path):
@@ -102,12 +102,12 @@ if __name__ == "__main__":
trainloader, testloader = dataset.get_loader(batch_size)
net = ConvNet(num_classes)
if torch.cuda.is_available():
net = net.cuda()
# net.load_state_dict(torch.load('checkpoints/handwriting_iter_004.pth'))
print('网络结构:\n')
summary(net, input_size=(3, 64, 64), device='cuda')
if torch.cuda.is_available():
net = net.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr)
writer = SummaryWriter(log_path)