########### PRECISION - RECALL CURVE ##########
from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_recall_curve
import numpy as np
import matplotlib.pyplot as plt

def create_pr(num_classes, y_test, y_pred):
    # For each class
    precision = dict()
    recall = dict()
    average_precision = dict()
    for i in range(num_classes):
        precision[i], recall[i], _ = precision_recall_curve(y_test[:, i], y_pred[:, i])
        average_precision[i] = average_precision_score(y_test[:, i], y_pred[:, i])

    # A "micro-average": quantifying score on all classes jointly
    precision["micro"], recall["micro"], _ = precision_recall_curve(y_test.ravel(), y_pred.ravel())
    average_precision["micro"] = average_precision_score(y_test, y_pred, average="micro")
    print('Average precision score, micro-averaged over all classes: {0:0.2f}'.format(average_precision["micro"]))

    return recall, precision, average_precision

def plot_pr(recall, precision, average_precision):
    #plotting
    plt.figure()
    plt.step(recall['micro'], precision['micro'], color='b', alpha=0.2, where='post')
    plt.fill_between(recall["micro"], precision["micro"], alpha=0.2, color='b')

    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.ylim([0.0, 1.05])
    plt.xlim([0.0, 1.0])
    plt.title('Average precision score, micro-averaged over all classes: AP={0:0.2f}'.format(average_precision["micro"]))
    plt.savefig("out/precision_recall_curve.pdf")
    #plt.show()
    plt.close()


def avg_pr(n_splits, num_classes, recs_k, precs_k, avgs_k):
    prec_per_class = [[precs_k[k][i] for k in range(n_splits)] for i in range(num_classes)]
    rec_per_class = [[recs_k[k][i] for k in range(n_splits)] for i in range(num_classes)]
    avg_prec_per_class = [[avgs_k[k][i] for k in range(n_splits)] for i in range(num_classes)]

    # First aggregate all points for every recall curve of one class
    all_recall_per_class = []
    for i in range(num_classes):
        all_recall_per_class.append(np.unique(np.concatenate([recs_k[k][i] for k in range(n_splits)])))

    mean_prec_per_class = []
    for i in range(num_classes):
        mean_prec_per_class.append(np.zeros_like(all_recall_per_class[i]))

    avg_prec = np.zeros(num_classes)
    for i in range(num_classes): # for a determinated class
        for k in range(n_splits):
            mean_prec_per_class[i] += np.interp(all_recall_per_class[i], rec_per_class[i][k][::-1],prec_per_class[i][k][::-1])
            avg_prec[i]+= avg_prec_per_class[i][k]
            print(avg_prec[i])
        mean_prec_per_class[i] /= n_splits
        avg_prec[i] /= n_splits

        plt.figure()
        plt.step(all_recall_per_class[i], mean_prec_per_class[i], color='b', alpha=0.2, where='post')
        plt.fill_between(all_recall_per_class[i], mean_prec_per_class[i], alpha=0.2, color='b')
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.ylim([0.0, 1.05])
        plt.xlim([0.0, 1.0])
        plt.title('Average precision score, over class {0}: AP={1:0.2f}'.format(i, avg_prec[i]))
        plt.savefig("out/pr_curve_class_" +str(i)+".pdf")
