update code and readme
This commit is contained in:
@@ -5,6 +5,7 @@ import threading
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
# 处理单个gnt文件获取图像与标签
|
||||
@@ -33,43 +34,53 @@ def read_from_gnt_dir(gnt_dir):
|
||||
|
||||
|
||||
def gnt_to_img(gnt_dir, img_dir):
|
||||
counter = 0
|
||||
for image, label in read_from_gnt_dir(gnt_dir=gnt_dir):
|
||||
def save_img(label, image, counter):
|
||||
label = struct.pack('>H', label).decode('gb2312')
|
||||
img = Image.fromarray(image)
|
||||
dir_name = os.path.join(img_dir, '%0.5d' % char_dict[label])
|
||||
if not os.path.exists(dir_name):
|
||||
os.mkdir(dir_name)
|
||||
img.convert('RGB').save(dir_name + '/' + str(counter) + '.png')
|
||||
print("train_counter=", counter)
|
||||
print("thread: {}, counter=".format(threading.current_thread().name), counter)
|
||||
|
||||
counter = 0
|
||||
thread_pool = ThreadPoolExecutor(4) # 定义4个线程执行此任务
|
||||
for image, label in read_from_gnt_dir(gnt_dir=gnt_dir):
|
||||
thread_pool.submit(save_img, label, image, counter)
|
||||
counter += 1
|
||||
thread_pool.shutdown()
|
||||
|
||||
|
||||
# 路径
|
||||
data_dir = './data'
|
||||
train_gnt_dir = os.path.join(data_dir, 'HWDB1.1trn_gnt')
|
||||
test_gnt_dir = os.path.join(data_dir, 'HWDB1.1tst_gnt')
|
||||
train_img_dir = os.path.join(data_dir, 'train')
|
||||
test_img_dir = os.path.join(data_dir, 'test')
|
||||
if not os.path.exists(train_img_dir):
|
||||
os.mkdir(train_img_dir)
|
||||
if not os.path.exists(test_img_dir):
|
||||
os.mkdir(test_img_dir)
|
||||
if __name__ == "__main__":
|
||||
# 路径
|
||||
data_dir = r'./data'
|
||||
train_gnt_dir = os.path.join(data_dir, 'HWDB1.1trn_gnt')
|
||||
test_gnt_dir = os.path.join(data_dir, 'HWDB1.1tst_gnt')
|
||||
train_img_dir = os.path.join(data_dir, 'train')
|
||||
test_img_dir = os.path.join(data_dir, 'test')
|
||||
if not os.path.exists(train_img_dir):
|
||||
os.mkdir(train_img_dir)
|
||||
if not os.path.exists(test_img_dir):
|
||||
os.mkdir(test_img_dir)
|
||||
|
||||
# 获取字符集合
|
||||
char_set = set()
|
||||
for _, tagcode in read_from_gnt_dir(gnt_dir=test_gnt_dir):
|
||||
tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
|
||||
char_set.add(tagcode_unicode)
|
||||
char_list = list(char_set)
|
||||
char_dict = dict(zip(sorted(char_list), range(len(char_list))))
|
||||
print(len(char_dict))
|
||||
print("char_dict=", char_dict)
|
||||
# 获取字符集合
|
||||
if not os.path.exists('char_dict'):
|
||||
char_set = set()
|
||||
for _, tagcode in read_from_gnt_dir(gnt_dir=test_gnt_dir):
|
||||
tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')
|
||||
char_set.add(tagcode_unicode)
|
||||
char_list = list(char_set)
|
||||
char_dict = dict(zip(sorted(char_list), range(len(char_list))))
|
||||
print(len(char_dict))
|
||||
print("char_dict=", char_dict)
|
||||
|
||||
with open('char_dict', 'wb') as f:
|
||||
pickle.dump(char_dict, f)
|
||||
with open('char_dict', 'wb') as f:
|
||||
pickle.dump(char_dict, f)
|
||||
else:
|
||||
with open('char_dict', 'rb') as f:
|
||||
char_dict = pickle.load(f)
|
||||
|
||||
train_thread = threading.Thread(target=gnt_to_img, args=(train_gnt_dir, train_img_dir)).start()
|
||||
test_thread = threading.Thread(target=gnt_to_img, args=(test_gnt_dir, test_img_dir)).start()
|
||||
train_thread.join()
|
||||
test_thread.join()
|
||||
train_thread = threading.Thread(target=gnt_to_img, args=(train_gnt_dir, train_img_dir)).start()
|
||||
test_thread = threading.Thread(target=gnt_to_img, args=(test_gnt_dir, test_img_dir)).start()
|
||||
train_thread.join()
|
||||
test_thread.join()
|
||||
|
||||
Reference in New Issue
Block a user