Commit 57c0fc14 by Pamela Osuna

added function thatcalculates average of pr curve for each class

parent 5b2de05d
......@@ -8,7 +8,7 @@ def pr(num_classes, y_test, y_pred):
recall = dict()
average_precision = dict()
for i in range(num_classes):
precision[i], recall[i], _ = precision_recall_curve(y_test_one_hot[:, i], y_pred[:, i])
precision[i], recall[i], _ = precision_recall_curve(y_test[:, i], y_pred[:, i])
average_precision[i] = average_precision_score(y_test[:, i], y_pred[:, i])
# A "micro-average": quantifying score on all classes jointly
......@@ -32,3 +32,37 @@ def plot_pr(recall, precision, average_precision):
plt.savefig("precision_recall_curve")
#plt.show()
plt.close()
def avg_pr(n_splits, num_classes, recs_k, precs_k, avgs_k):
prec_per_class = [[precs_k[k][i] for k in range(n_splits)] for i in range(num_classes)]
rec_per_class = [[recs_k[k][i] for k in range(n_splits)] for i in range(num_classes)]
avg_prec_per_class = [[avgs_k[k][i] for k in range(n_splits)] for i in range(num_classes)]
# First aggregate all points for every recall curve of one class
all_recall_per_class = []
for i in range(num_classes):
all_recall_per_class.append(np.unique(np.concatenate([recs_k[k][i] for k in range(n_splits)])))
mean_prec_per_class = []
for i in range(num_classes):
mean_prec_per_class.append(np.zeros_like(all_recall_per_class[i]))
avg_prec = np.zeros(num_classes)
for i in range(num_classes): # for a determinated class
for k in range(n_splits):
mean_prec_per_class[i] += np.interp(all_recall_per_class[i], rec_per_class[i][k][::-1],prec_per_class[i][k][::-1])
avg_prec[i]+= avg_prec_per_class[i][k]
print(avg_prec[i])
mean_prec_per_class[i] /= n_splits
avg_prec[i] /= n_splits
plt.figure()
plt.step(all_recall_per_class[i], mean_prec_per_class[i], color='b', alpha=0.2, where='post')
plt.fill_between(all_recall_per_class[i], mean_prec_per_class[i], alpha=0.2, color='b')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])
plt.title('Average precision score, over class {0}: AP={1:0.2f}'.format(i, avg_prec[i]))
plt.savefig("pr_curve_class_" +str(i))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment