fix bug
This commit is contained in:
6
train.py
6
train.py
@@ -76,7 +76,7 @@ def train(epoch, net, criterion, optimizer, train_loader, writer, save_iter=100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_path = r'C:\Users\Administrator\Desktop\hand-writing-recognition\data'
|
||||
data_path = r'data'
|
||||
log_path = r'logs/batch_100_100_lr0.02'
|
||||
save_path = r'checkpoints/'
|
||||
if not os.path.exists(save_path):
|
||||
@@ -102,12 +102,12 @@ if __name__ == "__main__":
|
||||
trainloader, testloader = dataset.get_loader(batch_size)
|
||||
|
||||
net = ConvNet(num_classes)
|
||||
if torch.cuda.is_available():
|
||||
net = net.cuda()
|
||||
# net.load_state_dict(torch.load('checkpoints/handwriting_iter_004.pth'))
|
||||
|
||||
print('网络结构:\n')
|
||||
summary(net, input_size=(3, 64, 64), device='cuda')
|
||||
if torch.cuda.is_available():
|
||||
net = net.cuda()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(net.parameters(), lr=lr)
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
Reference in New Issue
Block a user