data ready

This commit is contained in:
Your Name
2019-06-05 00:13:32 +08:00
parent bbf8928e7a
commit 919d89af4b
23 changed files with 364 additions and 68 deletions

BIN
dataset/.DS_Store vendored Executable file

Binary file not shown.

BIN
dataset/._.DS_Store Executable file

Binary file not shown.

3
dataset/.gitignore vendored Normal file → Executable file
View File

@@ -1,2 +1,3 @@
hwdb_raw/
*.tfrecord
*.tfrecord
casia_hwdb.pyhwdb_11.tfrecord

102
dataset/casia_hwdb.py Normal file → Executable file
View File

@@ -6,9 +6,13 @@ we using this class to get .png and label from raw
.gnt data
"""
from alfred.dl.tf.common import mute_tf
mute_tf()
import struct
import numpy as np
import cv2
import tensorflow as tf
class CASIAHWDBGNT(object):
@@ -24,61 +28,57 @@ class CASIAHWDBGNT(object):
with open(self.f_p, 'rb') as f:
while True:
header = np.fromfile(f, dtype='uint8', count=header_size)
if not header.size:
if not header.size:
break
sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)
tagcode = header[5] + (header[4]<<8)
width = header[6] + (header[7]<<8)
height = header[8] + (header[9]<<8)
if header_size + width*height != sample_size:
sample_size = header[0] + (header[1] << 8) + (
header[2] << 16) + (header[3] << 24)
tagcode = header[5] + (header[4] << 8)
width = header[6] + (header[7] << 8)
height = header[8] + (header[9] << 8)
if header_size + width * height != sample_size:
break
image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))
image = np.fromfile(f, dtype='uint8',
count=width * height).reshape(
(height, width))
yield image, tagcode
def resize_padding_or_crop(target_size, ori_img, padding_value=255):
if len(ori_img.shape) == 3:
res = np.zeros([ori_img.shape[0], target_size, target_size])
else:
res = np.ones([target_size, target_size])*padding_value
end_x = target_size
end_y = target_size
start_x = 0
start_y = 0
if ori_img.shape[0] < target_size:
end_x = int((target_size + ori_img.shape[0])/2)
if ori_img.shape[1] < target_size:
end_y = int((target_size + ori_img.shape[1])/2)
if ori_img.shape[0] < target_size:
start_x = int((target_size - ori_img.shape[0])/2)
if ori_img.shape[1] < target_size:
start_y = int((target_size - ori_img.shape[1])/2)
res[start_x:end_x, start_y:end_y] = ori_img
return np.array(res, dtype=np.uint8)
def parse_example(record):
features = tf.io.parse_single_example(record,
features={
'label':
tf.io.FixedLenFeature([], tf.int64),
'image':
tf.io.FixedLenFeature([], tf.string),
})
img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
label = tf.cast(features['label'], tf.int32)
return img, label
def load_ds():
input_files = ['casia_hwdb_1.0_1.1.tfrecord']
ds = tf.data.TFRecordDataset(input_files)
ds = ds.map(parse_example)
return ds
def load_charactors():
a = open('charactors.txt', 'r').readlines()
return [i.strip() for i in a]
if __name__ == "__main__":
gnt = CASIAHWDBGNT('samples/1001-f.gnt')
full_img = np.zeros([900, 900], dtype=np.uint8)
charset = []
i = 0
for img, tagcode in gnt.get_data_iter():
# cv2.imshow('rr', img)
try:
label = struct.pack('>H', tagcode).decode('gb2312')
img_padded = resize_padding_or_crop(90, img)
col_idx = i%10
row_idx = i//10
full_img[row_idx*90:(row_idx+1)*90, col_idx*90:(col_idx+1)*90] = img_padded
charset.append(label.replace('\x00', ''))
if i >= 99:
cv2.imshow('rrr', full_img)
cv2.imwrite('sample.png', full_img)
cv2.waitKey(0)
print(charset)
break
i += 1
except Exception as e:
# print(e.with_traceback(0))
print('decode error')
continue
ds = load_ds()
charactors = load_charactors()
for img, label in ds.take(9):
# start training on model...
img = img.numpy()
img = np.resize(img, (64, 64))
print(img.shape)
label = label.numpy()
label = charactors[label]
print(label)
cv2.imshow('rr', img)
cv2.waitKey(0)
# break

0
dataset/charactors.txt Normal file → Executable file
View File

40
dataset/convert_to_tfrecord.py Normal file → Executable file
View File

@@ -7,6 +7,7 @@ import cv2
from alfred.utils.log import logger as logging
import tensorflow as tf
import glob
import os
class CASIAHWDBGNT(object):
@@ -39,20 +40,27 @@ def run():
logging.info('got all {} gnt files.'.format(len(all_hwdb_gnt_files)))
logging.info('gathering charset...')
charset = []
for gnt in all_hwdb_gnt_files:
hwdb = CASIAHWDBGNT(gnt)
for img, tagcode in hwdb.get_data_iter():
try:
label = struct.pack('>H', tagcode).decode('gb2312')
label = label.replace('\x00', '')
charset.append(label)
except Exception as e:
continue
charset = sorted(set(charset))
if os.path.exists('charactors.txt'):
logging.info('found exist charactors.txt...')
with open('charactors.txt', 'r') as f:
charset = f.readlines()
charset = [i.strip() for i in charset]
else:
for gnt in all_hwdb_gnt_files:
hwdb = CASIAHWDBGNT(gnt)
for img, tagcode in hwdb.get_data_iter():
try:
label = struct.pack('>H', tagcode).decode('gb2312')
label = label.replace('\x00', '')
charset.append(label)
except Exception as e:
continue
charset = sorted(set(charset))
with open('charactors.txt', 'w') as f:
f.writelines('\n'.join(charset))
logging.info('all got {} charactors.'.format(len(charset)))
with open('charactors.txt', 'w') as f:
f.writelines('\n'.join(charset))
logging.info('{}'.format(charset[:10]))
tfrecord_f = 'casia_hwdb_1.0_1.1.tfrecord'
i = 0
with tf.io.TFRecordWriter(tfrecord_f) as tfrecord_writer:
@@ -60,7 +68,7 @@ def run():
hwdb = CASIAHWDBGNT(gnt)
for img, tagcode in hwdb.get_data_iter():
try:
img = cv.resize(img, (64, 64))
img = cv2.resize(img, (64, 64))
label = struct.pack('>H', tagcode).decode('gb2312')
label = label.replace('\x00', '')
index = charset.index(label)
@@ -68,11 +76,11 @@ def run():
example = tf.train.Example(features=tf.train.Features(
feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img]))
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()]))
}))
tfrecord_writer.write(example.SerializeToString())
if i%500:
logging.info('solved {} examples.'.format(i))
logging.info('solved {} examples. {}: {}'.format(i, label, index))
i += 1
except Exception as e:
logging.error(e)

View File

0
dataset/get_hwdb_1.0_1.1.sh Normal file → Executable file
View File