from cnn import CNN_Antifrag
from parser import parse_data
import itertools
from confusion_matrix import plot_confusion_matrix

"""
(c,b,e,o) will be read from the command line or a script
(c,b,e,o) corresponds to the combinations of the specific hyperparameters 
to build the model
c belongs to {0,1,2,3} and represents the layer architecture
b belongs to {0,1} and represents the batch size
e belongs to {0,1} and represents the number of epochs
o belongs to {0,1,2} and represents the balancing method
"""

c_ = [0,1,2,3]
b_ = [1]
e_ = [0]
o_ = [0,1,2]

n_experiences = 100
combinations = itertools.product(c_,b_,e_,o_)


#parse the data
input_, output_ = parse_data(n_experiences,kind='linear')
#%%
max_avg_auc = 0

for params in combinations:
    cnn = CNN_Antifrag(name="CNN_%d_%d_%d_%d"%params)
    avg_acc, avg_auc  = cnn.run_nn(input_, output_, params)
    if avg_auc > max_avg_auc:
        max_avg_auc = avg_auc
        max_params = params
        
#%%
print("Best params:",max_params)
# once we have chosen the optimal parameters we can do the normal kfold
cnn = CNN_Antifrag(name="CNN_%d_%d_%d_%d"%max_params)
acc, auc, cm, pr = cnn.run_kfold(input_, output_, max_params)
#to add: precision recall curve

#%%
labels = [
    '[~R & ~E]', 
    '[~R &  E]', 
    '[ R & ~E]', 
    '[ R &  E]'
    ]
#this function saves the matrix image automatically
plot_confusion_matrix(cm, labels) 

f = open("out/acc_auc.txt", 'w+')
f.write("Average accuracy: " + str(acc)+"\n")
f.write("Average area under the curve: " + str(auc))
f.close()

