91 lines
3.4 KiB
Python
Executable File
91 lines
3.4 KiB
Python
Executable File
"""
|
|
generates HWDB data into tfrecord
|
|
"""
|
|
import struct
|
|
import numpy as np
|
|
import cv2
|
|
from alfred.utils.log import logger as logging
|
|
import tensorflow as tf
|
|
import glob
|
|
import os
|
|
|
|
|
|
class CASIAHWDBGNT(object):
|
|
"""
|
|
A .gnt file may contains many images and charactors
|
|
"""
|
|
|
|
def __init__(self, f_p):
|
|
self.f_p = f_p
|
|
|
|
def get_data_iter(self):
|
|
header_size = 10
|
|
with open(self.f_p, 'rb') as f:
|
|
while True:
|
|
header = np.fromfile(f, dtype='uint8', count=header_size)
|
|
if not header.size:
|
|
break
|
|
sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)
|
|
tagcode = header[5] + (header[4]<<8)
|
|
width = header[6] + (header[7]<<8)
|
|
height = header[8] + (header[9]<<8)
|
|
if header_size + width*height != sample_size:
|
|
break
|
|
image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))
|
|
yield image, tagcode
|
|
|
|
|
|
def run():
|
|
all_hwdb_gnt_files = glob.glob('./hwdb_raw/HWDB1.1trn_gnt/*.gnt')
|
|
logging.info('got all {} gnt files.'.format(len(all_hwdb_gnt_files)))
|
|
logging.info('gathering charset...')
|
|
charset = []
|
|
if os.path.exists('charactors.txt'):
|
|
logging.info('found exist charactors.txt...')
|
|
with open('charactors.txt', 'r') as f:
|
|
charset = f.readlines()
|
|
charset = [i.strip() for i in charset]
|
|
else:
|
|
for gnt in all_hwdb_gnt_files:
|
|
hwdb = CASIAHWDBGNT(gnt)
|
|
for img, tagcode in hwdb.get_data_iter():
|
|
try:
|
|
label = struct.pack('>H', tagcode).decode('gb2312')
|
|
label = label.replace('\x00', '')
|
|
charset.append(label)
|
|
except Exception as e:
|
|
continue
|
|
charset = sorted(set(charset))
|
|
with open('charactors.txt', 'w') as f:
|
|
f.writelines('\n'.join(charset))
|
|
logging.info('all got {} charactors.'.format(len(charset)))
|
|
logging.info('{}'.format(charset[:10]))
|
|
|
|
tfrecord_f = 'casia_hwdb_1.0_1.1.tfrecord'
|
|
i = 0
|
|
with tf.io.TFRecordWriter(tfrecord_f) as tfrecord_writer:
|
|
for gnt in all_hwdb_gnt_files:
|
|
hwdb = CASIAHWDBGNT(gnt)
|
|
for img, tagcode in hwdb.get_data_iter():
|
|
try:
|
|
img = cv2.resize(img, (64, 64))
|
|
label = struct.pack('>H', tagcode).decode('gb2312')
|
|
label = label.replace('\x00', '')
|
|
index = charset.index(label)
|
|
# save img, label as example
|
|
example = tf.train.Example(features=tf.train.Features(
|
|
feature={
|
|
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
|
|
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()]))
|
|
}))
|
|
tfrecord_writer.write(example.SerializeToString())
|
|
if i%500:
|
|
logging.info('solved {} examples. {}: {}'.format(i, label, index))
|
|
i += 1
|
|
except Exception as e:
|
|
logging.error(e)
|
|
continue
|
|
logging.info('done.')
|
|
|
|
if __name__ == "__main__":
|
|
run() |