network runs
This commit is contained in:
BIN
dataset/__pycache__/casia_hwdb.cpython-36.pyc
Executable file
BIN
dataset/__pycache__/casia_hwdb.cpython-36.pyc
Executable file
Binary file not shown.
@@ -7,13 +7,17 @@ we using this class to get .png and label from raw
|
||||
|
||||
"""
|
||||
from alfred.dl.tf.common import mute_tf
|
||||
|
||||
mute_tf()
|
||||
import struct
|
||||
import numpy as np
|
||||
import cv2
|
||||
import tensorflow as tf
|
||||
|
||||
import os
|
||||
|
||||
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
class CASIAHWDBGNT(object):
|
||||
"""
|
||||
@@ -52,8 +56,9 @@ def parse_example(record):
|
||||
tf.io.FixedLenFeature([], tf.string),
|
||||
})
|
||||
img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
|
||||
img = tf.cast(tf.reshape(img, (64, 64)), dtype=tf.float32)
|
||||
label = tf.cast(features['label'], tf.int32)
|
||||
return img, label
|
||||
return {'image': img, 'label': label}
|
||||
|
||||
|
||||
def load_ds():
|
||||
@@ -63,14 +68,15 @@ def load_ds():
|
||||
return ds
|
||||
|
||||
|
||||
def load_charactors():
|
||||
a = open('charactors.txt', 'r').readlines()
|
||||
def load_characters():
|
||||
|
||||
a = open(os.path.join(this_dir, 'charactors.txt'), 'r').readlines()
|
||||
return [i.strip() for i in a]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ds = load_ds()
|
||||
charactors = load_charactors()
|
||||
charactors = load_characters()
|
||||
for img, label in ds.take(9):
|
||||
# start training on model...
|
||||
img = img.numpy()
|
||||
|
||||
Reference in New Issue
Block a user