fix bug
This commit is contained in:
6
train.py
6
train.py
@@ -76,7 +76,7 @@ def train(epoch, net, criterion, optimizer, train_loader, writer, save_iter=100)
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
data_path = r'C:\Users\Administrator\Desktop\hand-writing-recognition\data'
|
data_path = r'data'
|
||||||
log_path = r'logs/batch_100_100_lr0.02'
|
log_path = r'logs/batch_100_100_lr0.02'
|
||||||
save_path = r'checkpoints/'
|
save_path = r'checkpoints/'
|
||||||
if not os.path.exists(save_path):
|
if not os.path.exists(save_path):
|
||||||
@@ -102,12 +102,12 @@ if __name__ == "__main__":
|
|||||||
trainloader, testloader = dataset.get_loader(batch_size)
|
trainloader, testloader = dataset.get_loader(batch_size)
|
||||||
|
|
||||||
net = ConvNet(num_classes)
|
net = ConvNet(num_classes)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
net = net.cuda()
|
||||||
# net.load_state_dict(torch.load('checkpoints/handwriting_iter_004.pth'))
|
# net.load_state_dict(torch.load('checkpoints/handwriting_iter_004.pth'))
|
||||||
|
|
||||||
print('网络结构:\n')
|
print('网络结构:\n')
|
||||||
summary(net, input_size=(3, 64, 64), device='cuda')
|
summary(net, input_size=(3, 64, 64), device='cuda')
|
||||||
if torch.cuda.is_available():
|
|
||||||
net = net.cuda()
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = optim.SGD(net.parameters(), lr=lr)
|
optimizer = optim.SGD(net.parameters(), lr=lr)
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user