Files
hand-writing-recognition/process_gnt.py
2019-08-25 19:07:44 +08:00

87 lines
3.1 KiB
Python

import os
import struct
import pickle
import threading
import numpy as np
from PIL import Image
from concurrent.futures import ThreadPoolExecutor
# 处理单个gnt文件获取图像与标签
def read_from_gnt_dir(gnt_dir):
def one_file(f):
header_size = 10
while True:
header = np.fromfile(f, dtype='uint8', count=header_size)
if not header.size:
break
sample_size = header[0] + (header[1] << 8) + (header[2] << 16) + (header[3] << 24)
label = header[5] + (header[4] << 8)
width = header[6] + (header[7] << 8)
height = header[8] + (header[9] << 8)
if header_size + width * height != sample_size:
break
image = np.fromfile(f, dtype='uint8', count=width * height).reshape((height, width))
yield image, label
for file_name in os.listdir(gnt_dir):
if file_name.endswith('.gnt'):
file_path = os.path.join(gnt_dir, file_name)
with open(file_path, 'rb') as f:
for image, label in one_file(f):
yield image, label
def gnt_to_img(gnt_dir, img_dir):
def save_img(label, image, counter):
label = struct.pack('>H', label).decode('gb2312')
img = Image.fromarray(image)
dir_name = os.path.join(img_dir, '%0.5d' % char_dict[label])
if not os.path.exists(dir_name):
os.mkdir(dir_name)
img.convert('RGB').save(dir_name + '/' + str(counter) + '.png')
print("thread: {}, counter=".format(threading.current_thread().name), counter)
counter = 0
thread_pool = ThreadPoolExecutor(4) # 定义4个线程执行此任务
for image, label in read_from_gnt_dir(gnt_dir=gnt_dir):
thread_pool.submit(save_img, label, image, counter)
counter += 1
thread_pool.shutdown()
if __name__ == "__main__":
# 路径
data_dir = r'./data'
train_gnt_dir = os.path.join(data_dir, 'HWDB1.1trn_gnt')
test_gnt_dir = os.path.join(data_dir, 'HWDB1.1tst_gnt')
train_img_dir = os.path.join(data_dir, 'train')
test_img_dir = os.path.join(data_dir, 'test')
if not os.path.exists(train_img_dir):
os.mkdir(train_img_dir)
if not os.path.exists(test_img_dir):
os.mkdir(test_img_dir)
# 获取字符集合
if not os.path.exists('char_dict'):
char_set = set()
for _, tagcode in read_from_gnt_dir(gnt_dir=test_gnt_dir):
tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
char_set.add(tagcode_unicode)
char_list = list(char_set)
char_dict = dict(zip(sorted(char_list), range(len(char_list))))
print(len(char_dict))
print("char_dict=", char_dict)
with open('char_dict', 'wb') as f:
pickle.dump(char_dict, f)
else:
with open('char_dict', 'rb') as f:
char_dict = pickle.load(f)
train_thread = threading.Thread(target=gnt_to_img, args=(train_gnt_dir, train_img_dir)).start()
test_thread = threading.Thread(target=gnt_to_img, args=(test_gnt_dir, test_img_dir)).start()
train_thread.join()
test_thread.join()