From 9f5404552363abb99ba84d9d3f7ce1978d180aee Mon Sep 17 00:00:00 2001 From: JiageWang <1076050774@qq.com> Date: Tue, 27 Aug 2019 10:56:10 +0800 Subject: [PATCH] fix bug --- train.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index c278588..5a47a51 100644 --- a/train.py +++ b/train.py @@ -76,15 +76,16 @@ def train(epoch, net, criterion, optimizer, train_loader, writer, save_iter=100) if __name__ == "__main__": - data_path = r'data' - log_path = r'logs/batch_100_100_lr0.02' - save_path = r'checkpoints/' - if not os.path.exists(save_path): - os.mkdir(save_path) # 超参数 epochs = 20 batch_size = 100 - lr = 0.0001 + lr = 0.01 + + data_path = r'data' + log_path = r'logs/batch_{}_lr_{}'.format(batch_size, lr) + save_path = r'checkpoints/' + if not os.path.exists(save_path): + os.mkdir(save_path) # 读取分类类别 with open('char_dict', 'rb') as f: