Commit fb9722f2 by Pamela Osuna

some debugs

parent 57c0fc14
...@@ -181,6 +181,9 @@ def run_kfold(X_train, X_test, y_train, y_test, params): ...@@ -181,6 +181,9 @@ def run_kfold(X_train, X_test, y_train, y_test, params):
print("Area under the curve:", auc) print("Area under the curve:", auc)
# confusion matrix # confusion matrix
if i == 0:
cm = confusion_matrix(y_test[i].argmax(axis=1), y_pred.argmax(axis=1))
else:
cm+=confusion_matrix(y_test[i].argmax(axis=1), y_pred.argmax(axis=1)) cm+=confusion_matrix(y_test[i].argmax(axis=1), y_pred.argmax(axis=1))
#pr curve (contains 4 pr curves: one for each class) #pr curve (contains 4 pr curves: one for each class)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment