import tensorflow as tf
from tensorflow import keras
from keras.optimizers import Adam
from keras import datasets
from keras.utils import to_categorical
import numpy as np
import time
import argparse
import yaml
import logging
import pickle

from models.crenet import CreNetModel
from models.training_loop import training_loop


def load_config(yaml_file):
    with open(yaml_file, 'r') as file:
        config = yaml.load(file, Loader=yaml.SafeLoader)
    return config


def main():
    
    # Accept a YAML file as a command-line argument
    parser = argparse.ArgumentParser(description='Process parameters from a YAML file.')
    parser.add_argument('config_file', type=str, help='Path to the YAML configuration file')
    args = parser.parse_args()

    config = load_config(args.config_file)

    # Access hyperparameters from the loaded configuration
    exp_num = config['ExpNum']
    seed = config['Seed']
    delta = config['Delta']
    batch_size = config['BatchSize']
    epochs = config['Epochs']
    learning_rate = config['LearningRate']
    pre_weights = config['PreWeights']
    verbose = config['Verbose']
    backbone = config['Backbone']
    
    es = config['EarlyStopping']

    # Define the save path
    full_path = 'train_results/'+str(delta)+'/'+str(exp_num)+'_'+str(seed)
    full_path_his = 'train_results/'+str(delta)+'/his/'+str(exp_num)+'_'+str(seed)
    
    # Set random seed
    keras.utils.set_random_seed(seed)
    tf.config.experimental.enable_op_determinism()
    
    # Prepare training dataset
    (x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
    
    x_train = x_train / 255.0
    x_test = x_test / 255.0
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    y_train = to_categorical(y_train, 10)
    y_test = to_categorical(y_test, 10)
    
    val_samples = -10000
    
    x_val = x_train[val_samples:]
    y_val = y_train[val_samples:]
    x_train = x_train[:val_samples]
    y_train = y_train[:val_samples]

    # Define MyEarlyStopping 
    ###########################################
    ###########################################
    def MyEarlyStopping(log, epoch, monitor=es['Monitor'], min_delta=es['MinDelta'], patience=es['Patience'], verbose=True, mode="auto", start_from_epoch=es['StartEpoch']):
        # Determin monitor mode
        if mode not in ["auto", "min", "max"]:
            logging.warning(
                "EarlyStopping mode %s is unknown, fallback to auto mode.",
                mode,
            )
            mode = "auto"
        if mode == "min":
            monitor_op = np.less
        elif mode == "max":
            monitor_op = np.greater
        else:
            if (monitor.endswith("Acc-U") or monitor.endswith("Acc-L")):
                monitor_op = np.greater
            else:
                monitor_op = np.less
    
        if monitor_op == np.greater:
            min_delta *= 1
        else:
            min_delta *= -1
            
        # Epoch-numbering starts from 1
        index = epoch-1
    
        if epoch-1 <= start_from_epoch or index < patience:
            return False
        else:
            check = [not monitor_op(log[monitor][index]-min_delta, log[monitor][i]) for i in range(index-patience, index)]
            if all(check):
                if verbose:
                    print("Early stopping triggered.")
                return True
            else:
                return False
    ##########################################
    ##########################################
    
    # Build and compile IntResNet
    opt=Adam(learning_rate=learning_rate)

    creNet = CreNetModel(backbone=backbone, classes=10, input_shape=(32, 32, 3), delta=delta, pre_weights=pre_weights, task='multi')
    creNet.compile(optimizer=opt)

    # Training process
    if es['Enable']:
        my_stopping = MyEarlyStopping
    else:
        my_stopping = None
        
    start = time.time()   
    result = training_loop(
        creNet, x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=verbose, val_data=(x_val, y_val), val_batch_size=batch_size, path=full_path,lr_schedule=None,
        early_stopping=my_stopping
    )
    end = time.time()
    print(end-start)
    
    # Save trainig history
    with open(full_path_his + '_result', 'wb') as file:
        pickle.dump(result, file)

if __name__ == "__main__":
    main()
    