diff --git a/.DS_Store b/.DS_Store
new file mode 100755
index 0000000..bb3993e
Binary files /dev/null and b/.DS_Store differ
diff --git a/._.DS_Store b/._.DS_Store
new file mode 100755
index 0000000..9ad849c
Binary files /dev/null and b/._.DS_Store differ
diff --git a/.gitignore b/.gitignore
new file mode 100755
index 0000000..1d74e21
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+.vscode/
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100755
index 0000000..65531ca
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100755
index 0000000..d875ec3
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/ocrcn_tf2.iml b/.idea/ocrcn_tf2.iml
new file mode 100755
index 0000000..6711606
--- /dev/null
+++ b/.idea/ocrcn_tf2.iml
@@ -0,0 +1,11 @@
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100755
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/workspace.xml b/.idea/workspace.xml
new file mode 100755
index 0000000..bdd3ca6
--- /dev/null
+++ b/.idea/workspace.xml
@@ -0,0 +1,251 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ true
+ DEFINITION_ORDER
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1559385971852
+
+
+ 1559385971852
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/dataset/.DS_Store b/dataset/.DS_Store
new file mode 100755
index 0000000..44503b4
Binary files /dev/null and b/dataset/.DS_Store differ
diff --git a/dataset/._.DS_Store b/dataset/._.DS_Store
new file mode 100755
index 0000000..9ad849c
Binary files /dev/null and b/dataset/._.DS_Store differ
diff --git a/dataset/.gitignore b/dataset/.gitignore
old mode 100644
new mode 100755
index 60704fd..e24f3a0
--- a/dataset/.gitignore
+++ b/dataset/.gitignore
@@ -1,2 +1,3 @@
hwdb_raw/
-*.tfrecord
\ No newline at end of file
+*.tfrecord
+casia_hwdb.pyhwdb_11.tfrecord
diff --git a/dataset/casia_hwdb.py b/dataset/casia_hwdb.py
old mode 100644
new mode 100755
index bef2bd8..175a823
--- a/dataset/casia_hwdb.py
+++ b/dataset/casia_hwdb.py
@@ -6,9 +6,13 @@ we using this class to get .png and label from raw
.gnt data
"""
+from alfred.dl.tf.common import mute_tf
+
+mute_tf()
import struct
import numpy as np
import cv2
+import tensorflow as tf
class CASIAHWDBGNT(object):
@@ -24,61 +28,57 @@ class CASIAHWDBGNT(object):
with open(self.f_p, 'rb') as f:
while True:
header = np.fromfile(f, dtype='uint8', count=header_size)
- if not header.size:
+ if not header.size:
break
- sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)
- tagcode = header[5] + (header[4]<<8)
- width = header[6] + (header[7]<<8)
- height = header[8] + (header[9]<<8)
- if header_size + width*height != sample_size:
+ sample_size = header[0] + (header[1] << 8) + (
+ header[2] << 16) + (header[3] << 24)
+ tagcode = header[5] + (header[4] << 8)
+ width = header[6] + (header[7] << 8)
+ height = header[8] + (header[9] << 8)
+ if header_size + width * height != sample_size:
break
- image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))
+ image = np.fromfile(f, dtype='uint8',
+ count=width * height).reshape(
+ (height, width))
yield image, tagcode
-def resize_padding_or_crop(target_size, ori_img, padding_value=255):
- if len(ori_img.shape) == 3:
- res = np.zeros([ori_img.shape[0], target_size, target_size])
- else:
- res = np.ones([target_size, target_size])*padding_value
- end_x = target_size
- end_y = target_size
- start_x = 0
- start_y = 0
- if ori_img.shape[0] < target_size:
- end_x = int((target_size + ori_img.shape[0])/2)
- if ori_img.shape[1] < target_size:
- end_y = int((target_size + ori_img.shape[1])/2)
- if ori_img.shape[0] < target_size:
- start_x = int((target_size - ori_img.shape[0])/2)
- if ori_img.shape[1] < target_size:
- start_y = int((target_size - ori_img.shape[1])/2)
- res[start_x:end_x, start_y:end_y] = ori_img
- return np.array(res, dtype=np.uint8)
+def parse_example(record):
+ features = tf.io.parse_single_example(record,
+ features={
+ 'label':
+ tf.io.FixedLenFeature([], tf.int64),
+ 'image':
+ tf.io.FixedLenFeature([], tf.string),
+ })
+ img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
+ label = tf.cast(features['label'], tf.int32)
+ return img, label
+
+
+def load_ds():
+ input_files = ['casia_hwdb_1.0_1.1.tfrecord']
+ ds = tf.data.TFRecordDataset(input_files)
+ ds = ds.map(parse_example)
+ return ds
+
+
+def load_charactors():
+ a = open('charactors.txt', 'r').readlines()
+ return [i.strip() for i in a]
+
if __name__ == "__main__":
- gnt = CASIAHWDBGNT('samples/1001-f.gnt')
-
- full_img = np.zeros([900, 900], dtype=np.uint8)
- charset = []
- i = 0
- for img, tagcode in gnt.get_data_iter():
- # cv2.imshow('rr', img)
- try:
- label = struct.pack('>H', tagcode).decode('gb2312')
- img_padded = resize_padding_or_crop(90, img)
- col_idx = i%10
- row_idx = i//10
- full_img[row_idx*90:(row_idx+1)*90, col_idx*90:(col_idx+1)*90] = img_padded
- charset.append(label.replace('\x00', ''))
- if i >= 99:
- cv2.imshow('rrr', full_img)
- cv2.imwrite('sample.png', full_img)
- cv2.waitKey(0)
- print(charset)
- break
- i += 1
- except Exception as e:
- # print(e.with_traceback(0))
- print('decode error')
- continue
+ ds = load_ds()
+ charactors = load_charactors()
+ for img, label in ds.take(9):
+ # start training on model...
+ img = img.numpy()
+ img = np.resize(img, (64, 64))
+ print(img.shape)
+ label = label.numpy()
+ label = charactors[label]
+ print(label)
+ cv2.imshow('rr', img)
+ cv2.waitKey(0)
+ # break
diff --git a/dataset/charactors.txt b/dataset/charactors.txt
old mode 100644
new mode 100755
diff --git a/dataset/convert_to_tfrecord.py b/dataset/convert_to_tfrecord.py
old mode 100644
new mode 100755
index cd446e3..9bf9e44
--- a/dataset/convert_to_tfrecord.py
+++ b/dataset/convert_to_tfrecord.py
@@ -7,6 +7,7 @@ import cv2
from alfred.utils.log import logger as logging
import tensorflow as tf
import glob
+import os
class CASIAHWDBGNT(object):
@@ -39,20 +40,27 @@ def run():
logging.info('got all {} gnt files.'.format(len(all_hwdb_gnt_files)))
logging.info('gathering charset...')
charset = []
- for gnt in all_hwdb_gnt_files:
- hwdb = CASIAHWDBGNT(gnt)
- for img, tagcode in hwdb.get_data_iter():
- try:
- label = struct.pack('>H', tagcode).decode('gb2312')
- label = label.replace('\x00', '')
- charset.append(label)
- except Exception as e:
- continue
- charset = sorted(set(charset))
+ if os.path.exists('charactors.txt'):
+ logging.info('found exist charactors.txt...')
+ with open('charactors.txt', 'r') as f:
+ charset = f.readlines()
+ charset = [i.strip() for i in charset]
+ else:
+ for gnt in all_hwdb_gnt_files:
+ hwdb = CASIAHWDBGNT(gnt)
+ for img, tagcode in hwdb.get_data_iter():
+ try:
+ label = struct.pack('>H', tagcode).decode('gb2312')
+ label = label.replace('\x00', '')
+ charset.append(label)
+ except Exception as e:
+ continue
+ charset = sorted(set(charset))
+ with open('charactors.txt', 'w') as f:
+ f.writelines('\n'.join(charset))
logging.info('all got {} charactors.'.format(len(charset)))
- with open('charactors.txt', 'w') as f:
- f.writelines('\n'.join(charset))
-
+ logging.info('{}'.format(charset[:10]))
+
tfrecord_f = 'casia_hwdb_1.0_1.1.tfrecord'
i = 0
with tf.io.TFRecordWriter(tfrecord_f) as tfrecord_writer:
@@ -60,7 +68,7 @@ def run():
hwdb = CASIAHWDBGNT(gnt)
for img, tagcode in hwdb.get_data_iter():
try:
- img = cv.resize(img, (64, 64))
+ img = cv2.resize(img, (64, 64))
label = struct.pack('>H', tagcode).decode('gb2312')
label = label.replace('\x00', '')
index = charset.index(label)
@@ -68,11 +76,11 @@ def run():
example = tf.train.Example(features=tf.train.Features(
feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
- 'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img]))
+ 'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()]))
}))
tfrecord_writer.write(example.SerializeToString())
if i%500:
- logging.info('solved {} examples.'.format(i))
+ logging.info('solved {} examples. {}: {}'.format(i, label, index))
i += 1
except Exception as e:
logging.error(e)
diff --git a/dataset/casia_hwdb_1.0_1.1.tfrecord b/dataset/dataset_hwdb.py
old mode 100644
new mode 100755
similarity index 100%
rename from dataset/casia_hwdb_1.0_1.1.tfrecord
rename to dataset/dataset_hwdb.py
diff --git a/dataset/get_hwdb_1.0_1.1.sh b/dataset/get_hwdb_1.0_1.1.sh
old mode 100644
new mode 100755
diff --git a/readme.md b/readme.md
old mode 100644
new mode 100755
index 8624018..2f7414f
--- a/readme.md
+++ b/readme.md
@@ -27,7 +27,13 @@
**更新**:
实际上,由于单个汉字图片其实很小,差不多也就最大80x80的大小,这个大小不适合转成图片保存到本地,因此我们将hwdb原始的二进制保存为tfrecord。同时也方便后面训练,可以直接从tfrecord读取图片进行训练。
+
+在我们存储完成的时候大概处理了89万个汉字,总共汉字的空间是3755个汉字。由于我们暂时仅仅使用了1.0,所以还有大概3000个汉字没有加入进来,但是处理是一样。使用本仓库来生成你的tfrecord步骤如下:
+
+- `cd dataset && python3 convert_to_tfrecord.py`, 请注意我们使用的是tf2.0;
+- 你需要修改对应的路径,等待生成完成,大概有89万个example,如果1.0和1.1都用,那估计得double。
+
## Model
diff --git a/sample.png b/sample.png
old mode 100644
new mode 100755
diff --git a/samples/.gitignore b/samples/.gitignore
old mode 100644
new mode 100755
diff --git a/samples/001-f.gnt b/samples/001-f.gnt
old mode 100644
new mode 100755
diff --git a/samples/sample.png b/samples/sample.png
old mode 100644
new mode 100755
diff --git a/tests.py b/tests.py
old mode 100644
new mode 100755
diff --git a/train.py b/train.py
new file mode 100755
index 0000000..e69de29