add
This commit is contained in:
@@ -57,7 +57,7 @@ def parse_example(record):
|
|||||||
|
|
||||||
|
|
||||||
def load_ds():
|
def load_ds():
|
||||||
input_files = ['casia_hwdb_1.0_1.1.tfrecord']
|
input_files = ['dataset/hwdb_11.tfrecord']
|
||||||
ds = tf.data.TFRecordDataset(input_files)
|
ds = tf.data.TFRecordDataset(input_files)
|
||||||
ds = ds.map(parse_example)
|
ds = ds.map(parse_example)
|
||||||
return ds
|
return ds
|
||||||
|
|||||||
29
models/cnn_net.py
Executable file
29
models/cnn_net.py
Executable file
@@ -0,0 +1,29 @@
|
|||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
conv_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv1')
|
||||||
|
# (inputs,num_outputs,[卷积核个数] kernel_size,[卷积核的高度,卷积核的宽]stride=1,padding='SAME',)
|
||||||
|
max_pool_1 = slim.max_pool2d(conv_1, [2, 2], [2, 2], padding='SAME')
|
||||||
|
conv_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv2')
|
||||||
|
max_pool_2 = slim.max_pool2d(conv_2, [2, 2], [2, 2], padding='SAME')
|
||||||
|
conv_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3')
|
||||||
|
max_pool_3 = slim.max_pool2d(conv_3, [2, 2], [2, 2], padding='SAME')
|
||||||
|
|
||||||
|
flatten = slim.flatten(max_pool_3)
|
||||||
|
fc1 = slim.fully_connected(tf.nn.dropout(flatten, keep_prob), 1024, activation_fn=tf.nn.tanh, scope='fc1')
|
||||||
|
logits = slim.fully_connected(tf.nn.dropout(fc1, keep_prob), FLAGS.charset_size, activation_fn=None, scope='fc2')
|
||||||
|
# logits = slim.fully_connected(flatten, FLAGS.charset_size, activation_fn=None, reuse=reuse, scope='fc')
|
||||||
|
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
|
||||||
|
# y表示的是实际类别,y_表示预测结果,这实际上面是把原来的神经网络输出层的softmax和cross_entrop何在一起计算,为了追求速度
|
||||||
|
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32))
|
||||||
|
'''
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class CNNNet(tf.keras.Model):
|
||||||
|
|
||||||
|
def __init__(self.):
|
||||||
|
pass
|
||||||
106
train.py
106
train.py
@@ -0,0 +1,106 @@
|
|||||||
|
'''
|
||||||
|
training HWDB Chinese charactors classification
|
||||||
|
on MobileNetV2
|
||||||
|
'''
|
||||||
|
from alfred.dl.tf.common import mute_tf
|
||||||
|
mute_tf()
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from alfred.utils.log import logger as logging
|
||||||
|
import tensorflow_datasets as tfds
|
||||||
|
from dataset.casia_hwdb import load_ds, load_charactors
|
||||||
|
from models.cnn_net import CNNNet
|
||||||
|
|
||||||
|
|
||||||
|
target_size = 224
|
||||||
|
num_classes = 7356
|
||||||
|
use_keras_fit = False
|
||||||
|
# use_keras_fit = True
|
||||||
|
ckpt_path = './checkpoints/no_finetune/flowers_mbv2_scratch-{epoch}.ckpt'
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(x):
|
||||||
|
"""
|
||||||
|
minus mean pixel or normalize?
|
||||||
|
"""
|
||||||
|
x['image'] = tf.image.resize(x['image'], (target_size, target_size))
|
||||||
|
x['image'] /= 255.
|
||||||
|
x['image'] = 2*x['image'] - 1
|
||||||
|
return x['image'], x['label']
|
||||||
|
|
||||||
|
def train():
|
||||||
|
all_charactors = load_charactors()
|
||||||
|
num_classes = len(all_charactors)
|
||||||
|
# using mobilenetv2 classify tf_flowers dataset
|
||||||
|
train_dataset = load_ds()
|
||||||
|
train_dataset = train_dataset.shuffle(100).map(preprocess).batch(4).repeat()
|
||||||
|
|
||||||
|
# init model
|
||||||
|
model = CNNNet()
|
||||||
|
|
||||||
|
# model.summary()
|
||||||
|
# model = tf.keras.models.load_model('flowers_mobilenetv2.h5')
|
||||||
|
logging.info('model loaded.')
|
||||||
|
|
||||||
|
start_epoch = 0
|
||||||
|
latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path))
|
||||||
|
if latest_ckpt:
|
||||||
|
start_epoch = int(latest_ckpt.split('-')[1].split('.')[0])
|
||||||
|
model.load_weights(latest_ckpt)
|
||||||
|
logging.info('model resumed from: {}, start at epoch: {}'.format(latest_ckpt, start_epoch))
|
||||||
|
else:
|
||||||
|
logging.info('passing resume since weights not there. training from scratch')
|
||||||
|
|
||||||
|
if use_keras_fit:
|
||||||
|
# todo: why keras fit converge faster than tf loop?
|
||||||
|
model.compile(
|
||||||
|
optimizer='adam',
|
||||||
|
loss='sparse_categorical_crossentropy',
|
||||||
|
metrics=['accuracy'])
|
||||||
|
try:
|
||||||
|
model.fit(
|
||||||
|
train_dataset, epochs=50,
|
||||||
|
steps_per_epoch=700,)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
model.save_weights(ckpt_path.format(epoch=0))
|
||||||
|
logging.info('keras model saved.')
|
||||||
|
model.save_weights(ckpt_path.format(epoch=0))
|
||||||
|
model.save(os.path.join(os.path.dirname(ckpt_path), 'flowers_mobilenetv2.h5'))
|
||||||
|
else:
|
||||||
|
loss_fn = tf.losses.SparseCategoricalCrossentropy()
|
||||||
|
optimizer = tf.optimizers.RMSprop()
|
||||||
|
|
||||||
|
train_loss = tf.metrics.Mean(name='train_loss')
|
||||||
|
# the accuracy calculation has some problems, seems not right?
|
||||||
|
train_accuracy = tf.metrics.SparseCategoricalAccuracy(name='train_accuracy')
|
||||||
|
|
||||||
|
for epoch in range(start_epoch, 120):
|
||||||
|
try:
|
||||||
|
for batch, data in enumerate(train_dataset):
|
||||||
|
# images, labels = data['image'], data['label']
|
||||||
|
images, labels = data
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
predictions = model(images)
|
||||||
|
loss = loss_fn(labels, predictions)
|
||||||
|
gradients = tape.gradient(loss, model.trainable_variables)
|
||||||
|
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
||||||
|
train_loss(loss)
|
||||||
|
train_accuracy(labels, predictions)
|
||||||
|
if batch % 10 == 0:
|
||||||
|
logging.info('Epoch: {}, iter: {}, loss: {}, train_acc: {}'.format(
|
||||||
|
epoch, batch, train_loss.result(), train_accuracy.result()))
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logging.info('interrupted.')
|
||||||
|
model.save_weights(ckpt_path.format(epoch=epoch))
|
||||||
|
logging.info('model saved into: {}'.format(ckpt_path.format(epoch=epoch)))
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user