Files
ocrcn_tf2/dataset/convert_to_tfrecord.py
2019-06-05 00:13:32 +08:00

91 lines
3.4 KiB
Python
Executable File

"""
generates HWDB data into tfrecord
"""
import struct
import numpy as np
import cv2
from alfred.utils.log import logger as logging
import tensorflow as tf
import glob
import os
class CASIAHWDBGNT(object):
"""
A .gnt file may contains many images and charactors
"""
def __init__(self, f_p):
self.f_p = f_p
def get_data_iter(self):
header_size = 10
with open(self.f_p, 'rb') as f:
while True:
header = np.fromfile(f, dtype='uint8', count=header_size)
if not header.size:
break
sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)
tagcode = header[5] + (header[4]<<8)
width = header[6] + (header[7]<<8)
height = header[8] + (header[9]<<8)
if header_size + width*height != sample_size:
break
image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))
yield image, tagcode
def run():
all_hwdb_gnt_files = glob.glob('./hwdb_raw/HWDB1.1trn_gnt/*.gnt')
logging.info('got all {} gnt files.'.format(len(all_hwdb_gnt_files)))
logging.info('gathering charset...')
charset = []
if os.path.exists('charactors.txt'):
logging.info('found exist charactors.txt...')
with open('charactors.txt', 'r') as f:
charset = f.readlines()
charset = [i.strip() for i in charset]
else:
for gnt in all_hwdb_gnt_files:
hwdb = CASIAHWDBGNT(gnt)
for img, tagcode in hwdb.get_data_iter():
try:
label = struct.pack('>H', tagcode).decode('gb2312')
label = label.replace('\x00', '')
charset.append(label)
except Exception as e:
continue
charset = sorted(set(charset))
with open('charactors.txt', 'w') as f:
f.writelines('\n'.join(charset))
logging.info('all got {} charactors.'.format(len(charset)))
logging.info('{}'.format(charset[:10]))
tfrecord_f = 'casia_hwdb_1.0_1.1.tfrecord'
i = 0
with tf.io.TFRecordWriter(tfrecord_f) as tfrecord_writer:
for gnt in all_hwdb_gnt_files:
hwdb = CASIAHWDBGNT(gnt)
for img, tagcode in hwdb.get_data_iter():
try:
img = cv2.resize(img, (64, 64))
label = struct.pack('>H', tagcode).decode('gb2312')
label = label.replace('\x00', '')
index = charset.index(label)
# save img, label as example
example = tf.train.Example(features=tf.train.Features(
feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()]))
}))
tfrecord_writer.write(example.SerializeToString())
if i%500:
logging.info('solved {} examples. {}: {}'.format(i, label, index))
i += 1
except Exception as e:
logging.error(e)
continue
logging.info('done.')
if __name__ == "__main__":
run()