From 0d9ea44929d3b7b3489b3da20888650dcda5cfdd Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 5 Jun 2019 23:49:20 +0800 Subject: [PATCH] add --- dataset/casia_hwdb.py | 2 +- models/cnn_net.py | 29 ++++++++++++ train.py | 106 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 136 insertions(+), 1 deletion(-) create mode 100755 models/cnn_net.py diff --git a/dataset/casia_hwdb.py b/dataset/casia_hwdb.py index 175a823..08a5253 100755 --- a/dataset/casia_hwdb.py +++ b/dataset/casia_hwdb.py @@ -57,7 +57,7 @@ def parse_example(record): def load_ds(): - input_files = ['casia_hwdb_1.0_1.1.tfrecord'] + input_files = ['dataset/hwdb_11.tfrecord'] ds = tf.data.TFRecordDataset(input_files) ds = ds.map(parse_example) return ds diff --git a/models/cnn_net.py b/models/cnn_net.py new file mode 100755 index 0000000..8325011 --- /dev/null +++ b/models/cnn_net.py @@ -0,0 +1,29 @@ + +''' + + + +conv_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv1') +# (inputs,num_outputs,[卷积核个数] kernel_size,[卷积核的高度,卷积核的宽]stride=1,padding='SAME',) +max_pool_1 = slim.max_pool2d(conv_1, [2, 2], [2, 2], padding='SAME') +conv_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv2') +max_pool_2 = slim.max_pool2d(conv_2, [2, 2], [2, 2], padding='SAME') +conv_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3') +max_pool_3 = slim.max_pool2d(conv_3, [2, 2], [2, 2], padding='SAME') + +flatten = slim.flatten(max_pool_3) +fc1 = slim.fully_connected(tf.nn.dropout(flatten, keep_prob), 1024, activation_fn=tf.nn.tanh, scope='fc1') +logits = slim.fully_connected(tf.nn.dropout(fc1, keep_prob), FLAGS.charset_size, activation_fn=None, scope='fc2') +# logits = slim.fully_connected(flatten, FLAGS.charset_size, activation_fn=None, reuse=reuse, scope='fc') +loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)) +# y表示的是实际类别,y_表示预测结果,这实际上面是把原来的神经网络输出层的softmax和cross_entrop何在一起计算,为了追求速度 +accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32)) +''' + +import tensorflow as tf + + +class CNNNet(tf.keras.Model): + + def __init__(self.): + pass \ No newline at end of file diff --git a/train.py b/train.py index e69de29..c4ffb29 100755 --- a/train.py +++ b/train.py @@ -0,0 +1,106 @@ +''' +training HWDB Chinese charactors classification +on MobileNetV2 +''' +from alfred.dl.tf.common import mute_tf +mute_tf() + +import os +import sys +import numpy as np +import tensorflow as tf + +from alfred.utils.log import logger as logging +import tensorflow_datasets as tfds +from dataset.casia_hwdb import load_ds, load_charactors +from models.cnn_net import CNNNet + + +target_size = 224 +num_classes = 7356 +use_keras_fit = False +# use_keras_fit = True +ckpt_path = './checkpoints/no_finetune/flowers_mbv2_scratch-{epoch}.ckpt' + + +def preprocess(x): + """ + minus mean pixel or normalize? + """ + x['image'] = tf.image.resize(x['image'], (target_size, target_size)) + x['image'] /= 255. + x['image'] = 2*x['image'] - 1 + return x['image'], x['label'] + +def train(): + all_charactors = load_charactors() + num_classes = len(all_charactors) + # using mobilenetv2 classify tf_flowers dataset + train_dataset = load_ds() + train_dataset = train_dataset.shuffle(100).map(preprocess).batch(4).repeat() + + # init model + model = CNNNet() + + # model.summary() + # model = tf.keras.models.load_model('flowers_mobilenetv2.h5') + logging.info('model loaded.') + + start_epoch = 0 + latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path)) + if latest_ckpt: + start_epoch = int(latest_ckpt.split('-')[1].split('.')[0]) + model.load_weights(latest_ckpt) + logging.info('model resumed from: {}, start at epoch: {}'.format(latest_ckpt, start_epoch)) + else: + logging.info('passing resume since weights not there. training from scratch') + + if use_keras_fit: + # todo: why keras fit converge faster than tf loop? + model.compile( + optimizer='adam', + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + try: + model.fit( + train_dataset, epochs=50, + steps_per_epoch=700,) + except KeyboardInterrupt: + model.save_weights(ckpt_path.format(epoch=0)) + logging.info('keras model saved.') + model.save_weights(ckpt_path.format(epoch=0)) + model.save(os.path.join(os.path.dirname(ckpt_path), 'flowers_mobilenetv2.h5')) + else: + loss_fn = tf.losses.SparseCategoricalCrossentropy() + optimizer = tf.optimizers.RMSprop() + + train_loss = tf.metrics.Mean(name='train_loss') + # the accuracy calculation has some problems, seems not right? + train_accuracy = tf.metrics.SparseCategoricalAccuracy(name='train_accuracy') + + for epoch in range(start_epoch, 120): + try: + for batch, data in enumerate(train_dataset): + # images, labels = data['image'], data['label'] + images, labels = data + with tf.GradientTape() as tape: + predictions = model(images) + loss = loss_fn(labels, predictions) + gradients = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients(zip(gradients, model.trainable_variables)) + train_loss(loss) + train_accuracy(labels, predictions) + if batch % 10 == 0: + logging.info('Epoch: {}, iter: {}, loss: {}, train_acc: {}'.format( + epoch, batch, train_loss.result(), train_accuracy.result())) + except KeyboardInterrupt: + logging.info('interrupted.') + model.save_weights(ckpt_path.format(epoch=epoch)) + logging.info('model saved into: {}'.format(ckpt_path.format(epoch=epoch))) + exit(0) + + + +if __name__ == "__main__": + train() +