Commit fb9722f2 by Pamela Osuna

some debugs

parent 57c0fc14
......@@ -181,7 +181,10 @@ def run_kfold(X_train, X_test, y_train, y_test, params):
print("Area under the curve:", auc)
# confusion matrix
cm+=confusion_matrix(y_test[i].argmax(axis=1), y_pred.argmax(axis=1))
if i == 0:
cm = confusion_matrix(y_test[i].argmax(axis=1), y_pred.argmax(axis=1))
else:
cm+=confusion_matrix(y_test[i].argmax(axis=1), y_pred.argmax(axis=1))
#pr curve (contains 4 pr curves: one for each class)
recall, precision, average_prec = pr(N_CLASSES, y_test[i], y_pred)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment