add
This commit is contained in:
Binary file not shown.
2
dataset/.gitignore
vendored
2
dataset/.gitignore
vendored
@@ -1,3 +1,5 @@
|
||||
hwdb_raw/
|
||||
*.tfrecord
|
||||
casia_hwdb.pyhwdb_11.tfrecord
|
||||
HWDB1.1tst_gnt.tfrecord
|
||||
HWDB1.1trn_gnt.tfrecord
|
||||
Binary file not shown.
@@ -19,34 +19,6 @@ import os
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
class CASIAHWDBGNT(object):
|
||||
"""
|
||||
A .gnt file may contains many images and charactors
|
||||
"""
|
||||
|
||||
def __init__(self, f_p):
|
||||
self.f_p = f_p
|
||||
|
||||
def get_data_iter(self):
|
||||
header_size = 10
|
||||
with open(self.f_p, 'rb') as f:
|
||||
while True:
|
||||
header = np.fromfile(f, dtype='uint8', count=header_size)
|
||||
if not header.size:
|
||||
break
|
||||
sample_size = header[0] + (header[1] << 8) + (
|
||||
header[2] << 16) + (header[3] << 24)
|
||||
tagcode = header[5] + (header[4] << 8)
|
||||
width = header[6] + (header[7] << 8)
|
||||
height = header[8] + (header[9] << 8)
|
||||
if header_size + width * height != sample_size:
|
||||
break
|
||||
image = np.fromfile(f, dtype='uint8',
|
||||
count=width * height).reshape(
|
||||
(height, width))
|
||||
yield image, tagcode
|
||||
|
||||
|
||||
def parse_example(record):
|
||||
features = tf.io.parse_single_example(record,
|
||||
features={
|
||||
@@ -57,34 +29,101 @@ def parse_example(record):
|
||||
})
|
||||
img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
|
||||
img = tf.cast(tf.reshape(img, (64, 64)), dtype=tf.float32)
|
||||
label = tf.cast(features['label'], tf.int32)
|
||||
label = tf.cast(features['label'], tf.int64)
|
||||
return {'image': img, 'label': label}
|
||||
|
||||
|
||||
def parse_example_v2(record):
|
||||
"""
|
||||
latest version format
|
||||
:param record:
|
||||
:return:
|
||||
"""
|
||||
features = tf.io.parse_single_example(record,
|
||||
features={
|
||||
'width':
|
||||
tf.io.FixedLenFeature([], tf.int64),
|
||||
'height':
|
||||
tf.io.FixedLenFeature([], tf.int64),
|
||||
'label':
|
||||
tf.io.FixedLenFeature([], tf.int64),
|
||||
'image':
|
||||
tf.io.FixedLenFeature([], tf.string),
|
||||
})
|
||||
img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
|
||||
# we can not reshape since it stores with original size
|
||||
w = features['width']
|
||||
h = features['height']
|
||||
img = tf.cast(tf.reshape(img, (w, h)), dtype=tf.float32)
|
||||
label = tf.cast(features['label'], tf.int64)
|
||||
return {'image': img, 'label': label}
|
||||
|
||||
|
||||
def load_ds():
|
||||
input_files = ['dataset/hwdb_11.tfrecord']
|
||||
input_files = ['dataset/HWDB1.1trn_gnt.tfrecord']
|
||||
ds = tf.data.TFRecordDataset(input_files)
|
||||
ds = ds.map(parse_example)
|
||||
return ds
|
||||
|
||||
|
||||
def load_characters():
|
||||
def load_val_ds():
|
||||
input_files = ['dataset/HWDB1.1tst_gnt.tfrecord']
|
||||
ds = tf.data.TFRecordDataset(input_files)
|
||||
ds = ds.map(parse_example_v2)
|
||||
return ds
|
||||
|
||||
a = open(os.path.join(this_dir, 'charactors.txt'), 'r').readlines()
|
||||
|
||||
def load_characters():
|
||||
a = open(os.path.join(this_dir, 'characters.txt'), 'r').readlines()
|
||||
return [i.strip() for i in a]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ds = load_ds()
|
||||
val_ds = load_val_ds()
|
||||
val_ds = val_ds.shuffle(100)
|
||||
charactors = load_characters()
|
||||
for img, label in ds.take(9):
|
||||
# start training on model...
|
||||
img = img.numpy()
|
||||
img = np.resize(img, (64, 64))
|
||||
print(img.shape)
|
||||
label = label.numpy()
|
||||
label = charactors[label]
|
||||
print(label)
|
||||
cv2.imshow('rr', img)
|
||||
|
||||
is_show_combine = False
|
||||
if is_show_combine:
|
||||
combined = np.zeros([32*10, 32*20], dtype=np.uint8)
|
||||
i = 0
|
||||
res = ''
|
||||
for data in val_ds.take(200):
|
||||
# start training on model...
|
||||
img, label = data['image'], data['label']
|
||||
img = img.numpy()
|
||||
img = np.array(img, dtype=np.uint8)
|
||||
img = cv2.resize(img, (32, 32))
|
||||
label = label.numpy()
|
||||
label = charactors[label]
|
||||
print(label)
|
||||
row = i // 20
|
||||
col = i % 20
|
||||
print(i, col)
|
||||
print(row, col)
|
||||
combined[row*32: (row+1)*32, col*32: (col+1)*32] = img
|
||||
i += 1
|
||||
res += label
|
||||
cv2.imshow('rr', combined)
|
||||
print(res)
|
||||
cv2.imwrite('assets/combined.png', combined)
|
||||
cv2.waitKey(0)
|
||||
# break
|
||||
# break
|
||||
else:
|
||||
i = 0
|
||||
for data in val_ds.take(36):
|
||||
# start training on model...
|
||||
img, label = data['image'], data['label']
|
||||
img = img.numpy()
|
||||
img = np.array(img, dtype=np.uint8)
|
||||
print(img.shape)
|
||||
# img = cv2.resize(img, (64, 64))
|
||||
label = label.numpy()
|
||||
label = charactors[label]
|
||||
print(label)
|
||||
cv2.imshow('rr', img)
|
||||
cv2.imwrite('assets/{}.png'.format(i), img)
|
||||
i += 1
|
||||
cv2.waitKey(0)
|
||||
# break
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
generates HWDB data into tfrecord
|
||||
"""
|
||||
import sys
|
||||
import struct
|
||||
import numpy as np
|
||||
import cv2
|
||||
@@ -23,69 +24,83 @@ class CASIAHWDBGNT(object):
|
||||
with open(self.f_p, 'rb') as f:
|
||||
while True:
|
||||
header = np.fromfile(f, dtype='uint8', count=header_size)
|
||||
if not header.size:
|
||||
if not header.size:
|
||||
break
|
||||
sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)
|
||||
tagcode = header[5] + (header[4]<<8)
|
||||
width = header[6] + (header[7]<<8)
|
||||
height = header[8] + (header[9]<<8)
|
||||
if header_size + width*height != sample_size:
|
||||
sample_size = header[0] + (header[1] << 8) + (header[2] << 16) + (header[3] << 24)
|
||||
tagcode = header[5] + (header[4] << 8)
|
||||
width = header[6] + (header[7] << 8)
|
||||
height = header[8] + (header[9] << 8)
|
||||
if header_size + width * height != sample_size:
|
||||
break
|
||||
image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))
|
||||
image = np.fromfile(f, dtype='uint8', count=width * height).reshape((height, width))
|
||||
yield image, tagcode
|
||||
|
||||
|
||||
def run():
|
||||
all_hwdb_gnt_files = glob.glob('./hwdb_raw/HWDB1.1trn_gnt/*.gnt')
|
||||
def run(p):
|
||||
all_hwdb_gnt_files = glob.glob(os.path.join(p, '*.gnt'))
|
||||
logging.info('got all {} gnt files.'.format(len(all_hwdb_gnt_files)))
|
||||
logging.info('gathering charset...')
|
||||
charset = []
|
||||
if os.path.exists('charactors.txt'):
|
||||
logging.info('found exist charactors.txt...')
|
||||
with open('charactors.txt', 'r') as f:
|
||||
if os.path.exists('characters.txt'):
|
||||
logging.info('found exist characters.txt...')
|
||||
with open('characters.txt', 'r') as f:
|
||||
charset = f.readlines()
|
||||
charset = [i.strip() for i in charset]
|
||||
else:
|
||||
for gnt in all_hwdb_gnt_files:
|
||||
hwdb = CASIAHWDBGNT(gnt)
|
||||
for img, tagcode in hwdb.get_data_iter():
|
||||
try:
|
||||
label = struct.pack('>H', tagcode).decode('gb2312')
|
||||
label = label.replace('\x00', '')
|
||||
charset.append(label)
|
||||
except Exception as e:
|
||||
continue
|
||||
charset = sorted(set(charset))
|
||||
with open('charactors.txt', 'w') as f:
|
||||
f.writelines('\n'.join(charset))
|
||||
logging.info('all got {} charactors.'.format(len(charset)))
|
||||
if 'trn' in p:
|
||||
for gnt in all_hwdb_gnt_files:
|
||||
hwdb = CASIAHWDBGNT(gnt)
|
||||
for img, tagcode in hwdb.get_data_iter():
|
||||
try:
|
||||
label = struct.pack('>H', tagcode).decode('gb2312')
|
||||
label = label.replace('\x00', '')
|
||||
charset.append(label)
|
||||
except Exception as e:
|
||||
continue
|
||||
charset = sorted(set(charset))
|
||||
with open('characters.txt', 'w') as f:
|
||||
f.writelines('\n'.join(charset))
|
||||
logging.info('all got {} characters.'.format(len(charset)))
|
||||
logging.info('{}'.format(charset[:10]))
|
||||
|
||||
tfrecord_f = 'casia_hwdb_1.0_1.1.tfrecord'
|
||||
|
||||
tfrecord_f = os.path.basename(os.path.dirname(p)) + '.tfrecord'
|
||||
logging.info('tfrecord file saved into: {}'.format(tfrecord_f))
|
||||
i = 0
|
||||
with tf.io.TFRecordWriter(tfrecord_f) as tfrecord_writer:
|
||||
for gnt in all_hwdb_gnt_files:
|
||||
hwdb = CASIAHWDBGNT(gnt)
|
||||
for img, tagcode in hwdb.get_data_iter():
|
||||
try:
|
||||
img = cv2.resize(img, (64, 64))
|
||||
label = struct.pack('>H', tagcode).decode('gb2312')
|
||||
# why do you need resize?
|
||||
w = img.shape[0]
|
||||
h = img.shape[1]
|
||||
# img = cv2.resize(img, (64, 64))
|
||||
label = struct.pack('>H', tagcode).decode('gb2312')
|
||||
label = label.replace('\x00', '')
|
||||
index = charset.index(label)
|
||||
# save img, label as example
|
||||
example = tf.train.Example(features=tf.train.Features(
|
||||
feature={
|
||||
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
|
||||
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()]))
|
||||
}))
|
||||
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
|
||||
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()])),
|
||||
'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[w])),
|
||||
'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[h])),
|
||||
}))
|
||||
tfrecord_writer.write(example.SerializeToString())
|
||||
if i%500:
|
||||
if i % 5000:
|
||||
logging.info('solved {} examples. {}: {}'.format(i, label, index))
|
||||
i += 1
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
e.with_traceback()
|
||||
continue
|
||||
logging.info('done.')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
if len(sys.argv) <= 1:
|
||||
logging.error('send a pattern like this: {}'.format('./hwdb_raw/HWDB1.1trn_gnt/'))
|
||||
else:
|
||||
p = sys.argv[1]
|
||||
logging.info('converting from: {}'.format(p))
|
||||
run(p)
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zip
|
||||
wget wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip
|
||||
wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip
|
||||
|
||||
Reference in New Issue
Block a user