import pickle
import numpy as np
import tensorflow as tf
from keras import datasets
from tensorflow import keras
from keras.utils import to_categorical
import argparse
import yaml
from keras.optimizers import Adam
from test_core.load_test_data import load_cifar_corruption, load_cinic10_test, load_svhn_test, load_cifar10, load_tinyimage_test
import time


def load_model(full_path):
    snn = tf.keras.saving.load_model(full_path+"_modelRes.keras")
    return snn

def single_model_evaluate(model, x_test, y_test, IFAcc):
    pred = model.predict(x_test)
    eps = 1e-12
    entropy = -np.sum(pred*np.log2(pred + eps), axis=-1)

    if IFAcc:
        m = tf.keras.metrics.CategoricalAccuracy()
        m.update_state(y_test, pred)
        acc = m.result().numpy()
        m.reset_state()
    else:
        acc = None
    return pred, acc, entropy

def ensembl_evaluate(preds, y_test, IFAcc):
    pred_ensemble = np.mean(preds, axis=0)
    
    # print(preds.shape)
    
    eps = 1e-12
    
    # tu = -np.sum(np.mean(preds, axis=0)*np.log2(np.mean(preds, axis=0) + eps))
    tu = -np.sum(pred_ensemble*np.log2(pred_ensemble + eps), axis=-1)
    # print(tu.shape)
    au = np.mean(-np.sum(preds*np.log2(preds + eps), axis=-1), axis=0)  
    eu = tu - au
   
    entropy = dict()
    entropy['TU'] = tu
    entropy['EU'] = eu
    entropy['AU'] = au

    if IFAcc:
        m = tf.keras.metrics.CategoricalAccuracy()
        m.update_state(y_test, pred_ensemble)
        acc = m.result().numpy()
        m.reset_state()  
    else:
        acc = None
    return pred_ensemble, acc, entropy

def snn_evaluation(dataset):
    seeds = [0, 66, 99, 314, 524, 803, 888, 908, 1103, 1208, 7509, 11840, 40972, 46857, 54833]
    
    ######### Get 15 Ensembles #########
    with open('three_ensembles', 'rb') as file:
        DEs3 = pickle.load(file)
        
    with open('five_ensembles', 'rb') as file:
        DEs5 = pickle.load(file)
        
    ######### Load models #########
    model_list = list()
    for i in range(15):
        model_path = 'train_resultsSNNCIFAR100/'+str(seeds[i])
        snn = load_model(model_path)
        model_list.append(snn)

    ######## Unified dictionary for saving the results ########
    cifar = dict()
    cifar3 = dict()
    cifar5 = dict()

    if dataset == 'CIFAR10':
        ######### Test on CIFAR10 dataset #########
        cifar = {'pred': [], 'acc': [], 'entro': [], 'label': []}
        cifar3 = {'pred': [], 'acc': [], 'entro': []}
        cifar5 = {'pred': [], 'acc': [], 'entro': []}
        
        (_, _), (x_test, y_test) = datasets.cifar100.load_data()
    
        x_test = x_test / 255.0
        x_test = x_test.astype('float32')
        y_test = to_categorical(y_test, 100)
    
        # standard normalizing
        x_test = (x_test - np.array([[[0.4914, 0.4822, 0.4465]]])) / np.array([[[0.2023, 0.1994, 0.2010]]])
        x_cifar = x_test
        y_cifar = y_test
        # (_, _), (x_cifar, y_cifar) = load_cifar10()
        ################################
        ######### Single Model #########
        ################################
        for i in range(15):
            model = model_list[i]
            pred, acc, entropy = single_model_evaluate(model, x_cifar, y_cifar, IFAcc=True)
            # Save the result of single model
            cifar['pred'].append(pred)
            cifar['acc'].append(acc)
            cifar['entro'].append(entropy)

        ################################
        ######### Ensembles-3 ##########
        ################################
        preds_15 = np.stack(cifar['pred'])

        for j in range(15):
            DEs3Index = DEs3[str(j)]
            preds = preds_15[DEs3Index,]
            pred_ensemble, acc, entropy = ensembl_evaluate(preds, y_cifar, IFAcc=True)
            # Save the result of ensemble
            cifar3['pred'].append(pred_ensemble)
            cifar3['acc'].append(acc)
            cifar3['entro'].append(entropy)

        ################################
        ######### Ensembles-5 ##########
        ################################
        for k in range(15):
            DEs5Index = DEs5[str(k)]
            preds = preds_15[DEs5Index,]
            pred_ensemble, acc, entropy = ensembl_evaluate(preds, y_cifar, IFAcc=True)
            # Save the result of ensemble
            cifar5['pred'].append(pred_ensemble)
            cifar5['acc'].append(acc)
            cifar5['entro'].append(entropy)

        cifar['label'] = y_cifar

    elif 'CIFAR10C' in dataset:
        ######### Test on CIFAR10_C dataset #########  
        ######### Determine Corruption Type and Severity Level #########
        parts = dataset.split('_')
        if len(parts) == 3:
            cor_type = int(parts[1])
            sev_level = int(parts[2])
        else:
            print("Invalid dataset string format.")
                            
        x_test, y_test = load_cifar_corruption(cor_type, sev_level)
        
        cifar = {'pred': [], 'acc': [], 'entro': []}
        cifar3 = {'pred': [], 'acc': [], 'entro': []}
        cifar5 = {'pred': [], 'acc': [], 'entro': []}
        for i in range(15):
            model = model_list[i]
            pred, acc, entropy = single_model_evaluate(model, x_test, y_test, IFAcc=True)

            # Save the result of single model
            cifar['pred'].append(pred)
            cifar['acc'].append(acc)
            cifar['entro'].append(entropy)

        ################################
        ######### Ensembles-3 ##########
        ################################
        preds_15 = np.stack(cifar['pred'])

        for j in range(15):
            DEs3Index = DEs3[str(j)]
            preds = preds_15[DEs3Index,]
            pred_ensemble, acc, entropy = ensembl_evaluate(preds, y_test, IFAcc=True)
            # Save the result of ensemble
            cifar3['pred'].append(pred_ensemble)
            cifar3['acc'].append(acc)
            cifar3['entro'].append(entropy)

        ################################
        ######### Ensembles-5 ##########
        ################################
        for k in range(15):
            DEs5Index = DEs5[str(k)]
            preds = preds_15[DEs5Index,]
            pred_ensemble, acc, entropy = ensembl_evaluate(preds, y_test, IFAcc=True)
            # Save the result of ensemble
            cifar5['pred'].append(pred_ensemble)
            cifar5['acc'].append(acc)
            cifar5['entro'].append(entropy)

    elif dataset == 'SVHN':
        ######### Test on SVHN dataset #########
        cifar = {'pred': [], 'entro': [],}
        cifar3 = {'pred': [], 'entro': [],}
        cifar5 = {'pred': [], 'entro': [],}

        x_svhn, y_svhn = load_svhn_test()
        
        for i in range(15):           
            model = model_list[i]
            pred, _, entropy = single_model_evaluate(model, x_svhn, y_svhn, IFAcc=False)

            # Save the result of single model
            cifar['pred'].append(pred)
            cifar['entro'].append(entropy)

        ################################
        ######### Ensembles-3 ##########
        ################################
        preds_15 = np.stack(cifar['pred'])

        for j in range(15):
            DEs3Index = DEs3[str(j)]
            preds = preds_15[DEs3Index,]
            pred_ensemble, _, entropy = ensembl_evaluate(preds, y_svhn, IFAcc=False)
            # Save the result of ensemble
            cifar3['pred'].append(pred_ensemble)
            cifar3['entro'].append(entropy)

        ################################
        ######### Ensembles-5 ##########
        ################################
        for k in range(15):
            DEs5Index = DEs5[str(k)]
            preds = preds_15[DEs5Index,]
            pred_ensemble, _, entropy = ensembl_evaluate(preds, y_svhn, IFAcc=False)
            # Save the result of ensemble
            cifar5['pred'].append(pred_ensemble)
            cifar5['entro'].append(entropy)
 
    elif dataset == 'TinyImage':
        ######### Test on CIFAR10 dataset #########
        cifar = {'pred': [], 'entro': []}
        cifar3 = {'pred': [], 'entro': []}
        cifar5 = {'pred': [], 'entro': []}

        x_tiny, y_tiny = load_tinyimage_test()
        ################################
        ######### Single Model #########
        ################################
        for i in range(15):
            model = model_list[i]
            pred, acc, entropy = single_model_evaluate(model, x_tiny, y_tiny, IFAcc=False)
            # Save the result of single model
            cifar['pred'].append(pred)
            cifar['entro'].append(entropy)
        
        ################################
        ######### Ensembles-3 ##########
        ################################
        preds_15 = np.stack(cifar['pred'])

        for j in range(15):
            DEs3Index = DEs3[str(j)]
            preds = preds_15[DEs3Index,]
            pred_ensemble, _, entropy = ensembl_evaluate(preds, y_tiny, IFAcc=False)
            # Save the result of ensemble
            cifar3['pred'].append(pred_ensemble)
            cifar3['entro'].append(entropy)

        ################################
        ######### Ensembles-5 ##########
        ################################
        for k in range(15):
            DEs5Index = DEs5[str(k)]
            preds = preds_15[DEs5Index,]
            pred_ensemble, _, entropy = ensembl_evaluate(preds, y_tiny, IFAcc=False)
            # Save the result of ensemble
            cifar5['pred'].append(pred_ensemble)
            cifar5['entro'].append(entropy)      

    else:
        print("Invalid Dataset Name. Try again...")
    return cifar, cifar3, cifar5


def load_config(yaml_file):
    with open(yaml_file, 'r') as file:
        config = yaml.load(file, Loader=yaml.SafeLoader)
    return config


def main():
    
    # Accept a YAML file as a command-line argument
    parser = argparse.ArgumentParser(description='Process parameters from a YAML file.')
    parser.add_argument('config_file', type=str, help='Path to the YAML configuration file')
    args = parser.parse_args()

    config = load_config(args.config_file)

    # Access hyperparameters from the loaded configuration
    # exp_num = config['ExpNum']
    exp_num = 1
    dataset_name = config['Dataset']
    delta = config['Delta']
    
    start_time = time.time()
    result, result3, result5 = snn_evaluation(dataset_name)
    end_time = time.time()
    print(end_time - start_time)
    full_path = 'test_resultsSNNCIFAR100/' + dataset_name
 
    
    # Save test history
    with open(full_path + '_result', 'wb') as file:
        pickle.dump(result, file)
    with open(full_path + '_result3', 'wb') as file3:
        pickle.dump(result3, file3)
    with open(full_path + '_result5', 'wb') as file5:
        pickle.dump(result5, file5)

if __name__ == "__main__":
    main()
