diff --git a/._.DS_Store b/._.DS_Store
deleted file mode 100755
index 9ad849c..0000000
Binary files a/._.DS_Store and /dev/null differ
diff --git a/.gitignore b/.gitignore
index ea1ac33..7393194 100755
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
.vscode/
checkpoints/
+.idea/
diff --git a/.idea/workspace.xml b/.idea/workspace.xml
index 1b0bbe6..4e17f0d 100755
--- a/.idea/workspace.xml
+++ b/.idea/workspace.xml
@@ -11,63 +11,51 @@
+
+
+
+
+
+
+
+
+
+
-
+
-
-
+
+
-
-
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
-
+
-
-
+
+
+
+
+
+
+
+
+
+
+
+
@@ -75,17 +63,39 @@
+
+
+
+
+
+ 蹬
+
+
-
+
-
+
+
+
+
+
+
+
+
+
+
+
@@ -95,10 +105,11 @@
true
DEFINITION_ORDER
-
-
-
-
+
+
+
+
+
@@ -116,6 +127,7 @@
+
@@ -123,11 +135,20 @@
+
+
+
+
+
+
+
+
+
+
-
@@ -135,6 +156,11 @@
+
+
+
+
+
@@ -164,7 +190,7 @@
-
+
@@ -172,11 +198,11 @@
-
-
+
+
-
+
@@ -194,38 +220,50 @@
-
+
-
-
+
+
-
+
-
-
+
+
-
+
-
-
+
+
-
+
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -234,19 +272,30 @@
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
-
+
+
+
+
+
+
+
+
+
@@ -254,15 +303,17 @@
-
-
-
-
-
-
+
+
+
-
+
+
+
+
+
+
@@ -270,7 +321,15 @@
-
+
+
+
+
+
+
+
+
+
@@ -278,28 +337,88 @@
-
+
-
-
+
+
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
+
+
+
-
+
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/assets/0.png b/assets/0.png
new file mode 100755
index 0000000..ff67dcb
Binary files /dev/null and b/assets/0.png differ
diff --git a/assets/1.png b/assets/1.png
new file mode 100755
index 0000000..d8eeb12
Binary files /dev/null and b/assets/1.png differ
diff --git a/assets/10.png b/assets/10.png
new file mode 100755
index 0000000..80f1fb8
Binary files /dev/null and b/assets/10.png differ
diff --git a/assets/11.png b/assets/11.png
new file mode 100755
index 0000000..2b24594
Binary files /dev/null and b/assets/11.png differ
diff --git a/assets/12.png b/assets/12.png
new file mode 100755
index 0000000..243252f
Binary files /dev/null and b/assets/12.png differ
diff --git a/assets/13.png b/assets/13.png
new file mode 100755
index 0000000..f80d832
Binary files /dev/null and b/assets/13.png differ
diff --git a/assets/14.png b/assets/14.png
new file mode 100755
index 0000000..6227795
Binary files /dev/null and b/assets/14.png differ
diff --git a/assets/15.png b/assets/15.png
new file mode 100755
index 0000000..a11c080
Binary files /dev/null and b/assets/15.png differ
diff --git a/assets/16.png b/assets/16.png
new file mode 100755
index 0000000..616b17f
Binary files /dev/null and b/assets/16.png differ
diff --git a/assets/17.png b/assets/17.png
new file mode 100755
index 0000000..45be9cc
Binary files /dev/null and b/assets/17.png differ
diff --git a/assets/18.png b/assets/18.png
new file mode 100755
index 0000000..9d3ee2e
Binary files /dev/null and b/assets/18.png differ
diff --git a/assets/19.png b/assets/19.png
new file mode 100755
index 0000000..94e1ad5
Binary files /dev/null and b/assets/19.png differ
diff --git a/assets/2.png b/assets/2.png
new file mode 100755
index 0000000..77f0ff9
Binary files /dev/null and b/assets/2.png differ
diff --git a/assets/20.png b/assets/20.png
new file mode 100755
index 0000000..e479fe1
Binary files /dev/null and b/assets/20.png differ
diff --git a/assets/21.png b/assets/21.png
new file mode 100755
index 0000000..0acf583
Binary files /dev/null and b/assets/21.png differ
diff --git a/assets/22.png b/assets/22.png
new file mode 100755
index 0000000..27dbf8f
Binary files /dev/null and b/assets/22.png differ
diff --git a/assets/23.png b/assets/23.png
new file mode 100755
index 0000000..baddf0d
Binary files /dev/null and b/assets/23.png differ
diff --git a/assets/24.png b/assets/24.png
new file mode 100755
index 0000000..cb3b1b9
Binary files /dev/null and b/assets/24.png differ
diff --git a/assets/25.png b/assets/25.png
new file mode 100755
index 0000000..b01914d
Binary files /dev/null and b/assets/25.png differ
diff --git a/assets/26.png b/assets/26.png
new file mode 100755
index 0000000..35c8555
Binary files /dev/null and b/assets/26.png differ
diff --git a/assets/27.png b/assets/27.png
new file mode 100755
index 0000000..2a17c9f
Binary files /dev/null and b/assets/27.png differ
diff --git a/assets/28.png b/assets/28.png
new file mode 100755
index 0000000..fd90627
Binary files /dev/null and b/assets/28.png differ
diff --git a/assets/29.png b/assets/29.png
new file mode 100755
index 0000000..a9da274
Binary files /dev/null and b/assets/29.png differ
diff --git a/assets/3.png b/assets/3.png
new file mode 100755
index 0000000..2d380c2
Binary files /dev/null and b/assets/3.png differ
diff --git a/assets/30.png b/assets/30.png
new file mode 100755
index 0000000..32a1180
Binary files /dev/null and b/assets/30.png differ
diff --git a/assets/31.png b/assets/31.png
new file mode 100755
index 0000000..3749380
Binary files /dev/null and b/assets/31.png differ
diff --git a/assets/32.png b/assets/32.png
new file mode 100755
index 0000000..a98530d
Binary files /dev/null and b/assets/32.png differ
diff --git a/assets/33.png b/assets/33.png
new file mode 100755
index 0000000..5a230ed
Binary files /dev/null and b/assets/33.png differ
diff --git a/assets/34.png b/assets/34.png
new file mode 100755
index 0000000..a139b69
Binary files /dev/null and b/assets/34.png differ
diff --git a/assets/35.png b/assets/35.png
new file mode 100755
index 0000000..9164448
Binary files /dev/null and b/assets/35.png differ
diff --git a/assets/4.png b/assets/4.png
new file mode 100755
index 0000000..e74404c
Binary files /dev/null and b/assets/4.png differ
diff --git a/assets/5.png b/assets/5.png
new file mode 100755
index 0000000..72a6808
Binary files /dev/null and b/assets/5.png differ
diff --git a/assets/6.png b/assets/6.png
new file mode 100755
index 0000000..3eb2e30
Binary files /dev/null and b/assets/6.png differ
diff --git a/assets/7.png b/assets/7.png
new file mode 100755
index 0000000..fd6d887
Binary files /dev/null and b/assets/7.png differ
diff --git a/assets/8.png b/assets/8.png
new file mode 100755
index 0000000..c3178dd
Binary files /dev/null and b/assets/8.png differ
diff --git a/assets/9.png b/assets/9.png
new file mode 100755
index 0000000..1d88115
Binary files /dev/null and b/assets/9.png differ
diff --git a/assets/pred_佯.png b/assets/pred_佯.png
new file mode 100755
index 0000000..a5e39ae
Binary files /dev/null and b/assets/pred_佯.png differ
diff --git a/assets/pred_俺.png b/assets/pred_俺.png
new file mode 100755
index 0000000..eab7624
Binary files /dev/null and b/assets/pred_俺.png differ
diff --git a/assets/pred_傍.png b/assets/pred_傍.png
new file mode 100755
index 0000000..6390ff6
Binary files /dev/null and b/assets/pred_傍.png differ
diff --git a/assets/pred_傲.png b/assets/pred_傲.png
new file mode 100755
index 0000000..b59eb2d
Binary files /dev/null and b/assets/pred_傲.png differ
diff --git a/assets/pred_八.png b/assets/pred_八.png
new file mode 100755
index 0000000..bfb50b5
Binary files /dev/null and b/assets/pred_八.png differ
diff --git a/assets/pred_军.png b/assets/pred_军.png
new file mode 100755
index 0000000..893d2ac
Binary files /dev/null and b/assets/pred_军.png differ
diff --git a/assets/pred_凹.png b/assets/pred_凹.png
new file mode 100755
index 0000000..d20521a
Binary files /dev/null and b/assets/pred_凹.png differ
diff --git a/assets/pred_吧.png b/assets/pred_吧.png
new file mode 100755
index 0000000..e2c630c
Binary files /dev/null and b/assets/pred_吧.png differ
diff --git a/assets/pred_呐.png b/assets/pred_呐.png
new file mode 100755
index 0000000..23a1706
Binary files /dev/null and b/assets/pred_呐.png differ
diff --git a/assets/pred_奥.png b/assets/pred_奥.png
new file mode 100755
index 0000000..60a4d73
Binary files /dev/null and b/assets/pred_奥.png differ
diff --git a/assets/pred_安.png b/assets/pred_安.png
new file mode 100755
index 0000000..4f47399
Binary files /dev/null and b/assets/pred_安.png differ
diff --git a/assets/pred_宋.png b/assets/pred_宋.png
new file mode 100755
index 0000000..e3949cd
Binary files /dev/null and b/assets/pred_宋.png differ
diff --git a/assets/pred_宽.png b/assets/pred_宽.png
new file mode 100755
index 0000000..0e4530b
Binary files /dev/null and b/assets/pred_宽.png differ
diff --git a/assets/pred_巴.png b/assets/pred_巴.png
new file mode 100755
index 0000000..71139f9
Binary files /dev/null and b/assets/pred_巴.png differ
diff --git a/assets/pred_年.png b/assets/pred_年.png
new file mode 100755
index 0000000..e4d73e6
Binary files /dev/null and b/assets/pred_年.png differ
diff --git a/assets/pred_弟.png b/assets/pred_弟.png
new file mode 100755
index 0000000..2874c32
Binary files /dev/null and b/assets/pred_弟.png differ
diff --git a/assets/pred_捞.png b/assets/pred_捞.png
new file mode 100755
index 0000000..f656b0a
Binary files /dev/null and b/assets/pred_捞.png differ
diff --git a/assets/pred_换.png b/assets/pred_换.png
new file mode 100755
index 0000000..a03062e
Binary files /dev/null and b/assets/pred_换.png differ
diff --git a/assets/pred_昂.png b/assets/pred_昂.png
new file mode 100755
index 0000000..861d5aa
Binary files /dev/null and b/assets/pred_昂.png differ
diff --git a/assets/pred_晒.png b/assets/pred_晒.png
new file mode 100755
index 0000000..25551c8
Binary files /dev/null and b/assets/pred_晒.png differ
diff --git a/assets/pred_杯.png b/assets/pred_杯.png
new file mode 100755
index 0000000..bec7330
Binary files /dev/null and b/assets/pred_杯.png differ
diff --git a/assets/pred_梆.png b/assets/pred_梆.png
new file mode 100755
index 0000000..65b09f4
Binary files /dev/null and b/assets/pred_梆.png differ
diff --git a/assets/pred_氨.png b/assets/pred_氨.png
new file mode 100755
index 0000000..af2a5d3
Binary files /dev/null and b/assets/pred_氨.png differ
diff --git a/assets/pred_男.png b/assets/pred_男.png
new file mode 100755
index 0000000..29c97e1
Binary files /dev/null and b/assets/pred_男.png differ
diff --git a/assets/pred_百.png b/assets/pred_百.png
new file mode 100755
index 0000000..8861d4e
Binary files /dev/null and b/assets/pred_百.png differ
diff --git a/assets/pred_磅.png b/assets/pred_磅.png
new file mode 100755
index 0000000..a963688
Binary files /dev/null and b/assets/pred_磅.png differ
diff --git a/assets/pred_笨.png b/assets/pred_笨.png
new file mode 100755
index 0000000..0934645
Binary files /dev/null and b/assets/pred_笨.png differ
diff --git a/assets/pred_肮.png b/assets/pred_肮.png
new file mode 100755
index 0000000..db8e777
Binary files /dev/null and b/assets/pred_肮.png differ
diff --git a/assets/pred_蔼.png b/assets/pred_蔼.png
new file mode 100755
index 0000000..172b8f1
Binary files /dev/null and b/assets/pred_蔼.png differ
diff --git a/assets/pred_败.png b/assets/pred_败.png
new file mode 100755
index 0000000..b90958c
Binary files /dev/null and b/assets/pred_败.png differ
diff --git a/assets/pred_跋.png b/assets/pred_跋.png
new file mode 100755
index 0000000..293629b
Binary files /dev/null and b/assets/pred_跋.png differ
diff --git a/assets/pred_邦.png b/assets/pred_邦.png
new file mode 100755
index 0000000..a8b8798
Binary files /dev/null and b/assets/pred_邦.png differ
diff --git a/assets/pred_霸.png b/assets/pred_霸.png
new file mode 100755
index 0000000..2f87ca8
Binary files /dev/null and b/assets/pred_霸.png differ
diff --git a/assets/pred_颁.png b/assets/pred_颁.png
new file mode 100755
index 0000000..036e80b
Binary files /dev/null and b/assets/pred_颁.png differ
diff --git a/assets/pred_饶.png b/assets/pred_饶.png
new file mode 100755
index 0000000..a514a62
Binary files /dev/null and b/assets/pred_饶.png differ
diff --git a/dataset/._.DS_Store b/dataset/._.DS_Store
deleted file mode 100755
index 9ad849c..0000000
Binary files a/dataset/._.DS_Store and /dev/null differ
diff --git a/dataset/.gitignore b/dataset/.gitignore
index e24f3a0..744fc31 100755
--- a/dataset/.gitignore
+++ b/dataset/.gitignore
@@ -1,3 +1,5 @@
hwdb_raw/
*.tfrecord
casia_hwdb.pyhwdb_11.tfrecord
+HWDB1.1tst_gnt.tfrecord
+HWDB1.1trn_gnt.tfrecord
\ No newline at end of file
diff --git a/dataset/__pycache__/casia_hwdb.cpython-36.pyc b/dataset/__pycache__/casia_hwdb.cpython-36.pyc
index 6a8613f..7e17f8d 100755
Binary files a/dataset/__pycache__/casia_hwdb.cpython-36.pyc and b/dataset/__pycache__/casia_hwdb.cpython-36.pyc differ
diff --git a/dataset/casia_hwdb.py b/dataset/casia_hwdb.py
index 39bbb99..253c764 100755
--- a/dataset/casia_hwdb.py
+++ b/dataset/casia_hwdb.py
@@ -19,34 +19,6 @@ import os
this_dir = os.path.dirname(os.path.abspath(__file__))
-class CASIAHWDBGNT(object):
- """
- A .gnt file may contains many images and charactors
- """
-
- def __init__(self, f_p):
- self.f_p = f_p
-
- def get_data_iter(self):
- header_size = 10
- with open(self.f_p, 'rb') as f:
- while True:
- header = np.fromfile(f, dtype='uint8', count=header_size)
- if not header.size:
- break
- sample_size = header[0] + (header[1] << 8) + (
- header[2] << 16) + (header[3] << 24)
- tagcode = header[5] + (header[4] << 8)
- width = header[6] + (header[7] << 8)
- height = header[8] + (header[9] << 8)
- if header_size + width * height != sample_size:
- break
- image = np.fromfile(f, dtype='uint8',
- count=width * height).reshape(
- (height, width))
- yield image, tagcode
-
-
def parse_example(record):
features = tf.io.parse_single_example(record,
features={
@@ -57,34 +29,101 @@ def parse_example(record):
})
img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
img = tf.cast(tf.reshape(img, (64, 64)), dtype=tf.float32)
- label = tf.cast(features['label'], tf.int32)
+ label = tf.cast(features['label'], tf.int64)
+ return {'image': img, 'label': label}
+
+
+def parse_example_v2(record):
+ """
+ latest version format
+ :param record:
+ :return:
+ """
+ features = tf.io.parse_single_example(record,
+ features={
+ 'width':
+ tf.io.FixedLenFeature([], tf.int64),
+ 'height':
+ tf.io.FixedLenFeature([], tf.int64),
+ 'label':
+ tf.io.FixedLenFeature([], tf.int64),
+ 'image':
+ tf.io.FixedLenFeature([], tf.string),
+ })
+ img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
+ # we can not reshape since it stores with original size
+ w = features['width']
+ h = features['height']
+ img = tf.cast(tf.reshape(img, (w, h)), dtype=tf.float32)
+ label = tf.cast(features['label'], tf.int64)
return {'image': img, 'label': label}
def load_ds():
- input_files = ['dataset/hwdb_11.tfrecord']
+ input_files = ['dataset/HWDB1.1trn_gnt.tfrecord']
ds = tf.data.TFRecordDataset(input_files)
ds = ds.map(parse_example)
return ds
-def load_characters():
+def load_val_ds():
+ input_files = ['dataset/HWDB1.1tst_gnt.tfrecord']
+ ds = tf.data.TFRecordDataset(input_files)
+ ds = ds.map(parse_example_v2)
+ return ds
- a = open(os.path.join(this_dir, 'charactors.txt'), 'r').readlines()
+
+def load_characters():
+ a = open(os.path.join(this_dir, 'characters.txt'), 'r').readlines()
return [i.strip() for i in a]
if __name__ == "__main__":
ds = load_ds()
+ val_ds = load_val_ds()
+ val_ds = val_ds.shuffle(100)
charactors = load_characters()
- for img, label in ds.take(9):
- # start training on model...
- img = img.numpy()
- img = np.resize(img, (64, 64))
- print(img.shape)
- label = label.numpy()
- label = charactors[label]
- print(label)
- cv2.imshow('rr', img)
+
+ is_show_combine = False
+ if is_show_combine:
+ combined = np.zeros([32*10, 32*20], dtype=np.uint8)
+ i = 0
+ res = ''
+ for data in val_ds.take(200):
+ # start training on model...
+ img, label = data['image'], data['label']
+ img = img.numpy()
+ img = np.array(img, dtype=np.uint8)
+ img = cv2.resize(img, (32, 32))
+ label = label.numpy()
+ label = charactors[label]
+ print(label)
+ row = i // 20
+ col = i % 20
+ print(i, col)
+ print(row, col)
+ combined[row*32: (row+1)*32, col*32: (col+1)*32] = img
+ i += 1
+ res += label
+ cv2.imshow('rr', combined)
+ print(res)
+ cv2.imwrite('assets/combined.png', combined)
cv2.waitKey(0)
- # break
+ # break
+ else:
+ i = 0
+ for data in val_ds.take(36):
+ # start training on model...
+ img, label = data['image'], data['label']
+ img = img.numpy()
+ img = np.array(img, dtype=np.uint8)
+ print(img.shape)
+ # img = cv2.resize(img, (64, 64))
+ label = label.numpy()
+ label = charactors[label]
+ print(label)
+ cv2.imshow('rr', img)
+ cv2.imwrite('assets/{}.png'.format(i), img)
+ i += 1
+ cv2.waitKey(0)
+ # break
diff --git a/dataset/charactors.txt b/dataset/characters.txt
similarity index 100%
rename from dataset/charactors.txt
rename to dataset/characters.txt
diff --git a/dataset/convert_to_tfrecord.py b/dataset/convert_to_tfrecord.py
index 9bf9e44..9a195b4 100755
--- a/dataset/convert_to_tfrecord.py
+++ b/dataset/convert_to_tfrecord.py
@@ -1,6 +1,7 @@
"""
generates HWDB data into tfrecord
"""
+import sys
import struct
import numpy as np
import cv2
@@ -23,69 +24,83 @@ class CASIAHWDBGNT(object):
with open(self.f_p, 'rb') as f:
while True:
header = np.fromfile(f, dtype='uint8', count=header_size)
- if not header.size:
+ if not header.size:
break
- sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)
- tagcode = header[5] + (header[4]<<8)
- width = header[6] + (header[7]<<8)
- height = header[8] + (header[9]<<8)
- if header_size + width*height != sample_size:
+ sample_size = header[0] + (header[1] << 8) + (header[2] << 16) + (header[3] << 24)
+ tagcode = header[5] + (header[4] << 8)
+ width = header[6] + (header[7] << 8)
+ height = header[8] + (header[9] << 8)
+ if header_size + width * height != sample_size:
break
- image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))
+ image = np.fromfile(f, dtype='uint8', count=width * height).reshape((height, width))
yield image, tagcode
-def run():
- all_hwdb_gnt_files = glob.glob('./hwdb_raw/HWDB1.1trn_gnt/*.gnt')
+def run(p):
+ all_hwdb_gnt_files = glob.glob(os.path.join(p, '*.gnt'))
logging.info('got all {} gnt files.'.format(len(all_hwdb_gnt_files)))
logging.info('gathering charset...')
charset = []
- if os.path.exists('charactors.txt'):
- logging.info('found exist charactors.txt...')
- with open('charactors.txt', 'r') as f:
+ if os.path.exists('characters.txt'):
+ logging.info('found exist characters.txt...')
+ with open('characters.txt', 'r') as f:
charset = f.readlines()
charset = [i.strip() for i in charset]
else:
- for gnt in all_hwdb_gnt_files:
- hwdb = CASIAHWDBGNT(gnt)
- for img, tagcode in hwdb.get_data_iter():
- try:
- label = struct.pack('>H', tagcode).decode('gb2312')
- label = label.replace('\x00', '')
- charset.append(label)
- except Exception as e:
- continue
- charset = sorted(set(charset))
- with open('charactors.txt', 'w') as f:
- f.writelines('\n'.join(charset))
- logging.info('all got {} charactors.'.format(len(charset)))
+ if 'trn' in p:
+ for gnt in all_hwdb_gnt_files:
+ hwdb = CASIAHWDBGNT(gnt)
+ for img, tagcode in hwdb.get_data_iter():
+ try:
+ label = struct.pack('>H', tagcode).decode('gb2312')
+ label = label.replace('\x00', '')
+ charset.append(label)
+ except Exception as e:
+ continue
+ charset = sorted(set(charset))
+ with open('characters.txt', 'w') as f:
+ f.writelines('\n'.join(charset))
+ logging.info('all got {} characters.'.format(len(charset)))
logging.info('{}'.format(charset[:10]))
-
- tfrecord_f = 'casia_hwdb_1.0_1.1.tfrecord'
+
+ tfrecord_f = os.path.basename(os.path.dirname(p)) + '.tfrecord'
+ logging.info('tfrecord file saved into: {}'.format(tfrecord_f))
i = 0
with tf.io.TFRecordWriter(tfrecord_f) as tfrecord_writer:
for gnt in all_hwdb_gnt_files:
hwdb = CASIAHWDBGNT(gnt)
for img, tagcode in hwdb.get_data_iter():
try:
- img = cv2.resize(img, (64, 64))
- label = struct.pack('>H', tagcode).decode('gb2312')
+ # why do you need resize?
+ w = img.shape[0]
+ h = img.shape[1]
+ # img = cv2.resize(img, (64, 64))
+ label = struct.pack('>H', tagcode).decode('gb2312')
label = label.replace('\x00', '')
index = charset.index(label)
# save img, label as example
example = tf.train.Example(features=tf.train.Features(
feature={
- "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
- 'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()]))
- }))
+ "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
+ 'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()])),
+ 'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[w])),
+ 'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[h])),
+ }))
tfrecord_writer.write(example.SerializeToString())
- if i%500:
+ if i % 5000:
logging.info('solved {} examples. {}: {}'.format(i, label, index))
i += 1
except Exception as e:
logging.error(e)
+ e.with_traceback()
continue
logging.info('done.')
+
if __name__ == "__main__":
- run()
\ No newline at end of file
+ if len(sys.argv) <= 1:
+ logging.error('send a pattern like this: {}'.format('./hwdb_raw/HWDB1.1trn_gnt/'))
+ else:
+ p = sys.argv[1]
+ logging.info('converting from: {}'.format(p))
+ run(p)
diff --git a/dataset/get_hwdb_1.0_1.1.sh b/dataset/get_hwdb_1.0_1.1.sh
index 2b4e477..a0c1289 100755
--- a/dataset/get_hwdb_1.0_1.1.sh
+++ b/dataset/get_hwdb_1.0_1.1.sh
@@ -1,2 +1,2 @@
wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zip
-wget wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip
+wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip
diff --git a/demo.py b/demo.py
new file mode 100755
index 0000000..f0a6ce2
--- /dev/null
+++ b/demo.py
@@ -0,0 +1,77 @@
+"""
+
+inference on a single Chinese character
+image and recognition the meaning of it
+
+"""
+from alfred.dl.tf.common import mute_tf
+mute_tf()
+import os
+import cv2
+import sys
+import numpy as np
+import tensorflow as tf
+
+from alfred.utils.log import logger as logging
+import tensorflow_datasets as tfds
+from dataset.casia_hwdb import load_ds, load_characters, load_val_ds
+from models.cnn_net import CNNNet, build_net_002, build_net_003
+import glob
+
+
+target_size = 64
+characters = load_characters()
+num_classes = len(characters)
+# use_keras_fit = False
+use_keras_fit = True
+ckpt_path = './checkpoints/cn_ocr-{epoch}.ckpt'
+
+
+def preprocess(x):
+ """
+ minus mean pixel or normalize?
+ """
+ # original is 64x64, add a channel dim
+ x['image'] = tf.expand_dims(x['image'], axis=-1)
+ x['image'] = tf.image.resize(x['image'], (target_size, target_size))
+ x['image'] = (x['image'] - 128.) / 128.
+ return x['image'], x['label']
+
+
+def get_model():
+ # init model
+ model = build_net_003((64, 64, 1), num_classes)
+ logging.info('model loaded.')
+
+ latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path))
+ if latest_ckpt:
+ start_epoch = int(latest_ckpt.split('-')[1].split('.')[0])
+ model.load_weights(latest_ckpt)
+ logging.info('model resumed from: {} at epoch: {}'.format(latest_ckpt, start_epoch))
+ return model
+ else:
+ logging.error('can not found any checkpoints matched: {}'.format(ckpt_path))
+
+
+def predict(model, img_f):
+ ori_img = cv2.imread(img_f)
+ img = tf.expand_dims(ori_img[:, :, 0], axis=-1)
+ img = tf.image.resize(img, (target_size, target_size))
+ img = (img - 128.)/128.
+ img = tf.expand_dims(img, axis=0)
+ print(img.shape)
+ out = model(img).numpy()
+ print('predict: {}'.format(characters[np.argmax(out[0])]))
+ cv2.imwrite('assets/pred_{}.png'.format(characters[np.argmax(out[0])]), ori_img)
+
+
+if __name__ == '__main__':
+ img_files = glob.glob('assets/*.png')
+ model = get_model()
+ for img_f in img_files:
+ a = cv2.imread(img_f)
+ cv2.imshow('rr', a)
+ predict(model, img_f)
+ cv2.waitKey(0)
+
+
diff --git a/models/__pycache__/cnn_net.cpython-36.pyc b/models/__pycache__/cnn_net.cpython-36.pyc
index 9402b81..533589d 100755
Binary files a/models/__pycache__/cnn_net.cpython-36.pyc and b/models/__pycache__/cnn_net.cpython-36.pyc differ
diff --git a/models/cnn_net.py b/models/cnn_net.py
index f70460a..0828b63 100755
--- a/models/cnn_net.py
+++ b/models/cnn_net.py
@@ -1,24 +1,5 @@
-'''
-
-conv_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv1')
-# (inputs,num_outputs,[卷积核个数] kernel_size,[卷积核的高度,卷积核的宽]stride=1,padding='SAME',)
-max_pool_1 = slim.max_pool2d(conv_1, [2, 2], [2, 2], padding='SAME')
-conv_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv2')
-max_pool_2 = slim.max_pool2d(conv_2, [2, 2], [2, 2], padding='SAME')
-conv_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3')
-max_pool_3 = slim.max_pool2d(conv_3, [2, 2], [2, 2], padding='SAME')
-
-flatten = slim.flatten(max_pool_3)
-fc1 = slim.fully_connected(tf.nn.dropout(flatten, keep_prob), 1024, activation_fn=tf.nn.tanh, scope='fc1')
-logits = slim.fully_connected(tf.nn.dropout(fc1, keep_prob), FLAGS.charset_size, activation_fn=None, scope='fc2')
-# logits = slim.fully_connected(flatten, FLAGS.charset_size, activation_fn=None, reuse=reuse, scope='fc')
-loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
-# y表示的是实际类别,y_表示预测结果,这实际上面是把原来的神经网络输出层的softmax和cross_entrop何在一起计算,为了追求速度
-accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32))
-'''
-
import tensorflow as tf
from tensorflow.keras import layers
@@ -54,6 +35,23 @@ def build_net_002(input_shape, n_classes):
return model
+# this model is converge in terms of chinese characters classification
+# so simply is effective sometimes, adding a dense maybe model will be better?
+def build_net_003(input_shape, n_classes):
+ model = tf.keras.Sequential([
+ layers.Conv2D(input_shape=input_shape, filters=32, kernel_size=(3, 3), strides=(1, 1),
+ padding='same', activation='relu'),
+ layers.MaxPool2D(pool_size=(2, 2), padding='same'),
+ layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same'),
+ layers.MaxPool2D(pool_size=(2, 2), padding='same'),
+
+ layers.Flatten(),
+ # layers.Dense(1024, activation='relu'),
+ layers.Dense(n_classes, activation='softmax')
+ ])
+ return model
+
+
# some models wrapped into tf.keras.Model
class CNNNet(tf.keras.Model):
diff --git a/readme.md b/readme.md
index 2f7414f..b2c911e 100755
--- a/readme.md
+++ b/readme.md
@@ -1,9 +1,29 @@
-# TensorFlow 2.0 中文手写字识别
+# TensorFlow 2.0 中文手写字识别(汉字OCR)
+
+> 在开始之前,必须要说明的是,本教程完全基于TensorFlow2.0 接口编写,请误与其他古老的教程混为一谈,本教程除了手把手教大家完成这个挑战性任务之外,更多的会教大家如何分析整个调参过程的思考过程,力求把人工智能算法工程师日常的工作通过这个例子毫无保留的展示给大家。另外,我们建立了一个高端算法分享平台,希望得到大家的支持:http://manaai.cn , 也欢迎大家来我们的AI社区交流: http://talk.strangeai.pro
+
+
+
+还在玩minist?fashionmnist?不如来尝试一下类别多大3000+的汉字手写识别吧!!虽然以前有一些文章教大家如何操作,但是大多比较古老,这篇文章将用全新的TensorFlow 2.0 来教大家如何搭建一个中文OCR系统!
+
+让我们来看一下,相比于简单minist识别,汉字识别具有哪些难点:
+
+- 搜索空间空前巨大,我们使用的数据集1.0版本汉字就多大3755个,如果加上1.1版本一起,总共汉字可以分为多达7599+个类别!这比10个阿拉伯字母识别难度大很多!
+- 数据集处理挑战更大,相比于mnist和fasionmnist来说,汉字手写字体识别数据集非常少,而且仅有的数据集数据预处理难度非常大,非常不直观,但是,千万别吓到,相信你看完本教程一定会收货满满!
+- 汉字识别更考验选手的建模能力,还在分类花?分类猫和狗?随便搭建的几层在搜索空间巨大的汉字手写识别里根本不work!你现在是不是想用很深的网络跃跃欲试?更深的网络在这个任务上可能根本不可行!!看完本教程我们就可以一探究竟!总之一句话,模型太简单和太复杂都不好,甚至会发散!(想亲身体验模型训练发散抓狂的可以来尝试一下!)。
+
+但是,挑战这个任务也有很多好处:
+
+- 本教程基于TensorFlow2.0,从数据预处理,图片转Tensor以及Tensor的一系列骚操作都包含在内!做完本任务相信你会对TensorFlow2.0 API有一个很深刻的认识!
+- 如果你是新手,通过这个教程你完全可以深入体会一下调参(或者说随意修改网络)的纠结性和蛋疼性!
+
+
本项目实现了基于CNN的中文手写字识别,并且采用标准的**tensorflow 2.0 api** 来构建!相比对简单的字母手写识别,本项目更能体现模型设计的精巧性和数据增强的熟练操作性,并且最终设计出来的模型可以直接应用于工业场合,比如 **票据识别**, **手写文本自动扫描** 等,相比于百度api接口或者QQ接口等,具有可优化性、免费性、本地性等优点。
-## Data
+
+## 数据准备
在开始之前,先介绍一下本项目所采用的数据信息。我们的数据全部来自于CASIA的开源中文手写字数据集,该数据集分为两部分:
@@ -11,17 +31,23 @@
- CASIA-OLHWDB:在线的HWDB,格式一样,包含了约7185个汉字以及171个英文字母、数字、标点符号等,我们不用。
其实你下载1.0的train和test差不多已经够了,可以直接运行 `dataset/get_hwdb_1.0_1.1.sh` 下载。原始数据下载链接点击[这里](http://www.nlpr.ia.ac.cn/databases/handwriting/Offline_database.html).
-由于原始数据过于复杂,我们自己写了一个数据wrapper方便读取,统一将其转换为类似于Dataframe (Pandas)的格式,这样可以将一个字的特征和label方便的显示,也可以十分方便的将手写字转换为图片,采用CNN进行处理。这是我们展示的效果:
+由于原始数据过于复杂,我们使用一个类来封装数据读取过程,这是我们展示的效果:
-
+
+
+
+
+
+
+看到这么密密麻麻的文字相信连人类都.... 开始头疼了,这些复杂的文字能够通过一个神经网络来识别出来??答案是肯定的.... 不有得感叹一下神经网络的强大。。上面的部分文字识别出来的结果是这样的:
+
+
+
+
-其对应的label为:
-```
-['!', '"', '#', '$', '%', '&', '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '、', '。', '々', '…', '‘', '’', '“', '”']
-```
关于数据的处理部分,从服务器下载到的原始数据是 `trn_gnt.zip` 解压之后是 `gnt.alz`, 需要再次解压得到一个包含 gnt文件的文件夹。里面每一个gnt文件都包含了若干个汉字及其标注。直接处理比较麻烦,也不方便抽取出图片再进行操作,**虽然转为图片存入文件夹比较直观,但是不适合批量读取和训练**, 后面我们统一转为tfrecord进行训练。
@@ -33,8 +59,282 @@
- `cd dataset && python3 convert_to_tfrecord.py`, 请注意我们使用的是tf2.0;
- 你需要修改对应的路径,等待生成完成,大概有89万个example,如果1.0和1.1都用,那估计得double。
-
-## Model
-关于我们采用的OCR模型的构建,我们大致采用的是比较先进的MobileNetV3架构,同时设计了一个修改的过的MobileNetV3Big的更深网络。主要考虑模型的轻量型和表达能力。最终训练结果表明,我们的模型可以在中文手写字上达到约99.8%的准确率。
\ No newline at end of file
+
+## 模型构建
+
+关于我们采用的OCR模型的构建,我们构建了3个模型分别做测试,三个模型的复杂度逐渐的复杂,网络层数逐渐深入。但是到最后发现,最复杂的那个模型竟然不收敛。这个其中一个稍微简单模型的训练过程:
+
+
+
+
+
+大家可以看到,准确率可以在短时间内达到87%非常不错,测试集的准确率大概在40%,由于测试集中的样本在训练集中完全没有出现,相对训练集的准确率来讲偏低。可能原因无外乎两个,一个事模型泛化性能不强,另外一个原因是训练还不够。
+
+不过好在这个简单的模型也能达到训练集90%的准确率,it's a good start. 让我们来看一下如何快速的构建一个OCR网络模型:
+
+
+
+```python
+def build_net_003(input_shape, n_classes):
+ model = tf.keras.Sequential([
+ layers.Conv2D(input_shape=input_shape, filters=32, kernel_size=(3, 3), strides=(1, 1),
+ padding='same', activation='relu'),
+ layers.MaxPool2D(pool_size=(2, 2), padding='same'),
+ layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same'),
+ layers.MaxPool2D(pool_size=(2, 2), padding='same'),
+
+ layers.Flatten(),
+ layers.Dense(n_classes, activation='softmax')
+ ])
+ return model
+```
+
+这是我们使用keras API构建的一个模型,它足够简单,仅仅包含两个卷积层以及两个maxpool层。下面我们让大家知道,即便是再简单的模型,有时候也能发挥出巨大的用处,对于某些特定的问题可能比更深的网络更有用途。关于这部分模型构建大家只要知道这么几点:
+
+- 如果你只是构建序列模型,没有太fancy的跳跃链接,你可以直接用`keras.Sequential` 来构建你的模型;
+- Conv2D中最好指定每个参数的名字,不要省略,否则别人不知道你的写的事输入的通道数还是filters。
+
+
+
+最后,在你看完本篇博客后,并准备自己动手复现这个教程的时候, 可以思考一下为什么下面这个模型就发散了呢?(仅仅稍微复杂一点):
+
+
+
+```python
+
+def build_net_002(input_shape, n_classes):
+ model = tf.keras.Sequential([
+ layers.Conv2D(input_shape=input_shape, filters=64, kernel_size=(3, 3), strides=(1, 1),
+ padding='same', activation='relu'),
+ layers.MaxPool2D(pool_size=(2, 2), padding='same'),
+ layers.Conv2D(filters=128, kernel_size=(3, 3), padding='same'),
+ layers.MaxPool2D(pool_size=(2, 2), padding='same'),
+ layers.Conv2D(filters=256, kernel_size=(3, 3), padding='same'),
+ layers.MaxPool2D(pool_size=(2, 2), padding='same'),
+
+ layers.Flatten(),
+ layers.Dense(1024, activation='relu'),
+ layers.Dense(n_classes, activation='softmax')
+ ])
+ return model
+```
+
+
+
+## 数据输入
+
+其实最复杂的还是数据准备过程啊。这里着重说一下,我们的数据存入tfrecords中的事image和label,也就是这么一个example:
+
+
+
+```
+ example = tf.train.Example(features=tf.train.Features(
+ feature={
+ "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
+ 'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()])),
+ 'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[w])),
+ 'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[h])),
+ }))
+```
+
+然后读取的时候相应的读取即可,这里告诉大家几点坑爹的地方:
+
+- 将numpyarray的bytes存入tfrecord跟将文件的bytes直接存入tfrecord解码的方式事不同的,由于我们的图片数据不是来自于本地文件,所以我们使用了一个tobytes()方法存入的事numpy array的bytes格式,它实际上并不包含维度信息,所以这就是坑爹的地方之一,如果你不同时存储width和height,你后面读取的时候便无法知道维度,**存储tfrecord顺便存储图片长宽事一个好的习惯**.
+- 关于不同的存储方式解码的方法有坑爹的地方,比如这里我们存储numpy array的bytes,通常情况下,你很难知道如何解码。。(不看本教程应该很多人不知道)
+
+最后load tfrecord也就比较直观了:
+
+
+
+```python
+def parse_example(record):
+ features = tf.io.parse_single_example(record,
+ features={
+ 'label':
+ tf.io.FixedLenFeature([], tf.int64),
+ 'image':
+ tf.io.FixedLenFeature([], tf.string),
+ })
+ img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
+ img = tf.cast(tf.reshape(img, (64, 64)), dtype=tf.float32)
+ label = tf.cast(features['label'], tf.int64)
+ return {'image': img, 'label': label}
+
+
+def parse_example_v2(record):
+ """
+ latest version format
+ :param record:
+ :return:
+ """
+ features = tf.io.parse_single_example(record,
+ features={
+ 'width':
+ tf.io.FixedLenFeature([], tf.int64),
+ 'height':
+ tf.io.FixedLenFeature([], tf.int64),
+ 'label':
+ tf.io.FixedLenFeature([], tf.int64),
+ 'image':
+ tf.io.FixedLenFeature([], tf.string),
+ })
+ img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
+ # we can not reshape since it stores with original size
+ w = features['width']
+ h = features['height']
+ img = tf.cast(tf.reshape(img, (w, h)), dtype=tf.float32)
+ label = tf.cast(features['label'], tf.int64)
+ return {'image': img, 'label': label}
+
+
+def load_ds():
+ input_files = ['dataset/HWDB1.1trn_gnt.tfrecord']
+ ds = tf.data.TFRecordDataset(input_files)
+ ds = ds.map(parse_example)
+ return ds
+```
+
+
+
+这个v2的版本就是兼容了新的存入长宽的方式,因为我第一次生成的时候就没有保存。。。最后入坑了。注意这行代码:
+
+```
+ img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
+```
+
+它是对raw bytes进行解码,这个解码跟从文件读取bytes存入tfrecord的有着本质的不同。**同时注意type的变化,这里以unit8的方式解码,因为我们存储进去的就是uint8**.
+
+
+
+## 训练过程
+
+不瞒你说,我一开始写了一个很复杂的模型,训练了大概一个晚上结果准确率0.00012, 发散了。后面改成了更简单的模型才收敛。整个过程的训练pipleline:
+
+
+
+```python
+def train():
+ all_characters = load_characters()
+ num_classes = len(all_characters)
+ logging.info('all characters: {}'.format(num_classes))
+ train_dataset = load_ds()
+ train_dataset = train_dataset.shuffle(100).map(preprocess).batch(32).repeat()
+
+ val_ds = load_val_ds()
+ val_ds = val_ds.shuffle(100).map(preprocess).batch(32).repeat()
+
+ for data in train_dataset.take(2):
+ print(data)
+
+ # init model
+ model = build_net_003((64, 64, 1), num_classes)
+ model.summary()
+ logging.info('model loaded.')
+
+ start_epoch = 0
+ latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path))
+ if latest_ckpt:
+ start_epoch = int(latest_ckpt.split('-')[1].split('.')[0])
+ model.load_weights(latest_ckpt)
+ logging.info('model resumed from: {}, start at epoch: {}'.format(latest_ckpt, start_epoch))
+ else:
+ logging.info('passing resume since weights not there. training from scratch')
+
+ if use_keras_fit:
+ model.compile(
+ optimizer=tf.keras.optimizers.Adam(),
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(),
+ metrics=['accuracy'])
+ callbacks = [
+ tf.keras.callbacks.ModelCheckpoint(ckpt_path,
+ save_weights_only=True,
+ verbose=1,
+ period=500)
+ ]
+ try:
+ model.fit(
+ train_dataset,
+ validation_data=val_ds,
+ validation_steps=1000,
+ epochs=15000,
+ steps_per_epoch=1024,
+ callbacks=callbacks)
+ except KeyboardInterrupt:
+ model.save_weights(ckpt_path.format(epoch=0))
+ logging.info('keras model saved.')
+ model.save_weights(ckpt_path.format(epoch=0))
+ model.save(os.path.join(os.path.dirname(ckpt_path), 'cn_ocr.h5'))
+```
+
+
+
+在本系列教程开篇之际,我们就立下了几条准则,其中一条就是**handle everything**, 从这里就能看出,它事一个很稳健的训练代码,同事也很自动化:
+
+- 自动寻找之前保存的最新模型;
+- 自动保存模型;
+- 捕捉ctrl + c事件保存模型。
+- 支持断点续训练
+
+大家在以后编写训练代码的时候其实可以保持这个好的习惯。
+
+OK,整个模型训练起来之后,可以在短时间内达到95%的准确率:
+
+
+
+
+
+效果还是很不错的!
+
+
+
+## 模型测试
+
+
+
+最后模型训练完了,时候测试一下模型效果到底咋样。我们使用了一些简单的文字来测试:
+
+
+
+
+
+这个字写的还真的。。。。具有鬼神之势。相信普通人类大部分字都能认出来,不过有些字还真的。。。。不好认。看看神经网络的表现怎么样!
+
+
+
+
+
+这是大概2000次训练的结果, 基本上能识别出来了!神经网络的认字能力还不错的! 收工!
+
+
+
+## 总结
+
+通过本教程,我们完成了使用tensorflow 2.0全新的API搭建一个中文汉字手写识别系统。模型基本能够实现我们想要的功能。要知道,这个模型可是在搜索空间多大3755的类别当中准确的找到最相似的类别!!通过本实验,我们有几点心得:
+
+- 神经网络不仅仅是在学习,它具有一定的想象力!!比如它的一些看着很像的字:拜-佯, 扮-捞,笨-苯.... 这些字如果手写出来,连人都比较难以辨认!!但是大家要知道这些字在类别上并不是相领的!也就是说,模型具有一定的联想能力!
+- 不管问题多复杂,要敢于动手、善于动手。
+
+最后希望大家对本文点个赞,编写教程不容易。希望大家多多支持。笨教程将支持为大家输出全新的tensorflow2.0教程!欢迎关注!!
+
+本文所有代码开源在:
+
+https://github.com/jinfagang/ocrcn_tf2.git
+
+
+
+记得随手star哦!!
+
+我们的AI社区:
+
+
+
+http://talk.strangeai.pro
+
+
+
+全球最大的开源AI代码平台:
+
+http://manaai.cn
+
diff --git a/sample.png b/sample.png
deleted file mode 100755
index 24184e1..0000000
Binary files a/sample.png and /dev/null differ
diff --git a/train.py b/train_simple.py
similarity index 75%
rename from train.py
rename to train_simple.py
index f589719..32cb83a 100755
--- a/train.py
+++ b/train_simple.py
@@ -1,11 +1,11 @@
-'''
-training HWDB Chinese charactors classification
-on MobileNetV2
-'''
+"""
+
+training a simple net on Chinese Characters classification dataset
+we got about 90% accuracy by simply applying a simple CNN net
+
+"""
from alfred.dl.tf.common import mute_tf
-
mute_tf()
-
import os
import sys
import numpy as np
@@ -13,8 +13,8 @@ import tensorflow as tf
from alfred.utils.log import logger as logging
import tensorflow_datasets as tfds
-from dataset.casia_hwdb import load_ds, load_characters
-from models.cnn_net import CNNNet, build_net_002
+from dataset.casia_hwdb import load_ds, load_characters, load_val_ds
+from models.cnn_net import CNNNet, build_net_002, build_net_003
@@ -29,10 +29,10 @@ def preprocess(x):
"""
minus mean pixel or normalize?
"""
+ # original is 64x64, add a channel dim
x['image'] = tf.expand_dims(x['image'], axis=-1)
x['image'] = tf.image.resize(x['image'], (target_size, target_size))
- x['image'] /= 255.
- x['image'] = 2 * x['image'] - 1
+ x['image'] = (x['image'] - 128.) / 128.
return x['image'], x['label']
@@ -41,10 +41,16 @@ def train():
num_classes = len(all_characters)
logging.info('all characters: {}'.format(num_classes))
train_dataset = load_ds()
- train_dataset = train_dataset.shuffle(100).map(preprocess).batch(4).repeat()
+ train_dataset = train_dataset.shuffle(100).map(preprocess).batch(32).repeat()
+
+ val_ds = load_val_ds()
+ val_ds = val_ds.shuffle(100).map(preprocess).batch(32).repeat()
+
+ for data in train_dataset.take(2):
+ print(data)
# init model
- model = build_net_002((64, 64, 1), num_classes)
+ model = build_net_003((64, 64, 1), num_classes)
model.summary()
logging.info('model loaded.')
@@ -62,10 +68,20 @@ def train():
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
+ callbacks = [
+ tf.keras.callbacks.ModelCheckpoint(ckpt_path,
+ save_weights_only=True,
+ verbose=1,
+ period=500)
+ ]
try:
model.fit(
- train_dataset, epochs=50,
- steps_per_epoch=700, )
+ train_dataset,
+ validation_data=val_ds,
+ validation_steps=1000,
+ epochs=15000,
+ steps_per_epoch=1024,
+ callbacks=callbacks)
except KeyboardInterrupt:
model.save_weights(ckpt_path.format(epoch=0))
logging.info('keras model saved.')