diff --git a/train.py b/train.py index b926df5..c278588 100644 --- a/train.py +++ b/train.py @@ -76,7 +76,7 @@ def train(epoch, net, criterion, optimizer, train_loader, writer, save_iter=100) if __name__ == "__main__": - data_path = r'C:\Users\Administrator\Desktop\hand-writing-recognition\data' + data_path = r'data' log_path = r'logs/batch_100_100_lr0.02' save_path = r'checkpoints/' if not os.path.exists(save_path): @@ -102,12 +102,12 @@ if __name__ == "__main__": trainloader, testloader = dataset.get_loader(batch_size) net = ConvNet(num_classes) + if torch.cuda.is_available(): + net = net.cuda() # net.load_state_dict(torch.load('checkpoints/handwriting_iter_004.pth')) print('网络结构:\n') summary(net, input_size=(3, 64, 64), device='cuda') - if torch.cuda.is_available(): - net = net.cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=lr) writer = SummaryWriter(log_path)