add dataset
This commit is contained in:
1
dataset/.gitignore
vendored
Normal file
1
dataset/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
hwdb_raw/
|
||||
@@ -59,28 +59,25 @@ def resize_padding_or_crop(target_size, ori_img, padding_value=255):
|
||||
if __name__ == "__main__":
|
||||
gnt = CASIAHWDBGNT('samples/1001-f.gnt')
|
||||
|
||||
full_img = np.zeros([800, 800], dtype=np.uint8)
|
||||
full_img = np.zeros([900, 900], dtype=np.uint8)
|
||||
charset = []
|
||||
i = 0
|
||||
for img, tagcode in gnt.get_data_iter():
|
||||
cv2.imshow('rr', img)
|
||||
|
||||
# cv2.imshow('rr', img)
|
||||
try:
|
||||
label = struct.pack('>H', tagcode).decode('gb2312')
|
||||
cv2.waitKey(0)
|
||||
print(label)
|
||||
# img_padded = resize_padding_or_crop(80, img)
|
||||
# col_idx = i%10
|
||||
# row_idx = i//10
|
||||
# full_img[row_idx*80:(row_idx+1)*80, col_idx*80:(col_idx+1)*80] = img_padded
|
||||
# charset.append(label.replace('\x00', ''))
|
||||
# if i >= 99:
|
||||
# cv2.imshow('rrr', full_img)
|
||||
# cv2.imwrite('sample.png', full_img)
|
||||
# cv2.waitKey(0)
|
||||
# print(charset)
|
||||
# break
|
||||
# i += 1
|
||||
label = struct.pack('>H', tagcode).decode('gb2312')
|
||||
img_padded = resize_padding_or_crop(90, img)
|
||||
col_idx = i%10
|
||||
row_idx = i//10
|
||||
full_img[row_idx*90:(row_idx+1)*90, col_idx*90:(col_idx+1)*90] = img_padded
|
||||
charset.append(label.replace('\x00', ''))
|
||||
if i >= 99:
|
||||
cv2.imshow('rrr', full_img)
|
||||
cv2.imwrite('sample.png', full_img)
|
||||
cv2.waitKey(0)
|
||||
print(charset)
|
||||
break
|
||||
i += 1
|
||||
except Exception as e:
|
||||
# print(e.with_traceback(0))
|
||||
print('decode error')
|
||||
|
||||
0
dataset/casia_hwdb_1.0_1.1.tfrecord
Normal file
0
dataset/casia_hwdb_1.0_1.1.tfrecord
Normal file
3755
dataset/charactors.txt
Normal file
3755
dataset/charactors.txt
Normal file
File diff suppressed because it is too large
Load Diff
83
dataset/convert_to_tfrecord.py
Normal file
83
dataset/convert_to_tfrecord.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
generates HWDB data into tfrecord
|
||||
"""
|
||||
import struct
|
||||
import numpy as np
|
||||
import cv2
|
||||
from alfred.utils.log import logger as logging
|
||||
import tensorflow as tf
|
||||
import glob
|
||||
|
||||
|
||||
class CASIAHWDBGNT(object):
|
||||
"""
|
||||
A .gnt file may contains many images and charactors
|
||||
"""
|
||||
|
||||
def __init__(self, f_p):
|
||||
self.f_p = f_p
|
||||
|
||||
def get_data_iter(self):
|
||||
header_size = 10
|
||||
with open(self.f_p, 'rb') as f:
|
||||
while True:
|
||||
header = np.fromfile(f, dtype='uint8', count=header_size)
|
||||
if not header.size:
|
||||
break
|
||||
sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)
|
||||
tagcode = header[5] + (header[4]<<8)
|
||||
width = header[6] + (header[7]<<8)
|
||||
height = header[8] + (header[9]<<8)
|
||||
if header_size + width*height != sample_size:
|
||||
break
|
||||
image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))
|
||||
yield image, tagcode
|
||||
|
||||
|
||||
def run():
|
||||
all_hwdb_gnt_files = glob.glob('./hwdb_raw/HWDB1.1trn_gnt/*.gnt')
|
||||
logging.info('got all {} gnt files.'.format(len(all_hwdb_gnt_files)))
|
||||
logging.info('gathering charset...')
|
||||
charset = []
|
||||
for gnt in all_hwdb_gnt_files:
|
||||
hwdb = CASIAHWDBGNT(gnt)
|
||||
for img, tagcode in hwdb.get_data_iter():
|
||||
try:
|
||||
label = struct.pack('>H', tagcode).decode('gb2312')
|
||||
label = label.replace('\x00', '')
|
||||
charset.append(label)
|
||||
except Exception as e:
|
||||
continue
|
||||
charset = sorted(set(charset))
|
||||
logging.info('all got {} charactors.'.format(len(charset)))
|
||||
with open('charactors.txt', 'w') as f:
|
||||
f.writelines('\n'.join(charset))
|
||||
|
||||
tfrecord_f = 'casia_hwdb_1.0_1.1.tfrecord'
|
||||
i = 0
|
||||
with tf.io.TFRecordWriter(tfrecord_f) as tfrecord_writer:
|
||||
for gnt in all_hwdb_gnt_files:
|
||||
hwdb = CASIAHWDBGNT(gnt)
|
||||
for img, tagcode in hwdb.get_data_iter():
|
||||
try:
|
||||
img = cv.resize(img, (64, 64))
|
||||
label = struct.pack('>H', tagcode).decode('gb2312')
|
||||
label = label.replace('\x00', '')
|
||||
index = charset.index(label)
|
||||
# save img, label as example
|
||||
example = tf.train.Example(features=tf.train.Features(
|
||||
feature={
|
||||
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
|
||||
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img]))
|
||||
}))
|
||||
tfrecord_writer.write(example.SerializeToString())
|
||||
if i%500:
|
||||
logging.info('solved {} examples.'.format(i))
|
||||
i += 1
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
continue
|
||||
logging.info('done.')
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
2
dataset/get_hwdb_1.0_1.1.sh
Normal file
2
dataset/get_hwdb_1.0_1.1.sh
Normal file
@@ -0,0 +1,2 @@
|
||||
wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zip
|
||||
wget wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip
|
||||
Reference in New Issue
Block a user