This commit is contained in:
Your Name
2019-06-07 19:04:00 +08:00
parent c8df372f63
commit aed8c10d71
87 changed files with 782 additions and 215 deletions

View File

@@ -1,6 +1,7 @@
"""
generates HWDB data into tfrecord
"""
import sys
import struct
import numpy as np
import cv2
@@ -23,69 +24,83 @@ class CASIAHWDBGNT(object):
with open(self.f_p, 'rb') as f:
while True:
header = np.fromfile(f, dtype='uint8', count=header_size)
if not header.size:
if not header.size:
break
sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)
tagcode = header[5] + (header[4]<<8)
width = header[6] + (header[7]<<8)
height = header[8] + (header[9]<<8)
if header_size + width*height != sample_size:
sample_size = header[0] + (header[1] << 8) + (header[2] << 16) + (header[3] << 24)
tagcode = header[5] + (header[4] << 8)
width = header[6] + (header[7] << 8)
height = header[8] + (header[9] << 8)
if header_size + width * height != sample_size:
break
image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))
image = np.fromfile(f, dtype='uint8', count=width * height).reshape((height, width))
yield image, tagcode
def run():
all_hwdb_gnt_files = glob.glob('./hwdb_raw/HWDB1.1trn_gnt/*.gnt')
def run(p):
all_hwdb_gnt_files = glob.glob(os.path.join(p, '*.gnt'))
logging.info('got all {} gnt files.'.format(len(all_hwdb_gnt_files)))
logging.info('gathering charset...')
charset = []
if os.path.exists('charactors.txt'):
logging.info('found exist charactors.txt...')
with open('charactors.txt', 'r') as f:
if os.path.exists('characters.txt'):
logging.info('found exist characters.txt...')
with open('characters.txt', 'r') as f:
charset = f.readlines()
charset = [i.strip() for i in charset]
else:
for gnt in all_hwdb_gnt_files:
hwdb = CASIAHWDBGNT(gnt)
for img, tagcode in hwdb.get_data_iter():
try:
label = struct.pack('>H', tagcode).decode('gb2312')
label = label.replace('\x00', '')
charset.append(label)
except Exception as e:
continue
charset = sorted(set(charset))
with open('charactors.txt', 'w') as f:
f.writelines('\n'.join(charset))
logging.info('all got {} charactors.'.format(len(charset)))
if 'trn' in p:
for gnt in all_hwdb_gnt_files:
hwdb = CASIAHWDBGNT(gnt)
for img, tagcode in hwdb.get_data_iter():
try:
label = struct.pack('>H', tagcode).decode('gb2312')
label = label.replace('\x00', '')
charset.append(label)
except Exception as e:
continue
charset = sorted(set(charset))
with open('characters.txt', 'w') as f:
f.writelines('\n'.join(charset))
logging.info('all got {} characters.'.format(len(charset)))
logging.info('{}'.format(charset[:10]))
tfrecord_f = 'casia_hwdb_1.0_1.1.tfrecord'
tfrecord_f = os.path.basename(os.path.dirname(p)) + '.tfrecord'
logging.info('tfrecord file saved into: {}'.format(tfrecord_f))
i = 0
with tf.io.TFRecordWriter(tfrecord_f) as tfrecord_writer:
for gnt in all_hwdb_gnt_files:
hwdb = CASIAHWDBGNT(gnt)
for img, tagcode in hwdb.get_data_iter():
try:
img = cv2.resize(img, (64, 64))
label = struct.pack('>H', tagcode).decode('gb2312')
# why do you need resize?
w = img.shape[0]
h = img.shape[1]
# img = cv2.resize(img, (64, 64))
label = struct.pack('>H', tagcode).decode('gb2312')
label = label.replace('\x00', '')
index = charset.index(label)
# save img, label as example
example = tf.train.Example(features=tf.train.Features(
feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()]))
}))
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()])),
'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[w])),
'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[h])),
}))
tfrecord_writer.write(example.SerializeToString())
if i%500:
if i % 5000:
logging.info('solved {} examples. {}: {}'.format(i, label, index))
i += 1
except Exception as e:
logging.error(e)
e.with_traceback()
continue
logging.info('done.')
if __name__ == "__main__":
run()
if len(sys.argv) <= 1:
logging.error('send a pattern like this: {}'.format('./hwdb_raw/HWDB1.1trn_gnt/'))
else:
p = sys.argv[1]
logging.info('converting from: {}'.format(p))
run(p)