This commit is contained in:
Your Name
2019-06-07 19:04:00 +08:00
parent c8df372f63
commit aed8c10d71
87 changed files with 782 additions and 215 deletions

View File

@@ -19,34 +19,6 @@ import os
this_dir = os.path.dirname(os.path.abspath(__file__))
class CASIAHWDBGNT(object):
"""
A .gnt file may contains many images and charactors
"""
def __init__(self, f_p):
self.f_p = f_p
def get_data_iter(self):
header_size = 10
with open(self.f_p, 'rb') as f:
while True:
header = np.fromfile(f, dtype='uint8', count=header_size)
if not header.size:
break
sample_size = header[0] + (header[1] << 8) + (
header[2] << 16) + (header[3] << 24)
tagcode = header[5] + (header[4] << 8)
width = header[6] + (header[7] << 8)
height = header[8] + (header[9] << 8)
if header_size + width * height != sample_size:
break
image = np.fromfile(f, dtype='uint8',
count=width * height).reshape(
(height, width))
yield image, tagcode
def parse_example(record):
features = tf.io.parse_single_example(record,
features={
@@ -57,34 +29,101 @@ def parse_example(record):
})
img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
img = tf.cast(tf.reshape(img, (64, 64)), dtype=tf.float32)
label = tf.cast(features['label'], tf.int32)
label = tf.cast(features['label'], tf.int64)
return {'image': img, 'label': label}
def parse_example_v2(record):
"""
latest version format
:param record:
:return:
"""
features = tf.io.parse_single_example(record,
features={
'width':
tf.io.FixedLenFeature([], tf.int64),
'height':
tf.io.FixedLenFeature([], tf.int64),
'label':
tf.io.FixedLenFeature([], tf.int64),
'image':
tf.io.FixedLenFeature([], tf.string),
})
img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
# we can not reshape since it stores with original size
w = features['width']
h = features['height']
img = tf.cast(tf.reshape(img, (w, h)), dtype=tf.float32)
label = tf.cast(features['label'], tf.int64)
return {'image': img, 'label': label}
def load_ds():
input_files = ['dataset/hwdb_11.tfrecord']
input_files = ['dataset/HWDB1.1trn_gnt.tfrecord']
ds = tf.data.TFRecordDataset(input_files)
ds = ds.map(parse_example)
return ds
def load_characters():
def load_val_ds():
input_files = ['dataset/HWDB1.1tst_gnt.tfrecord']
ds = tf.data.TFRecordDataset(input_files)
ds = ds.map(parse_example_v2)
return ds
a = open(os.path.join(this_dir, 'charactors.txt'), 'r').readlines()
def load_characters():
a = open(os.path.join(this_dir, 'characters.txt'), 'r').readlines()
return [i.strip() for i in a]
if __name__ == "__main__":
ds = load_ds()
val_ds = load_val_ds()
val_ds = val_ds.shuffle(100)
charactors = load_characters()
for img, label in ds.take(9):
# start training on model...
img = img.numpy()
img = np.resize(img, (64, 64))
print(img.shape)
label = label.numpy()
label = charactors[label]
print(label)
cv2.imshow('rr', img)
is_show_combine = False
if is_show_combine:
combined = np.zeros([32*10, 32*20], dtype=np.uint8)
i = 0
res = ''
for data in val_ds.take(200):
# start training on model...
img, label = data['image'], data['label']
img = img.numpy()
img = np.array(img, dtype=np.uint8)
img = cv2.resize(img, (32, 32))
label = label.numpy()
label = charactors[label]
print(label)
row = i // 20
col = i % 20
print(i, col)
print(row, col)
combined[row*32: (row+1)*32, col*32: (col+1)*32] = img
i += 1
res += label
cv2.imshow('rr', combined)
print(res)
cv2.imwrite('assets/combined.png', combined)
cv2.waitKey(0)
# break
# break
else:
i = 0
for data in val_ds.take(36):
# start training on model...
img, label = data['image'], data['label']
img = img.numpy()
img = np.array(img, dtype=np.uint8)
print(img.shape)
# img = cv2.resize(img, (64, 64))
label = label.numpy()
label = charactors[label]
print(label)
cv2.imshow('rr', img)
cv2.imwrite('assets/{}.png'.format(i), img)
i += 1
cv2.waitKey(0)
# break