network runs

This commit is contained in:
Your Name
2019-06-06 23:47:37 +08:00
parent 0d9ea44929
commit c8df372f63
8 changed files with 212 additions and 118 deletions

1
.gitignore vendored
View File

@@ -1 +1,2 @@
.vscode/
checkpoints/

227
.idea/workspace.xml generated
View File

@@ -1,20 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ChangeListManager">
<list default="true" id="c86d3061-c2c8-42bb-882f-53f6373c7f88" name="Default" comment="">
<change beforePath="$PROJECT_DIR$/dataset/.gitignore" afterPath="$PROJECT_DIR$/dataset/.gitignore" />
<change beforePath="$PROJECT_DIR$/dataset/casia_hwdb.py" afterPath="$PROJECT_DIR$/dataset/casia_hwdb.py" />
<change beforePath="$PROJECT_DIR$/dataset/casia_hwdb_1.0_1.1.tfrecord" afterPath="$PROJECT_DIR$/dataset/casia_hwdb_1.0_1.1.tfrecord" />
<change beforePath="$PROJECT_DIR$/dataset/charactors.txt" afterPath="$PROJECT_DIR$/dataset/charactors.txt" />
<change beforePath="$PROJECT_DIR$/dataset/convert_to_tfrecord.py" afterPath="$PROJECT_DIR$/dataset/convert_to_tfrecord.py" />
<change beforePath="$PROJECT_DIR$/dataset/get_hwdb_1.0_1.1.sh" afterPath="$PROJECT_DIR$/dataset/get_hwdb_1.0_1.1.sh" />
<change beforePath="$PROJECT_DIR$/readme.md" afterPath="$PROJECT_DIR$/readme.md" />
<change beforePath="$PROJECT_DIR$/sample.png" afterPath="$PROJECT_DIR$/sample.png" />
<change beforePath="$PROJECT_DIR$/samples/.gitignore" afterPath="$PROJECT_DIR$/samples/.gitignore" />
<change beforePath="$PROJECT_DIR$/samples/001-f.gnt" afterPath="$PROJECT_DIR$/samples/001-f.gnt" />
<change beforePath="$PROJECT_DIR$/samples/sample.png" afterPath="$PROJECT_DIR$/samples/sample.png" />
<change beforePath="$PROJECT_DIR$/tests.py" afterPath="$PROJECT_DIR$/tests.py" />
</list>
<list default="true" id="c86d3061-c2c8-42bb-882f-53f6373c7f88" name="Default" comment="" />
<option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" />
<option name="TRACKING_ENABLED" value="true" />
<option name="SHOW_DIALOG" value="false" />
@@ -23,29 +10,51 @@
<option name="LAST_RESOLUTION" value="IGNORE" />
</component>
<component name="FileEditorManager">
<leaf>
<file leaf-file-name="dataset_hwdb.py" pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/dataset/dataset_hwdb.py">
<leaf SIDE_TABS_SIZE_LIMIT_KEY="300">
<file leaf-file-name="casia_hwdb.py" pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/../../dataset/casia_hwdb.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="0">
<caret line="0" column="0" lean-forward="false" selection-start-line="0" selection-start-column="0" selection-end-line="0" selection-end-column="0" />
<state relative-caret-position="418">
<caret line="54" column="35" lean-forward="false" selection-start-line="54" selection-start-column="35" selection-end-line="54" selection-end-column="35" />
<folding />
</state>
</provider>
</entry>
</file>
<file leaf-file-name="casia_hwdb.py" pinned="false" current-in-tab="true">
<entry file="file://$PROJECT_DIR$/dataset/casia_hwdb.py">
<file leaf-file-name="train.py" pinned="false" current-in-tab="true">
<entry file="file://$PROJECT_DIR$/../../train.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="715">
<caret line="70" column="26" lean-forward="true" selection-start-line="70" selection-start-column="26" selection-end-line="70" selection-end-column="26" />
<state relative-caret-position="401">
<caret line="86" column="51" lean-forward="true" selection-start-line="86" selection-start-column="51" selection-end-line="86" selection-end-column="51" />
<folding />
</state>
</provider>
</entry>
</file>
<file leaf-file-name=".gitignore" pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/../../.gitignore">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="17">
<caret line="1" column="12" lean-forward="false" selection-start-line="1" selection-start-column="12" selection-end-line="1" selection-end-column="12" />
<folding />
</state>
</provider>
</entry>
</file>
<file leaf-file-name="cnn_net.py" pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/../../models/cnn_net.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="646">
<caret line="38" column="0" lean-forward="false" selection-start-line="38" selection-start-column="0" selection-end-line="38" selection-end-column="0" />
<folding>
<element signature="e#1153#1176#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
</file>
<file leaf-file-name="tests.py" pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/tests.py">
<entry file="file://$PROJECT_DIR$/../../tests.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="0">
<caret line="0" column="0" lean-forward="false" selection-start-line="0" selection-start-column="0" selection-end-line="0" selection-end-column="0" />
@@ -55,9 +64,9 @@
</entry>
</file>
<file leaf-file-name="convert_to_tfrecord.py" pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/dataset/convert_to_tfrecord.py">
<entry file="file://$PROJECT_DIR$/../../dataset/convert_to_tfrecord.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="-301">
<state relative-caret-position="0">
<caret line="0" column="0" lean-forward="false" selection-start-line="0" selection-start-column="0" selection-end-line="0" selection-end-column="0" />
<folding />
</state>
@@ -67,12 +76,16 @@
</leaf>
</component>
<component name="Git.Settings">
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$/../.." />
</component>
<component name="IdeDocumentHistory">
<option name="CHANGED_PATHS">
<list>
<option value="$PROJECT_DIR$/dataset/casia_hwdb.py" />
<option value="$PROJECT_DIR$/../../.gitignore" />
<option value="$PROJECT_DIR$/../../models/cnn_net.py" />
<option value="$PROJECT_DIR$/../../dataset/casia_hwdb.py" />
<option value="$PROJECT_DIR$/../../train.py" />
</list>
</option>
</component>
@@ -82,11 +95,10 @@
<detection-done>true</detection-done>
<sorting>DEFINITION_ORDER</sorting>
</component>
<component name="ProjectFrameBounds">
<option name="x" value="933" />
<option name="y" value="28" />
<option name="width" value="1538" />
<option name="height" value="1412" />
<component name="ProjectFrameBounds" extendedState="4">
<option name="y" value="25" />
<option name="width" value="1573" />
<option name="height" value="1415" />
</component>
<component name="ProjectView">
<navigator currentView="ProjectPane" proportions="" version="1">
@@ -103,24 +115,19 @@
<foldersAlwaysOnTop value="true" />
</navigator>
<panes>
<pane id="Scratches" />
<pane id="Scope" />
<pane id="ProjectPane">
<subPane>
<expand>
<path>
<item name="ocrcn_tf2" type="b2602c69:ProjectViewProjectNode" />
<item name="." type="b2602c69:ProjectViewProjectNode" />
<item name="ocrcn_tf2" type="462c0819:PsiDirectoryNode" />
</path>
<path>
<item name="ocrcn_tf2" type="b2602c69:ProjectViewProjectNode" />
<item name="ocrcn_tf2" type="462c0819:PsiDirectoryNode" />
<item name="dataset" type="462c0819:PsiDirectoryNode" />
</path>
</expand>
<select />
</subPane>
</pane>
<pane id="Scope" />
<pane id="Scratches" />
</panes>
</component>
<component name="PropertiesComponent">
@@ -157,22 +164,23 @@
<servers />
</component>
<component name="ToolWindowManager">
<frame x="933" y="28" width="1538" height="1412" extended-state="0" />
<frame x="0" y="25" width="1573" height="1415" extended-state="4" />
<editor active="true" />
<layout>
<window_info id="TODO" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="-1" side_tool="false" content_ui="tabs" />
<window_info id="Event Log" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="-1" side_tool="true" content_ui="tabs" />
<window_info id="Version Control" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="-1" side_tool="false" content_ui="tabs" />
<window_info id="Python Console" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="-1" side_tool="false" content_ui="tabs" />
<window_info id="Run" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="-1" side_tool="false" content_ui="tabs" />
<window_info id="Terminal" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="-1" side_tool="false" content_ui="tabs" />
<window_info id="Project" active="false" anchor="left" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="true" show_stripe_button="true" weight="0.22386059" sideWeight="0.5" order="-1" side_tool="false" content_ui="combo" />
<window_info id="Docker" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="false" weight="0.33" sideWeight="0.5" order="-1" side_tool="false" content_ui="tabs" />
<window_info id="Database" active="false" anchor="right" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="-1" side_tool="false" content_ui="tabs" />
<window_info id="SciView" active="false" anchor="right" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="-1" side_tool="false" content_ui="tabs" />
<window_info id="Structure" active="false" anchor="left" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="-1" side_tool="true" content_ui="tabs" />
<window_info id="Favorites" active="false" anchor="left" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="-1" side_tool="true" content_ui="tabs" />
<window_info id="Debug" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="-1" side_tool="false" content_ui="tabs" />
<window_info id="TODO" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="0" side_tool="false" content_ui="tabs" />
<window_info id="Event Log" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="0" side_tool="true" content_ui="tabs" />
<window_info id="Run" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="0" side_tool="false" content_ui="tabs" />
<window_info id="Version Control" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="0" side_tool="false" content_ui="tabs" />
<window_info id="Python Console" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="0" side_tool="false" content_ui="tabs" />
<window_info id="Terminal" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="true" show_stripe_button="true" weight="0.32943925" sideWeight="0.5" order="0" side_tool="false" content_ui="tabs" />
<window_info id="Project" active="false" anchor="left" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="true" show_stripe_button="true" weight="0.14800262" sideWeight="0.5" order="0" side_tool="false" content_ui="combo" />
<window_info id="Docker" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="false" weight="0.33" sideWeight="0.5" order="0" side_tool="false" content_ui="tabs" />
<window_info id="Database" active="false" anchor="right" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="0" side_tool="false" content_ui="tabs" />
<window_info id="Find" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.32943925" sideWeight="0.5" order="-1" side_tool="false" content_ui="tabs" />
<window_info id="SciView" active="false" anchor="right" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="0" side_tool="false" content_ui="tabs" />
<window_info id="Structure" active="false" anchor="left" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="0" side_tool="true" content_ui="tabs" />
<window_info id="Debug" active="false" anchor="bottom" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="0" side_tool="false" content_ui="tabs" />
<window_info id="Favorites" active="false" anchor="left" auto_hide="false" internal_type="DOCKED" type="DOCKED" visible="false" show_stripe_button="true" weight="0.33" sideWeight="0.5" order="0" side_tool="true" content_ui="tabs" />
</layout>
</component>
<component name="TypeScriptGeneratedFilesManager">
@@ -186,6 +194,38 @@
<watches-manager />
</component>
<component name="editorHistoryManager">
<entry file="file://$PROJECT_DIR$/../../dataset/dataset_hwdb.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="0">
<caret line="0" column="0" lean-forward="false" selection-start-line="0" selection-start-column="0" selection-end-line="0" selection-end-column="0" />
<folding />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/../../tests.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="0">
<caret line="0" column="0" lean-forward="false" selection-start-line="0" selection-start-column="0" selection-end-line="0" selection-end-column="0" />
<folding />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/../../dataset/convert_to_tfrecord.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="0">
<caret line="0" column="0" lean-forward="false" selection-start-line="0" selection-start-column="0" selection-end-line="0" selection-end-column="0" />
<folding />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/../../dataset/casia_hwdb.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="1190">
<caret line="70" column="26" lean-forward="true" selection-start-line="70" selection-start-column="26" selection-end-line="70" selection-end-column="26" />
<folding />
</state>
</provider>
</entry>
<entry file="file:///usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="-7666">
@@ -194,44 +234,19 @@
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/dataset/charactors.txt">
<entry file="file://$PROJECT_DIR$/../../dataset/charactors.txt">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="0">
<caret line="0" column="0" lean-forward="false" selection-start-line="0" selection-start-column="0" selection-end-line="0" selection-end-column="0" />
<folding />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/dataset/convert_to_tfrecord.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="-301">
<caret line="0" column="0" lean-forward="false" selection-start-line="0" selection-start-column="0" selection-end-line="0" selection-end-column="0" />
<folding />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/dataset/dataset_hwdb.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="0">
<caret line="0" column="0" lean-forward="false" selection-start-line="0" selection-start-column="0" selection-end-line="0" selection-end-column="0" />
<folding />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/train.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="0">
<caret line="0" column="0" lean-forward="false" selection-start-line="0" selection-start-column="0" selection-end-line="0" selection-end-column="0" />
<folding />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/sample.png">
<entry file="file://$PROJECT_DIR$/../../sample.png">
<provider selected="true" editor-type-id="images">
<state />
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/tests.py">
<entry file="file://$PROJECT_DIR$/../../tests.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="0">
<caret line="0" column="0" lean-forward="false" selection-start-line="0" selection-start-column="0" selection-end-line="0" selection-end-column="0" />
@@ -239,10 +254,52 @@
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/dataset/casia_hwdb.py">
<entry file="file://$PROJECT_DIR$/../../.gitignore">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="715">
<caret line="70" column="26" lean-forward="true" selection-start-line="70" selection-start-column="26" selection-end-line="70" selection-end-column="26" />
<state relative-caret-position="17">
<caret line="1" column="12" lean-forward="false" selection-start-line="1" selection-start-column="12" selection-end-line="1" selection-end-column="12" />
<folding />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/../../dataset/dataset_hwdb.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="0">
<caret line="0" column="0" lean-forward="false" selection-start-line="0" selection-start-column="0" selection-end-line="0" selection-end-column="0" />
<folding />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/../../dataset/convert_to_tfrecord.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="0">
<caret line="0" column="0" lean-forward="false" selection-start-line="0" selection-start-column="0" selection-end-line="0" selection-end-column="0" />
<folding />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/../../dataset/casia_hwdb.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="418">
<caret line="54" column="35" lean-forward="false" selection-start-line="54" selection-start-column="35" selection-end-line="54" selection-end-column="35" />
<folding />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/../../models/cnn_net.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="646">
<caret line="38" column="0" lean-forward="false" selection-start-line="38" selection-start-column="0" selection-end-line="38" selection-end-column="0" />
<folding>
<element signature="e#1153#1176#0" expanded="true" />
</folding>
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/../../train.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="401">
<caret line="86" column="51" lean-forward="true" selection-start-line="86" selection-start-column="51" selection-end-line="86" selection-end-column="51" />
<folding />
</state>
</provider>

Binary file not shown.

View File

@@ -7,13 +7,17 @@ we using this class to get .png and label from raw
"""
from alfred.dl.tf.common import mute_tf
mute_tf()
import struct
import numpy as np
import cv2
import tensorflow as tf
import os
this_dir = os.path.dirname(os.path.abspath(__file__))
class CASIAHWDBGNT(object):
"""
@@ -52,8 +56,9 @@ def parse_example(record):
tf.io.FixedLenFeature([], tf.string),
})
img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
img = tf.cast(tf.reshape(img, (64, 64)), dtype=tf.float32)
label = tf.cast(features['label'], tf.int32)
return img, label
return {'image': img, 'label': label}
def load_ds():
@@ -63,14 +68,15 @@ def load_ds():
return ds
def load_charactors():
a = open('charactors.txt', 'r').readlines()
def load_characters():
a = open(os.path.join(this_dir, 'charactors.txt'), 'r').readlines()
return [i.strip() for i in a]
if __name__ == "__main__":
ds = load_ds()
charactors = load_charactors()
charactors = load_characters()
for img, label in ds.take(9):
# start training on model...
img = img.numpy()

Binary file not shown.

View File

@@ -1,4 +1,3 @@
'''
@@ -21,9 +20,42 @@ accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.flo
'''
import tensorflow as tf
from tensorflow.keras import layers
# some simple models
def build_net_001(input_shape, n_classes):
assert len(input_shape) == 3, 'only support 3 channels'
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(
input_shape=input_shape, filters=32, kernel_size=(3, 3), strides=(1, 1),
padding='valid', activation='relu'))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2, 2)))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(32, activation='relu'))
model.add(tf.keras.layers.Dense(n_classes, activation='softmax'))
return model
def build_net_002(input_shape, n_classes):
model = tf.keras.Sequential([
layers.Conv2D(input_shape=input_shape, filters=64, kernel_size=(3, 3), strides=(1, 1),
padding='same', activation='relu'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Conv2D(filters=128, kernel_size=(3, 3), padding='same'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Conv2D(filters=256, kernel_size=(3, 3), padding='same'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Flatten(),
layers.Dense(1024, activation='relu'),
layers.Dense(n_classes, activation='softmax')
])
return model
# some models wrapped into tf.keras.Model
class CNNNet(tf.keras.Model):
def __init__(self.):
pass
def __init__(self):
pass

View File

@@ -3,6 +3,7 @@ training HWDB Chinese charactors classification
on MobileNetV2
'''
from alfred.dl.tf.common import mute_tf
mute_tf()
import os
@@ -12,40 +13,41 @@ import tensorflow as tf
from alfred.utils.log import logger as logging
import tensorflow_datasets as tfds
from dataset.casia_hwdb import load_ds, load_charactors
from models.cnn_net import CNNNet
from dataset.casia_hwdb import load_ds, load_characters
from models.cnn_net import CNNNet, build_net_002
target_size = 224
target_size = 64
num_classes = 7356
use_keras_fit = False
# use_keras_fit = True
ckpt_path = './checkpoints/no_finetune/flowers_mbv2_scratch-{epoch}.ckpt'
# use_keras_fit = False
use_keras_fit = True
ckpt_path = './checkpoints/cn_ocr-{epoch}.ckpt'
def preprocess(x):
"""
minus mean pixel or normalize?
"""
x['image'] = tf.expand_dims(x['image'], axis=-1)
x['image'] = tf.image.resize(x['image'], (target_size, target_size))
x['image'] /= 255.
x['image'] = 2*x['image'] - 1
x['image'] = 2 * x['image'] - 1
return x['image'], x['label']
def train():
all_charactors = load_charactors()
num_classes = len(all_charactors)
# using mobilenetv2 classify tf_flowers dataset
all_characters = load_characters()
num_classes = len(all_characters)
logging.info('all characters: {}'.format(num_classes))
train_dataset = load_ds()
train_dataset = train_dataset.shuffle(100).map(preprocess).batch(4).repeat()
# init model
model = CNNNet()
# model.summary()
# model = tf.keras.models.load_model('flowers_mobilenetv2.h5')
model = build_net_002((64, 64, 1), num_classes)
model.summary()
logging.info('model loaded.')
start_epoch = 0
latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path))
if latest_ckpt:
@@ -56,26 +58,24 @@ def train():
logging.info('passing resume since weights not there. training from scratch')
if use_keras_fit:
# todo: why keras fit converge faster than tf loop?
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
try:
model.fit(
train_dataset, epochs=50,
steps_per_epoch=700,)
train_dataset, epochs=50,
steps_per_epoch=700, )
except KeyboardInterrupt:
model.save_weights(ckpt_path.format(epoch=0))
logging.info('keras model saved.')
model.save_weights(ckpt_path.format(epoch=0))
model.save(os.path.join(os.path.dirname(ckpt_path), 'flowers_mobilenetv2.h5'))
model.save(os.path.join(os.path.dirname(ckpt_path), 'cn_ocr.h5'))
else:
loss_fn = tf.losses.SparseCategoricalCrossentropy()
optimizer = tf.optimizers.RMSprop()
train_loss = tf.metrics.Mean(name='train_loss')
# the accuracy calculation has some problems, seems not right?
train_accuracy = tf.metrics.SparseCategoricalAccuracy(name='train_accuracy')
for epoch in range(start_epoch, 120):
@@ -92,7 +92,7 @@ def train():
train_accuracy(labels, predictions)
if batch % 10 == 0:
logging.info('Epoch: {}, iter: {}, loss: {}, train_acc: {}'.format(
epoch, batch, train_loss.result(), train_accuracy.result()))
epoch, batch, train_loss.result(), train_accuracy.result()))
except KeyboardInterrupt:
logging.info('interrupted.')
model.save_weights(ckpt_path.format(epoch=epoch))
@@ -100,7 +100,5 @@ def train():
exit(0)
if __name__ == "__main__":
train()