add dataset
This commit is contained in:
1
dataset/.gitignore
vendored
Normal file
1
dataset/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
hwdb_raw/
|
||||||
@@ -59,28 +59,25 @@ def resize_padding_or_crop(target_size, ori_img, padding_value=255):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
gnt = CASIAHWDBGNT('samples/1001-f.gnt')
|
gnt = CASIAHWDBGNT('samples/1001-f.gnt')
|
||||||
|
|
||||||
full_img = np.zeros([800, 800], dtype=np.uint8)
|
full_img = np.zeros([900, 900], dtype=np.uint8)
|
||||||
charset = []
|
charset = []
|
||||||
i = 0
|
i = 0
|
||||||
for img, tagcode in gnt.get_data_iter():
|
for img, tagcode in gnt.get_data_iter():
|
||||||
cv2.imshow('rr', img)
|
# cv2.imshow('rr', img)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
label = struct.pack('>H', tagcode).decode('gb2312')
|
label = struct.pack('>H', tagcode).decode('gb2312')
|
||||||
cv2.waitKey(0)
|
img_padded = resize_padding_or_crop(90, img)
|
||||||
print(label)
|
col_idx = i%10
|
||||||
# img_padded = resize_padding_or_crop(80, img)
|
row_idx = i//10
|
||||||
# col_idx = i%10
|
full_img[row_idx*90:(row_idx+1)*90, col_idx*90:(col_idx+1)*90] = img_padded
|
||||||
# row_idx = i//10
|
charset.append(label.replace('\x00', ''))
|
||||||
# full_img[row_idx*80:(row_idx+1)*80, col_idx*80:(col_idx+1)*80] = img_padded
|
if i >= 99:
|
||||||
# charset.append(label.replace('\x00', ''))
|
cv2.imshow('rrr', full_img)
|
||||||
# if i >= 99:
|
cv2.imwrite('sample.png', full_img)
|
||||||
# cv2.imshow('rrr', full_img)
|
cv2.waitKey(0)
|
||||||
# cv2.imwrite('sample.png', full_img)
|
print(charset)
|
||||||
# cv2.waitKey(0)
|
break
|
||||||
# print(charset)
|
i += 1
|
||||||
# break
|
|
||||||
# i += 1
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# print(e.with_traceback(0))
|
# print(e.with_traceback(0))
|
||||||
print('decode error')
|
print('decode error')
|
||||||
|
|||||||
0
dataset/casia_hwdb_1.0_1.1.tfrecord
Normal file
0
dataset/casia_hwdb_1.0_1.1.tfrecord
Normal file
3755
dataset/charactors.txt
Normal file
3755
dataset/charactors.txt
Normal file
File diff suppressed because it is too large
Load Diff
83
dataset/convert_to_tfrecord.py
Normal file
83
dataset/convert_to_tfrecord.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""
|
||||||
|
generates HWDB data into tfrecord
|
||||||
|
"""
|
||||||
|
import struct
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from alfred.utils.log import logger as logging
|
||||||
|
import tensorflow as tf
|
||||||
|
import glob
|
||||||
|
|
||||||
|
|
||||||
|
class CASIAHWDBGNT(object):
|
||||||
|
"""
|
||||||
|
A .gnt file may contains many images and charactors
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, f_p):
|
||||||
|
self.f_p = f_p
|
||||||
|
|
||||||
|
def get_data_iter(self):
|
||||||
|
header_size = 10
|
||||||
|
with open(self.f_p, 'rb') as f:
|
||||||
|
while True:
|
||||||
|
header = np.fromfile(f, dtype='uint8', count=header_size)
|
||||||
|
if not header.size:
|
||||||
|
break
|
||||||
|
sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)
|
||||||
|
tagcode = header[5] + (header[4]<<8)
|
||||||
|
width = header[6] + (header[7]<<8)
|
||||||
|
height = header[8] + (header[9]<<8)
|
||||||
|
if header_size + width*height != sample_size:
|
||||||
|
break
|
||||||
|
image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))
|
||||||
|
yield image, tagcode
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
all_hwdb_gnt_files = glob.glob('./hwdb_raw/HWDB1.1trn_gnt/*.gnt')
|
||||||
|
logging.info('got all {} gnt files.'.format(len(all_hwdb_gnt_files)))
|
||||||
|
logging.info('gathering charset...')
|
||||||
|
charset = []
|
||||||
|
for gnt in all_hwdb_gnt_files:
|
||||||
|
hwdb = CASIAHWDBGNT(gnt)
|
||||||
|
for img, tagcode in hwdb.get_data_iter():
|
||||||
|
try:
|
||||||
|
label = struct.pack('>H', tagcode).decode('gb2312')
|
||||||
|
label = label.replace('\x00', '')
|
||||||
|
charset.append(label)
|
||||||
|
except Exception as e:
|
||||||
|
continue
|
||||||
|
charset = sorted(set(charset))
|
||||||
|
logging.info('all got {} charactors.'.format(len(charset)))
|
||||||
|
with open('charactors.txt', 'w') as f:
|
||||||
|
f.writelines('\n'.join(charset))
|
||||||
|
|
||||||
|
tfrecord_f = 'casia_hwdb_1.0_1.1.tfrecord'
|
||||||
|
i = 0
|
||||||
|
with tf.io.TFRecordWriter(tfrecord_f) as tfrecord_writer:
|
||||||
|
for gnt in all_hwdb_gnt_files:
|
||||||
|
hwdb = CASIAHWDBGNT(gnt)
|
||||||
|
for img, tagcode in hwdb.get_data_iter():
|
||||||
|
try:
|
||||||
|
img = cv.resize(img, (64, 64))
|
||||||
|
label = struct.pack('>H', tagcode).decode('gb2312')
|
||||||
|
label = label.replace('\x00', '')
|
||||||
|
index = charset.index(label)
|
||||||
|
# save img, label as example
|
||||||
|
example = tf.train.Example(features=tf.train.Features(
|
||||||
|
feature={
|
||||||
|
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
|
||||||
|
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img]))
|
||||||
|
}))
|
||||||
|
tfrecord_writer.write(example.SerializeToString())
|
||||||
|
if i%500:
|
||||||
|
logging.info('solved {} examples.'.format(i))
|
||||||
|
i += 1
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(e)
|
||||||
|
continue
|
||||||
|
logging.info('done.')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
2
dataset/get_hwdb_1.0_1.1.sh
Normal file
2
dataset/get_hwdb_1.0_1.1.sh
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zip
|
||||||
|
wget wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip
|
||||||
12
readme.md
12
readme.md
@@ -7,10 +7,10 @@
|
|||||||
|
|
||||||
在开始之前,先介绍一下本项目所采用的数据信息。我们的数据全部来自于CASIA的开源中文手写字数据集,该数据集分为两部分:
|
在开始之前,先介绍一下本项目所采用的数据信息。我们的数据全部来自于CASIA的开源中文手写字数据集,该数据集分为两部分:
|
||||||
|
|
||||||
- CASIA-HWDB:新版本的HWDB,我们仅仅使用1.0-1.2,这是单字的数据集,2.0-2.2是整张文本的数据集,我们暂时不用,单字里面包含了约7185个汉字以及171个英文字母、数字、标点符号等;
|
- CASIA-HWDB:离线的HWDB,我们仅仅使用1.0-1.2,这是单字的数据集,2.0-2.2是整张文本的数据集,我们暂时不用,单字里面包含了约7185个汉字以及171个英文字母、数字、标点符号等;
|
||||||
- CASIA-OLHWDB:老版本的HWDB,格式一样,包含了约7185个汉字以及171个英文字母、数字、标点符号等。
|
- CASIA-OLHWDB:在线的HWDB,格式一样,包含了约7185个汉字以及171个英文字母、数字、标点符号等,我们不用。
|
||||||
|
|
||||||
原始数据下载链接点击[这里](http://www.nlpr.ia.ac.cn/databases/handwriting/Offline_database.html).
|
其实你下载1.0的train和test差不多已经够了,可以直接运行 `dataset/get_hwdb_1.0_1.1.sh` 下载。原始数据下载链接点击[这里](http://www.nlpr.ia.ac.cn/databases/handwriting/Offline_database.html).
|
||||||
由于原始数据过于复杂,我们自己写了一个数据wrapper方便读取,统一将其转换为类似于Dataframe (Pandas)的格式,这样可以将一个字的特征和label方便的显示,也可以十分方便的将手写字转换为图片,采用CNN进行处理。这是我们展示的效果:
|
由于原始数据过于复杂,我们自己写了一个数据wrapper方便读取,统一将其转换为类似于Dataframe (Pandas)的格式,这样可以将一个字的特征和label方便的显示,也可以十分方便的将手写字转换为图片,采用CNN进行处理。这是我们展示的效果:
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
@@ -23,6 +23,12 @@
|
|||||||
['!', '"', '#', '$', '%', '&', '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '、', '。', '々', '…', '‘', '’', '“', '”']
|
['!', '"', '#', '$', '%', '&', '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '、', '。', '々', '…', '‘', '’', '“', '”']
|
||||||
```
|
```
|
||||||
|
|
||||||
|
关于数据的处理部分,从服务器下载到的原始数据是 `trn_gnt.zip` 解压之后是 `gnt.alz`, 需要再次解压得到一个包含 gnt文件的文件夹。里面每一个gnt文件都包含了若干个汉字及其标注。直接处理比较麻烦,也不方便抽取出图片再进行操作,**虽然转为图片存入文件夹比较直观,但是不适合批量读取和训练**, 后面我们统一转为tfrecord进行训练。
|
||||||
|
|
||||||
|
**更新**:
|
||||||
|
实际上,由于单个汉字图片其实很小,差不多也就最大80x80的大小,这个大小不适合转成图片保存到本地,因此我们将hwdb原始的二进制保存为tfrecord。同时也方便后面训练,可以直接从tfrecord读取图片进行训练。
|
||||||
|
|
||||||
|
|
||||||
## Model
|
## Model
|
||||||
|
|
||||||
关于我们采用的OCR模型的构建,我们大致采用的是比较先进的MobileNetV3架构,同时设计了一个修改的过的MobileNetV3Big的更深网络。主要考虑模型的轻量型和表达能力。最终训练结果表明,我们的模型可以在中文手写字上达到约99.8%的准确率。
|
关于我们采用的OCR模型的构建,我们大致采用的是比较先进的MobileNetV3架构,同时设计了一个修改的过的MobileNetV3Big的更深网络。主要考虑模型的轻量型和表达能力。最终训练结果表明,我们的模型可以在中文手写字上达到约99.8%的准确率。
|
||||||
BIN
sample.png
Normal file
BIN
sample.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 71 KiB |
Reference in New Issue
Block a user