network runs
This commit is contained in:
48
train.py
48
train.py
@@ -3,6 +3,7 @@ training HWDB Chinese charactors classification
|
||||
on MobileNetV2
|
||||
'''
|
||||
from alfred.dl.tf.common import mute_tf
|
||||
|
||||
mute_tf()
|
||||
|
||||
import os
|
||||
@@ -12,40 +13,41 @@ import tensorflow as tf
|
||||
|
||||
from alfred.utils.log import logger as logging
|
||||
import tensorflow_datasets as tfds
|
||||
from dataset.casia_hwdb import load_ds, load_charactors
|
||||
from models.cnn_net import CNNNet
|
||||
from dataset.casia_hwdb import load_ds, load_characters
|
||||
from models.cnn_net import CNNNet, build_net_002
|
||||
|
||||
|
||||
target_size = 224
|
||||
|
||||
target_size = 64
|
||||
num_classes = 7356
|
||||
use_keras_fit = False
|
||||
# use_keras_fit = True
|
||||
ckpt_path = './checkpoints/no_finetune/flowers_mbv2_scratch-{epoch}.ckpt'
|
||||
# use_keras_fit = False
|
||||
use_keras_fit = True
|
||||
ckpt_path = './checkpoints/cn_ocr-{epoch}.ckpt'
|
||||
|
||||
|
||||
def preprocess(x):
|
||||
"""
|
||||
minus mean pixel or normalize?
|
||||
"""
|
||||
x['image'] = tf.expand_dims(x['image'], axis=-1)
|
||||
x['image'] = tf.image.resize(x['image'], (target_size, target_size))
|
||||
x['image'] /= 255.
|
||||
x['image'] = 2*x['image'] - 1
|
||||
x['image'] = 2 * x['image'] - 1
|
||||
return x['image'], x['label']
|
||||
|
||||
|
||||
def train():
|
||||
all_charactors = load_charactors()
|
||||
num_classes = len(all_charactors)
|
||||
# using mobilenetv2 classify tf_flowers dataset
|
||||
all_characters = load_characters()
|
||||
num_classes = len(all_characters)
|
||||
logging.info('all characters: {}'.format(num_classes))
|
||||
train_dataset = load_ds()
|
||||
train_dataset = train_dataset.shuffle(100).map(preprocess).batch(4).repeat()
|
||||
|
||||
# init model
|
||||
model = CNNNet()
|
||||
|
||||
# model.summary()
|
||||
# model = tf.keras.models.load_model('flowers_mobilenetv2.h5')
|
||||
model = build_net_002((64, 64, 1), num_classes)
|
||||
model.summary()
|
||||
logging.info('model loaded.')
|
||||
|
||||
|
||||
start_epoch = 0
|
||||
latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path))
|
||||
if latest_ckpt:
|
||||
@@ -56,26 +58,24 @@ def train():
|
||||
logging.info('passing resume since weights not there. training from scratch')
|
||||
|
||||
if use_keras_fit:
|
||||
# todo: why keras fit converge faster than tf loop?
|
||||
model.compile(
|
||||
optimizer='adam',
|
||||
loss='sparse_categorical_crossentropy',
|
||||
optimizer=tf.keras.optimizers.Adam(),
|
||||
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
|
||||
metrics=['accuracy'])
|
||||
try:
|
||||
model.fit(
|
||||
train_dataset, epochs=50,
|
||||
steps_per_epoch=700,)
|
||||
train_dataset, epochs=50,
|
||||
steps_per_epoch=700, )
|
||||
except KeyboardInterrupt:
|
||||
model.save_weights(ckpt_path.format(epoch=0))
|
||||
logging.info('keras model saved.')
|
||||
model.save_weights(ckpt_path.format(epoch=0))
|
||||
model.save(os.path.join(os.path.dirname(ckpt_path), 'flowers_mobilenetv2.h5'))
|
||||
model.save(os.path.join(os.path.dirname(ckpt_path), 'cn_ocr.h5'))
|
||||
else:
|
||||
loss_fn = tf.losses.SparseCategoricalCrossentropy()
|
||||
optimizer = tf.optimizers.RMSprop()
|
||||
|
||||
train_loss = tf.metrics.Mean(name='train_loss')
|
||||
# the accuracy calculation has some problems, seems not right?
|
||||
train_accuracy = tf.metrics.SparseCategoricalAccuracy(name='train_accuracy')
|
||||
|
||||
for epoch in range(start_epoch, 120):
|
||||
@@ -92,7 +92,7 @@ def train():
|
||||
train_accuracy(labels, predictions)
|
||||
if batch % 10 == 0:
|
||||
logging.info('Epoch: {}, iter: {}, loss: {}, train_acc: {}'.format(
|
||||
epoch, batch, train_loss.result(), train_accuracy.result()))
|
||||
epoch, batch, train_loss.result(), train_accuracy.result()))
|
||||
except KeyboardInterrupt:
|
||||
logging.info('interrupted.')
|
||||
model.save_weights(ckpt_path.format(epoch=epoch))
|
||||
@@ -100,7 +100,5 @@ def train():
|
||||
exit(0)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user