85 lines
2.4 KiB
Python
Executable File
85 lines
2.4 KiB
Python
Executable File
"""
|
|
|
|
this is a wrapper handle CASIA_HWDB dataset
|
|
since original data is complicated
|
|
we using this class to get .png and label from raw
|
|
.gnt data
|
|
|
|
"""
|
|
from alfred.dl.tf.common import mute_tf
|
|
|
|
mute_tf()
|
|
import struct
|
|
import numpy as np
|
|
import cv2
|
|
import tensorflow as tf
|
|
|
|
|
|
class CASIAHWDBGNT(object):
|
|
"""
|
|
A .gnt file may contains many images and charactors
|
|
"""
|
|
|
|
def __init__(self, f_p):
|
|
self.f_p = f_p
|
|
|
|
def get_data_iter(self):
|
|
header_size = 10
|
|
with open(self.f_p, 'rb') as f:
|
|
while True:
|
|
header = np.fromfile(f, dtype='uint8', count=header_size)
|
|
if not header.size:
|
|
break
|
|
sample_size = header[0] + (header[1] << 8) + (
|
|
header[2] << 16) + (header[3] << 24)
|
|
tagcode = header[5] + (header[4] << 8)
|
|
width = header[6] + (header[7] << 8)
|
|
height = header[8] + (header[9] << 8)
|
|
if header_size + width * height != sample_size:
|
|
break
|
|
image = np.fromfile(f, dtype='uint8',
|
|
count=width * height).reshape(
|
|
(height, width))
|
|
yield image, tagcode
|
|
|
|
|
|
def parse_example(record):
|
|
features = tf.io.parse_single_example(record,
|
|
features={
|
|
'label':
|
|
tf.io.FixedLenFeature([], tf.int64),
|
|
'image':
|
|
tf.io.FixedLenFeature([], tf.string),
|
|
})
|
|
img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
|
|
label = tf.cast(features['label'], tf.int32)
|
|
return img, label
|
|
|
|
|
|
def load_ds():
|
|
input_files = ['dataset/hwdb_11.tfrecord']
|
|
ds = tf.data.TFRecordDataset(input_files)
|
|
ds = ds.map(parse_example)
|
|
return ds
|
|
|
|
|
|
def load_charactors():
|
|
a = open('charactors.txt', 'r').readlines()
|
|
return [i.strip() for i in a]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
ds = load_ds()
|
|
charactors = load_charactors()
|
|
for img, label in ds.take(9):
|
|
# start training on model...
|
|
img = img.numpy()
|
|
img = np.resize(img, (64, 64))
|
|
print(img.shape)
|
|
label = label.numpy()
|
|
label = charactors[label]
|
|
print(label)
|
|
cv2.imshow('rr', img)
|
|
cv2.waitKey(0)
|
|
# break
|