Files
DeepHSV/visualize/visualization.py
2019-05-06 17:43:25 +08:00

102 lines
3.5 KiB
Python

"""
@file: model.py
@time: 2018/6/17 15:03
@desc:
"""
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn.metrics import confusion_matrix, accuracy_score, roc_auc_score, roc_curve
def compute_er(y_true, y_prob):
fpr, tpr, thresholds = roc_curve(y_true, y_prob, pos_label=True)
sum_sensitivity_specificity_train = tpr + (1 - fpr)
best_threshold_id = np.argmax(sum_sensitivity_specificity_train)
best_threshold = thresholds[best_threshold_id]
y = y_prob > best_threshold
cm_test = confusion_matrix(y_true, y)
acc_test = accuracy_score(y_true, y)
auc_test = roc_auc_score(y_true, y)
print('Test Accuracy: %s ' % acc_test)
print('Test AUC: %s ' % auc_test)
print('Test Confusion Matrix:')
print(cm_test)
tpr_score = float(cm_test[1][1]) / (cm_test[1][1] + cm_test[1][0])
fpr_score = float(cm_test[0][1]) / (cm_test[0][0] + cm_test[0][1])
return fpr, tpr
def read_y_prob(filename):
TwoChannel2logit = np.loadtxt(filename)
siamese = np.split(TwoChannel2logit, 2, axis=1)
y_true = siamese[1]
y_prob = siamese[0]
return y_true, y_prob
def visualize_roc():
y_true_2logit, y_prob_2logit = read_y_prob('distribution_2Channel2logit_CEDAR.txt')
y_true_1logit, y_prob_1logit = read_y_prob('distribution_2ChannelsCNN_CEDAR.txt')
y_true_siamese, y_prob_siamese = read_y_prob('distribution_siamese_CEDAR.txt')
fpr_siamese, tpr_siamese = compute_er(y_true_siamese, y_prob_siamese)
fpr_1logit, tpr_1logit = compute_er(y_true_1logit, y_prob_1logit)
fpr_2logit, tpr_2logit = compute_er(y_true_2logit, y_prob_2logit)
fig = plt.figure(figsize=(5, 5))
ax2 = fig.add_subplot(111)
curve1 = ax2.plot(fpr_siamese, tpr_siamese)
curve2 = ax2.plot(fpr_1logit, tpr_1logit)
curve3 = ax2.plot(fpr_2logit, tpr_2logit)
curve4 = ax2.plot([0, 1], [0, 1], color='navy', linestyle='--')
# dot = ax2.plot(fpr_score, tpr_score, marker='o', color='black')
# ax2.text(fpr_score, tpr_score, s='(%.3f,%.3f)' % (fpr_score, tpr_score))
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
# plt.title('ROC curve (Test), AUC = %.4f' % auc_test)
params = {'legend.fontsize': 15,
'legend.handlelength': 2}
plt.rcParams.update(params)
plt.legend(['Siamese', '2ChannelCNN', '2Channel2logit'])
plt.savefig('ROC_CEDAR_with_backgroud', dpi=500)
plt.show()
visualize_roc()
def get_auc():
writer_val = tf.summary.FileWriter('C:\work\Projects\HWS_ID\\test\\2Channels\\val')
writer_train = tf.summary.FileWriter('C:\work\Projects\HWS_ID\\test\\2Channels\\train')
auc_var = tf.Variable(0.0)
tf.summary.scalar("auc", auc_var)
write_op = tf.summary.merge_all()
session = tf.InteractiveSession()
session.run(tf.global_variables_initializer())
for e in tf.train.summary_iterator(
"C:\work\Projects\HWS_ID\\test\\2Channels\\2channelscnn.Deep-Ubantu"):
for v in e.summary.value:
if 'auc' in v.tag:
summary = session.run(write_op, {auc_var: v.simple_value})
writer_train.add_summary(summary, e.step)
writer_train.flush()
for e in tf.train.summary_iterator(
"C:\work\Projects\HWS_ID\\test\\2Channels\\2channelsoftmax.Deep-Ubantu"):
for v in e.summary.value:
if 'auc' in v.tag:
summary = session.run(write_op, {auc_var: v.simple_value})
writer_val.add_summary(summary, e.step)
writer_val.flush()