data ready

This commit is contained in:
Your Name
2019-06-05 00:13:32 +08:00
parent bbf8928e7a
commit 919d89af4b
23 changed files with 364 additions and 68 deletions

40
dataset/convert_to_tfrecord.py Normal file → Executable file
View File

@@ -7,6 +7,7 @@ import cv2
from alfred.utils.log import logger as logging
import tensorflow as tf
import glob
import os
class CASIAHWDBGNT(object):
@@ -39,20 +40,27 @@ def run():
logging.info('got all {} gnt files.'.format(len(all_hwdb_gnt_files)))
logging.info('gathering charset...')
charset = []
for gnt in all_hwdb_gnt_files:
hwdb = CASIAHWDBGNT(gnt)
for img, tagcode in hwdb.get_data_iter():
try:
label = struct.pack('>H', tagcode).decode('gb2312')
label = label.replace('\x00', '')
charset.append(label)
except Exception as e:
continue
charset = sorted(set(charset))
if os.path.exists('charactors.txt'):
logging.info('found exist charactors.txt...')
with open('charactors.txt', 'r') as f:
charset = f.readlines()
charset = [i.strip() for i in charset]
else:
for gnt in all_hwdb_gnt_files:
hwdb = CASIAHWDBGNT(gnt)
for img, tagcode in hwdb.get_data_iter():
try:
label = struct.pack('>H', tagcode).decode('gb2312')
label = label.replace('\x00', '')
charset.append(label)
except Exception as e:
continue
charset = sorted(set(charset))
with open('charactors.txt', 'w') as f:
f.writelines('\n'.join(charset))
logging.info('all got {} charactors.'.format(len(charset)))
with open('charactors.txt', 'w') as f:
f.writelines('\n'.join(charset))
logging.info('{}'.format(charset[:10]))
tfrecord_f = 'casia_hwdb_1.0_1.1.tfrecord'
i = 0
with tf.io.TFRecordWriter(tfrecord_f) as tfrecord_writer:
@@ -60,7 +68,7 @@ def run():
hwdb = CASIAHWDBGNT(gnt)
for img, tagcode in hwdb.get_data_iter():
try:
img = cv.resize(img, (64, 64))
img = cv2.resize(img, (64, 64))
label = struct.pack('>H', tagcode).decode('gb2312')
label = label.replace('\x00', '')
index = charset.index(label)
@@ -68,11 +76,11 @@ def run():
example = tf.train.Example(features=tf.train.Features(
feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img]))
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()]))
}))
tfrecord_writer.write(example.SerializeToString())
if i%500:
logging.info('solved {} examples.'.format(i))
logging.info('solved {} examples. {}: {}'.format(i, label, index))
i += 1
except Exception as e:
logging.error(e)