commit aa1c91780a3f8b6e5e6c8074a16a3dc9d92bb969 Author: xxxxxx Date: Tue May 28 00:06:06 2019 +0800 add diff --git a/dataset/casia_hwdb.py b/dataset/casia_hwdb.py new file mode 100644 index 0000000..a52add1 --- /dev/null +++ b/dataset/casia_hwdb.py @@ -0,0 +1,87 @@ +""" + +this is a wrapper handle CASIA_HWDB dataset +since original data is complicated +we using this class to get .png and label from raw +.gnt data + +""" +import struct +import numpy as np +import cv2 + + +class CASIAHWDBGNT(object): + """ + A .gnt file may contains many images and charactors + """ + + def __init__(self, f_p): + self.f_p = f_p + + def get_data_iter(self): + header_size = 10 + with open(self.f_p, 'rb') as f: + while True: + header = np.fromfile(f, dtype='uint8', count=header_size) + if not header.size: + break + sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24) + tagcode = header[5] + (header[4]<<8) + width = header[6] + (header[7]<<8) + height = header[8] + (header[9]<<8) + if header_size + width*height != sample_size: + break + image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width)) + yield image, tagcode + + +def resize_padding_or_crop(target_size, ori_img, padding_value=255): + if len(ori_img.shape) == 3: + res = np.zeros([ori_img.shape[0], target_size, target_size]) + else: + res = np.ones([target_size, target_size])*padding_value + end_x = target_size + end_y = target_size + start_x = 0 + start_y = 0 + if ori_img.shape[0] < target_size: + end_x = int((target_size + ori_img.shape[0])/2) + if ori_img.shape[1] < target_size: + end_y = int((target_size + ori_img.shape[1])/2) + if ori_img.shape[0] < target_size: + start_x = int((target_size - ori_img.shape[0])/2) + if ori_img.shape[1] < target_size: + start_y = int((target_size - ori_img.shape[1])/2) + res[start_x:end_x, start_y:end_y] = ori_img + return np.array(res, dtype=np.uint8) + +if __name__ == "__main__": + gnt = CASIAHWDBGNT('samples/1001-f.gnt') + + full_img = np.zeros([800, 800], dtype=np.uint8) + charset = [] + i = 0 + for img, tagcode in gnt.get_data_iter(): + cv2.imshow('rr', img) + + try: + label = struct.pack('>H', tagcode).decode('gb2312') + cv2.waitKey(0) + print(label) + # img_padded = resize_padding_or_crop(80, img) + # col_idx = i%10 + # row_idx = i//10 + # full_img[row_idx*80:(row_idx+1)*80, col_idx*80:(col_idx+1)*80] = img_padded + # charset.append(label.replace('\x00', '')) + # if i >= 99: + # cv2.imshow('rrr', full_img) + # cv2.imwrite('sample.png', full_img) + # cv2.waitKey(0) + # print(charset) + # break + # i += 1 + except Exception as e: + # print(e.with_traceback(0)) + print('decode error') + continue diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..0be960b --- /dev/null +++ b/readme.md @@ -0,0 +1,28 @@ +# TensorFlow 2.0 中文手写字识别 + +本项目实现了基于CNN的中文手写字识别,并且采用标准的**tensorflow 2.0 api** 来构建!相比对简单的字母手写识别,本项目更能体现模型设计的精巧性和数据增强的熟练操作性,并且最终设计出来的模型可以直接应用于工业场合,比如 **票据识别**, **手写文本自动扫描** 等,相比于百度api接口或者QQ接口等,具有可优化性、免费性、本地性等优点。 + + +## Data + +在开始之前,先介绍一下本项目所采用的数据信息。我们的数据全部来自于CASIA的开源中文手写字数据集,该数据集分为两部分: + +- CASIA-HWDB:新版本的HWDB,我们仅仅使用1.0-1.2,这是单字的数据集,2.0-2.2是整张文本的数据集,我们暂时不用,单字里面包含了约7185个汉字以及171个英文字母、数字、标点符号等; +- CASIA-OLHWDB:老版本的HWDB,格式一样,包含了约7185个汉字以及171个英文字母、数字、标点符号等。 + +原始数据下载链接点击[这里](http://www.nlpr.ia.ac.cn/databases/handwriting/Offline_database.html). +由于原始数据过于复杂,我们自己写了一个数据wrapper方便读取,统一将其转换为类似于Dataframe (Pandas)的格式,这样可以将一个字的特征和label方便的显示,也可以十分方便的将手写字转换为图片,采用CNN进行处理。这是我们展示的效果: + +

+ +

+ +其对应的label为: + +``` +['!', '"', '#', '$', '%', '&', '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '、', '。', '々', '…', '‘', '’', '“', '”'] +``` + +## Model + +关于我们采用的OCR模型的构建,我们大致采用的是比较先进的MobileNetV3架构,同时设计了一个修改的过的MobileNetV3Big的更深网络。主要考虑模型的轻量型和表达能力。最终训练结果表明,我们的模型可以在中文手写字上达到约99.8%的准确率。 \ No newline at end of file diff --git a/samples/.gitignore b/samples/.gitignore new file mode 100644 index 0000000..e5645d7 --- /dev/null +++ b/samples/.gitignore @@ -0,0 +1 @@ +1001-f.gnt \ No newline at end of file diff --git a/samples/001-f.gnt b/samples/001-f.gnt new file mode 100644 index 0000000..b4ead5f Binary files /dev/null and b/samples/001-f.gnt differ diff --git a/samples/sample.png b/samples/sample.png new file mode 100644 index 0000000..6a0f1e9 Binary files /dev/null and b/samples/sample.png differ diff --git a/tests.py b/tests.py new file mode 100644 index 0000000..e69de29