Commit 5b2de05d by Pamela Osuna

pr curve

parent cef6a9d2
import models as m
import roc_auc as ra
from prec_recall import pr, avg_pr
import sys
from sklearn.metrics import confusion_matrix
......@@ -9,6 +10,7 @@ from imblearn.over_sampling import SMOTE
from sklearn.model_selection import KFold
from tensorflow.keras.utils import to_categorical
## global variables
N_SPLITS = 5 # for the kfold
N_CLASSES = 4
......@@ -124,13 +126,12 @@ def run_nn(input_, output_, n_experiences, params):
print("Average accuracy: ", total_acc)
print("Average area under the curve: ", total_auc)
return total_acc, total_auc, X_train_kfold, X_test_kfold, train_Y_one_hot, validation_Y_one_hot
return total_acc, total_auc, X_train_kfold, X_test_kfold, y_train_kfold, y_test_kfold
def run_kfold(X_train, X_test, y_train, y_test, params):
c, b, e = params
for i in range(N_SPLITS):
# change the labels from categorical to one-hot encoding
y_train[i] = to_categorical(y_train[i], num_classes = 4)
......@@ -156,8 +157,9 @@ def run_kfold(X_train, X_test, y_train, y_test, params):
total_acc = 0
total_auc = 0
cm_tab = []
pr_tab = []
precs_k = [] #it will contain the average pr curve for each class
recs_k = []
avgs_k = []
bs, ep = m.choose_batch_epochs(b,e)
......@@ -179,19 +181,23 @@ def run_kfold(X_train, X_test, y_train, y_test, params):
print("Area under the curve:", auc)
# confusion matrix
cm_tab.append(confusion_matrix(y_test[i].argmax(axis=1), y_pred.argmax(axis=1)))
cm+=confusion_matrix(y_test[i].argmax(axis=1), y_pred.argmax(axis=1))
#pr curve (contains 4 pr curves: one for each class)
recall, precision, average_prec = pr(N_CLASSES, y_test[i], y_pred)
recs_k.append(recall)
precs_k.append(precision)
avgs_k.append(average_prec)
#pr curve (1 for each class)
#average of acc, auc, cm, pr
total_acc = total_acc/(N_SPLITS)
total_auc = total_acc/(N_SPLITS)
cm = sum([cm_tab[j] for j in range(N_SPLITS)])/5
cm = cm/N_SPLITS
pr = avg_pr(N_SPLITS, N_CLASSES, recs_k, precs_k, avgs_k)
print("Average accuracy: ", total_acc)
print("Average area under the curve: ", total_auc)
return total_acc, total_auc, cm_tab, pr_tab
return total_acc, total_auc, cm, pr
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment