data ready
This commit is contained in:
102
dataset/casia_hwdb.py
Normal file → Executable file
102
dataset/casia_hwdb.py
Normal file → Executable file
@@ -6,9 +6,13 @@ we using this class to get .png and label from raw
|
||||
.gnt data
|
||||
|
||||
"""
|
||||
from alfred.dl.tf.common import mute_tf
|
||||
|
||||
mute_tf()
|
||||
import struct
|
||||
import numpy as np
|
||||
import cv2
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class CASIAHWDBGNT(object):
|
||||
@@ -24,61 +28,57 @@ class CASIAHWDBGNT(object):
|
||||
with open(self.f_p, 'rb') as f:
|
||||
while True:
|
||||
header = np.fromfile(f, dtype='uint8', count=header_size)
|
||||
if not header.size:
|
||||
if not header.size:
|
||||
break
|
||||
sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)
|
||||
tagcode = header[5] + (header[4]<<8)
|
||||
width = header[6] + (header[7]<<8)
|
||||
height = header[8] + (header[9]<<8)
|
||||
if header_size + width*height != sample_size:
|
||||
sample_size = header[0] + (header[1] << 8) + (
|
||||
header[2] << 16) + (header[3] << 24)
|
||||
tagcode = header[5] + (header[4] << 8)
|
||||
width = header[6] + (header[7] << 8)
|
||||
height = header[8] + (header[9] << 8)
|
||||
if header_size + width * height != sample_size:
|
||||
break
|
||||
image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))
|
||||
image = np.fromfile(f, dtype='uint8',
|
||||
count=width * height).reshape(
|
||||
(height, width))
|
||||
yield image, tagcode
|
||||
|
||||
|
||||
def resize_padding_or_crop(target_size, ori_img, padding_value=255):
|
||||
if len(ori_img.shape) == 3:
|
||||
res = np.zeros([ori_img.shape[0], target_size, target_size])
|
||||
else:
|
||||
res = np.ones([target_size, target_size])*padding_value
|
||||
end_x = target_size
|
||||
end_y = target_size
|
||||
start_x = 0
|
||||
start_y = 0
|
||||
if ori_img.shape[0] < target_size:
|
||||
end_x = int((target_size + ori_img.shape[0])/2)
|
||||
if ori_img.shape[1] < target_size:
|
||||
end_y = int((target_size + ori_img.shape[1])/2)
|
||||
if ori_img.shape[0] < target_size:
|
||||
start_x = int((target_size - ori_img.shape[0])/2)
|
||||
if ori_img.shape[1] < target_size:
|
||||
start_y = int((target_size - ori_img.shape[1])/2)
|
||||
res[start_x:end_x, start_y:end_y] = ori_img
|
||||
return np.array(res, dtype=np.uint8)
|
||||
def parse_example(record):
|
||||
features = tf.io.parse_single_example(record,
|
||||
features={
|
||||
'label':
|
||||
tf.io.FixedLenFeature([], tf.int64),
|
||||
'image':
|
||||
tf.io.FixedLenFeature([], tf.string),
|
||||
})
|
||||
img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
|
||||
label = tf.cast(features['label'], tf.int32)
|
||||
return img, label
|
||||
|
||||
|
||||
def load_ds():
|
||||
input_files = ['casia_hwdb_1.0_1.1.tfrecord']
|
||||
ds = tf.data.TFRecordDataset(input_files)
|
||||
ds = ds.map(parse_example)
|
||||
return ds
|
||||
|
||||
|
||||
def load_charactors():
|
||||
a = open('charactors.txt', 'r').readlines()
|
||||
return [i.strip() for i in a]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
gnt = CASIAHWDBGNT('samples/1001-f.gnt')
|
||||
|
||||
full_img = np.zeros([900, 900], dtype=np.uint8)
|
||||
charset = []
|
||||
i = 0
|
||||
for img, tagcode in gnt.get_data_iter():
|
||||
# cv2.imshow('rr', img)
|
||||
try:
|
||||
label = struct.pack('>H', tagcode).decode('gb2312')
|
||||
img_padded = resize_padding_or_crop(90, img)
|
||||
col_idx = i%10
|
||||
row_idx = i//10
|
||||
full_img[row_idx*90:(row_idx+1)*90, col_idx*90:(col_idx+1)*90] = img_padded
|
||||
charset.append(label.replace('\x00', ''))
|
||||
if i >= 99:
|
||||
cv2.imshow('rrr', full_img)
|
||||
cv2.imwrite('sample.png', full_img)
|
||||
cv2.waitKey(0)
|
||||
print(charset)
|
||||
break
|
||||
i += 1
|
||||
except Exception as e:
|
||||
# print(e.with_traceback(0))
|
||||
print('decode error')
|
||||
continue
|
||||
ds = load_ds()
|
||||
charactors = load_charactors()
|
||||
for img, label in ds.take(9):
|
||||
# start training on model...
|
||||
img = img.numpy()
|
||||
img = np.resize(img, (64, 64))
|
||||
print(img.shape)
|
||||
label = label.numpy()
|
||||
label = charactors[label]
|
||||
print(label)
|
||||
cv2.imshow('rr', img)
|
||||
cv2.waitKey(0)
|
||||
# break
|
||||
|
||||
Reference in New Issue
Block a user