""" Distillation evaluation on EC2 using dataset stored in s3.
    First need to run setupEC2.sh or launch instance from existing AMI.
"""

import boto3
import fire
import shutil, os, time
from datetime import datetime
import numpy as np
import pandas as pd

import autogluon as ag
from autogluon import TabularPrediction as task
from autogluon.utils.tabular.ml.constants import MULTICLASS, BINARY, REGRESSION
from autogluon.utils.tabular.ml.utils import default_holdout_frac

from autogluon_utils.benchmarking.distill_benchmark.utils import get_gibbs_name, s3_sync_folder, filter_gibbs_models
from autogluon_utils.benchmarking.distill_benchmark.configs import *


def evaluate(dataset_name, profile, methods, tag = None):
    """
    User specifies command-line args: dataset (str), profile (str), methods (list[str] of methods to run), tag (TAG to apply to s3 folders generated from this run)
    See instructions on how to run this below (at run() method).
    Example:
        dataset_name = 'Helena'
        profile = 'PROFILE_FAST'
        methods = 'ALL_METHODS'
        tag = None  # or tag = 'run_test'
    """
    global GIBBS_PREFIX
    ##  Other args ##
    verbosity = 3
    FOLD = 0  # not config for now
    fold_str = str(FOLD)
    small_model_only = True  # only gibbs-augmented data from smaller transformer is considered
    ## End of Args ##

    ## Parse user args:
    profile_name = profile
    if profile in PROFILES:
        profile = PROFILES[profile]
    else:
        raise ValueError(f"{profile} is unknown profile.")

    method_possibilities = [FIT] + DISTILL_METHODS_GIBBS_POSSIBILITIES + DISTILL_METHODS_LIST
    if not isinstance(methods, list):
        if methods == ALL_METHODS:
            methods = [FIT] + DISTILL_METHODS_LIST
        elif methods == ALL_METHODS_GIBBS_SOME:
            methods = [FIT] + DISTILL_METHODS_GIBBS_SOME
        elif methods == ALL_METHODS_NO_GIBBS:
            methods = [FIT] + DISTILL_METHODS_NO_GIBBS
        elif methods == FIT:
            methods = [FIT]
        else:
            raise ValueError(f"method {methods} must be specified in list unless it is one of the ALL_METHODS_* methods")
    else:
        if ALL_DISTILL in methods:
            methods.remove(ALL_DISTILL)
            methods += DISTILL_METHODS_LIST
        elif ALL_DISTILL_GIBBS_SOME in methods:
            methods.remove(ALL_DISTILL_GIBBS_SOME)
            methods += DISTILL_METHODS_GIBBS_SOME
        elif ALL_DISTILL_NO_GIBBS in methods:
            methods.remove(ALL_DISTILL_NO_GIBBS)
            methods += DISTILL_METHODS_NO_GIBBS

    if tag is None or tag == "":
        tag = datetime.utcnow().strftime('%Y_%m_%d-%H_%M_%S')
        print(f"No tag provided, will set tag={tag}")

    subsample_size = profile['subsample_size']
    time_limits = profile['time_limits']
    distill_time = time_limits * profile['distill_time_factor']
    distill_size_factor = profile['distill_size_factor']
    gan_epochs = profile.get('gan_epochs', FULL_GAN_EPOCHS)
    print(f"Run tag={tag} on dataset={dataset_name} with: \n")
    print(profile)
    print(methods)

    # Ensure all valid methods and parse tag loader
    tag_loader_method = None
    for meth in methods:
        is_tag_loader = (meth[:len(LOAD_FROM_TAG)] == LOAD_FROM_TAG)
        if is_tag_loader:
            tag_loader_method = meth
        if not is_tag_loader and meth not in method_possibilities:
            raise ValueError(f"{meth} is invalid method.")
    if tag_loader_method is not None:
        print(f"tag_loader detected: {tag_loader_method}")

    ## Paths/names:
    dataset_shortname = get_gibbs_name(dataset_name)

    save_path = "/home/ubuntu/distillmodels/"
    s3prefix_notag = f"results/{dataset_name}/{profile_name}/"
    s3prefix_notagprofile = f"results/{dataset_name}/"
    s3prefix_save_path = f"results/{dataset_name}/{profile_name}/{tag}/"  # where save_path will be synced in s3
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    train_filename = 'raw_train.csv'
    test_filename = 'raw_test.csv'
    postdistill_ldr_csv = save_path + 'DistillLeaderboard.csv'  # where final leaderboard is saved
    predistill_ldr_csv = save_path + 'PredistillLeaderboard.csv'  # where original fit() leaderboard (pre-distillation) is saved
    metadata_csv = save_path + 'metadata.csv'

    if dataset_name in MIXED_DATASETS:
        GIBBS_PREFIX = GIBBS_PREFIX_MIXEDDATA  # folder with Gibbs sampled data for augmentation
    elif dataset_name in REGRESSION_DATASETS:
        GIBBS_PREFIX = GIBBS_PREFIX_REGRESSDATA
    if DISTILL_GIBBS_200 in methods:
        GIBBS_PREFIX = GIBBS_PREFIX_200
    s3_gibbs_folder = 's3://'+GIBBS_BUCKET+'/'+GIBBS_PREFIX  # folder with Gibbs sampled data for augmentation

    agmodel_folder = 'agModels/'
    savedir = save_path + agmodel_folder
    train_name = 'raw_train.csv'
    test_name = 'raw_test.csv'

    label_column = '__label__'
    if dataset_name.lower() in MULTICLASS_DATASETS:
        problem_type = MULTICLASS
        eval_metric = 'accuracy'  # Only acc for now
        s3_dataset_folder = 's3://'+BUCKET+'/rawdata/'+dataset_name+'/fold_'+fold_str+'/'
    elif dataset_name.lower() in BINARY_DATASETS:
        problem_type = BINARY
        eval_metric = 'accuracy'  # Only acc for now
        s3_dataset_folder = 's3://'+BUCKET+'/rawdata/'+dataset_name+'/fold_'+fold_str+'/'
    elif dataset_name in REGRESSION_DATASETS:
        eval_metric = None  # RMSE by default
        problem_type = REGRESSION
        s3_dataset_folder = 's3://'+BUCKET+'/Regression/RegressionDataframes/'+dataset_name+'/fold_'+fold_str+'/'
    else:
        raise ValueError(f"{dataset_name} is invalid dataset-name.")


    # Load data:
    train_data = task.Dataset(file_path=s3_dataset_folder + train_filename)
    test_data = task.Dataset(file_path=s3_dataset_folder + test_filename)
    num_train = len(train_data)
    num_test = len(test_data)

    s3 = boto3.client("s3")
    gibbs_files = None
    aug_data = None
    run_gibbs_distill = False
    if (DISTILL_GIBBS_R1 in methods or DISTILL_GIBBS_R5 in methods or DISTILL_GIBBS_R10 in methods or
        DISTILL_GIBBS_SOME in methods or DISTILL_GIBBS_ALL in methods or DISTILL_GIBBS_200 in methods):
        run_gibbs_distill = True
        if (DISTILL_GIBBS_R1 in methods) + (DISTILL_GIBBS_R5 in methods) + (DISTILL_GIBBS_R10 in methods) + (DISTILL_GIBBS_SOME in methods) + (DISTILL_GIBBS_ALL in methods) > 1:
            raise ValueError("cannot have more than one of: {DISTILL_GIBBS_R1},{DISTILL_GIBBS_SOME},{DISTILL_GIBBS_ALL} in methods")
        all_objects = s3.list_objects_v2(Bucket = GIBBS_BUCKET, Prefix = GIBBS_PREFIX)
        if 'Contents' not in all_objects:
            raise ValueError(f"s3://{GIBBS_BUCKET}/{GIBBS_PREFIX} not found, all_objects is empty.")
        gibbs_files = [obj['Key'] for obj in all_objects['Contents'] if (dataset_shortname in obj['Key'])]
        if len(gibbs_files) < 1:
            raise ValueError(f"No gibbs samples found for dataset: {dataset_name} in s3://{GIBBS_BUCKET}/{GIBBS_PREFIX}")
        if DISTILL_GIBBS_R1 in methods:
            gibbs_files = [filename for filename in gibbs_files if ('_r1_' in filename)]
            if len(gibbs_files) < 1:
                raise ValueError(f"_r1_ not present in any gibbs sampling file for dataset: {dataset_name} in s3://{GIBBS_BUCKET}/{GIBBS_PREFIX}")
            else:
                gibbs_files = [gibbs_files[0]]
        elif DISTILL_GIBBS_R5 in methods:
            gibbs_files = [filename for filename in gibbs_files if ('_r5_' in filename)]
            if len(gibbs_files) < 1:
                raise ValueError(f"_r5_ not present in any gibbs sampling file for dataset: {dataset_name} in s3://{GIBBS_BUCKET}/{GIBBS_PREFIX}")
            else:
                gibbs_files = [gibbs_files[0]]
        elif DISTILL_GIBBS_R10 in methods:
            gibbs_files = [filename for filename in gibbs_files if ('_r10_' in filename)]
            if len(gibbs_files) < 1:
                raise ValueError(f"_r10_ not present in any gibbs sampling file for dataset: {dataset_name} in s3://{GIBBS_BUCKET}/{GIBBS_PREFIX}")
            else:
                gibbs_files = [gibbs_files[0]]
        elif DISTILL_GIBBS_SOME in methods:
            gibbs_files = [filename for filename in gibbs_files if (('_r1_' in filename) or ('_r5_' in filename) or ('_r10_' in filename))]
            if len(gibbs_files) < 1:
                raise ValueError(f"_r1_,_r5_,_r10_ not present in any gibbs sampling file for dataset: {dataset_name} in s3://{GIBBS_BUCKET}/{GIBBS_PREFIX}")
        elif DISTILL_GIBBS_200 in methods:
            gibbs_files = [filename for filename in gibbs_files if (('_r20_' in filename) or ('_r40_' in filename) or ('_r100_' in filename) or ('_r200_' in filename))]
            if len(gibbs_files) < 1:
                raise ValueError(f"_r20_,_r40_,_r100_,_r200_ not present in any gibbs sampling file for dataset: {dataset_name} in s3://{GIBBS_BUCKET}/{GIBBS_PREFIX}")
        if small_model_only and (dataset_name != 'jasmine'):  # jasmine model names switched
            gibbs_files = filter_gibbs_models(gibbs_files)

        print(f"Will run Gibbs augmentation for these {len(gibbs_files)} files: \n {gibbs_files}")
        aug_data_path = 's3://'+GIBBS_BUCKET+'/'+gibbs_files[0]  # read in test file. Still need to read in other files if DISTILL_GIBBS_ALL was specified.
        aug_data = task.Dataset(file_path=aug_data_path)

    # generated_dataset_name = dataset_name.lower()
    # generated_suffix = generated_suffixes[generated_dataset_name]


    # Training time:
    if subsample_size is not None:
        train_data = train_data.head(subsample_size)
        test_data = test_data.head(subsample_size)

    print("train_data.head()", train_data.head())
    print("train_data.shape", train_data.shape)
    print("test_data.shape", test_data.shape)

    additional_fit_kwargs = {}
    additional_distill_kwargs = {}
    if dataset_name in HIGH_MEM_DATASETS:
        print(f"{dataset_name} is high-disk-space dataset, will reduce disk size of autogluon models...")
        additional_fit_kwargs = {'hyperparameters': HIGH_MEM_HYPERPARAMS}
        additional_distill_kwargs = additional_fit_kwargs
    if profile in SMALL_STUDENT_PROFILES:
        print("small student profile specified, students will use more-efficient hyperparameters:")
        print(SMALL_STUDENT_HYPERPARAMS)
        additional_distill_kwargs = {'hyperparameters': SMALL_STUDENT_HYPERPARAMS}

    if FIT in methods:
        shutil.rmtree(savedir, ignore_errors=True) # Delete AutoGluon output directory to ensure previous runs' information has been removed.
        predictor = task.fit(train_data=train_data, label=label_column, output_directory=savedir, problem_type=problem_type, eval_metric=eval_metric,
                             enable_fit_continuation=True, auto_stack=True, time_limits=time_limits, **additional_fit_kwargs)
        s3_sync_folder(local_folder=save_path, s3_bucket=BUCKET, s3_prefix=s3prefix_save_path, to_s3=True)
    else:
        print("Not running FIT, just loading previously-trained predictor.")
        if DISTILL_HARD_MUNGE not in methods:
            del train_data
        # Get agModels file from s3 to savedir:  # TODO: consider that TAG may be different now as it's named after fleet
        if tag_loader_method is not None:
            profile_to_load = tag_loader_method.split(TAG_SPLIT)[1]
            tag_to_load = tag_loader_method.split(TAG_SPLIT)[2]
            folder_to_load = s3prefix_notagprofile + profile_to_load + '/' + tag_to_load + '/'
            # tag_objects = s3.list_objects_v2(Bucket=BUCKET, Prefix=folder_to_load)
        else:
            folder_to_load = s3prefix_save_path  # Contains agModels/ folder
        print(f"Loading predictor from: {folder_to_load}")
        s3_sync_folder(local_folder=save_path, s3_bucket=BUCKET, s3_prefix=folder_to_load, to_s3=False)  # fetch files from TAG s3 folder.
        logfile = save_path + 'python_output.log'
        if os.path.exists(logfile):
            os.remove(logfile)
        predictor = task.load(savedir, verbosity=verbosity)

    predictor.leaderboard()
    ldrs_predistill = predictor.leaderboard(test_data) # verify ensemble is the best
    ldrs_predistill.to_csv(predistill_ldr_csv, index=False)  # save leaderboard to file
    s3_sync_folder(local_folder=save_path, s3_bucket=BUCKET, s3_prefix=s3prefix_save_path, to_s3=True)
    learner = predictor._learner
    trainer = learner.load_trainer()
    if trainer.model_best != ldrs_predistill.loc[0, 'model']:
        print(f"WARNING: {ldrs_predistill.loc[0, 'model']} better than weighted ensemble on test data")
        ensemble_is_best = False
    else:
        ensemble_is_best = True
    metadata = pd.DataFrame({'ensemble_is_best':[ensemble_is_best], 'num_train':[num_train], 'num_test':[num_test]})
    metadata.to_csv(metadata_csv, index=False)
    s3_sync_folder(local_folder=save_path, s3_bucket=BUCKET, s3_prefix=s3prefix_save_path, to_s3=True)

    # Determine num_augmented:
    holdout_frac = default_holdout_frac(trainer.num_rows_train)
    min_aug_size = 1000
    max_aug_size = 100000
    num_augmented_samples = min(max_aug_size, max(min_aug_size, int((1-holdout_frac)*trainer.num_rows_train*distill_size_factor)))
    if aug_data is not None:
        num_augmented_samples = min(num_augmented_samples, len(aug_data))

    augment_args = {'num_augmented_samples': num_augmented_samples, 'epochs': gan_epochs}

    # Gibbs augmentation methods:
    if run_gibbs_distill:
        i = 0
        for gibbs_file in gibbs_files:
            gibbs_name = "_".join(gibbs_file.split("_")[-3:-1])  # used as name-suffixes for Gibbs student models
            aug_data_path = 's3://'+GIBBS_BUCKET+'/'+gibbs_file
            aug_data = task.Dataset(file_path=aug_data_path)
            aug_data = aug_data.head(num_augmented_samples)
            if gibbs_file == gibbs_files[0]:
                print("aug_data.shape", aug_data.shape)
                print("aug_data.head()", aug_data.head())

            learner.distill(augmentation_data=aug_data, models_name_suffix="GIB_"+gibbs_name, time_limits=distill_time,
                            teacher_preds='soft', augment_args=augment_args, verbosity=verbosity, **additional_distill_kwargs)
            learner.save()
            s3_sync_folder(local_folder=save_path, s3_bucket=BUCKET, s3_prefix=s3prefix_save_path, to_s3=True)

    if DISTILL_BASELINE in methods:
        learner.distill(time_limits=distill_time, teacher_preds=None, augment_args=augment_args, verbosity=verbosity, **additional_distill_kwargs)  # Baseline
        learner.save()
        s3_sync_folder(local_folder=save_path, s3_bucket=BUCKET, s3_prefix=s3prefix_save_path, to_s3=True)
    if DISTILL_SOFT_NONE in methods:
        learner.distill(time_limits=time_limits, teacher_preds='soft', augment_method=None, augment_args=augment_args, verbosity=verbosity, **additional_distill_kwargs)  # SOFT teacher predictions on original data, no augmentation
        learner.save()
        s3_sync_folder(local_folder=save_path, s3_bucket=BUCKET, s3_prefix=s3prefix_save_path, to_s3=True)
    if DISTILL_HARD_MUNGE in methods:
        train_data = task.Dataset(file_path=s3_dataset_folder + train_filename)
        learner.distill(train_data, time_limits=time_limits, teacher_preds='hard', augment_method='munge', **additional_distill_kwargs)  # Hard teacher predictions with MUNGE augmentation
        learner.save()
        s3_sync_folder(local_folder=save_path, s3_bucket=BUCKET, s3_prefix=s3prefix_save_path, to_s3=True)
    if DISTILL_SOFT_MUNGE in methods:
        learner.distill(time_limits=distill_time, teacher_preds='soft', augment_method='munge', augment_args=augment_args, verbosity=verbosity, **additional_distill_kwargs)
        learner.save()
        s3_sync_folder(local_folder=save_path, s3_bucket=BUCKET, s3_prefix=s3prefix_save_path, to_s3=True)
    if DISTILL_SOFT_SPUNGE in methods:
        learner.distill(time_limits=distill_time, teacher_preds='soft', augment_method='spunge', augment_args=augment_args, verbosity=verbosity, **additional_distill_kwargs)
        learner.save()
        s3_sync_folder(local_folder=save_path, s3_bucket=BUCKET, s3_prefix=s3prefix_save_path, to_s3=True)
    if DISTILL_SOFT_GAN in methods:
        learner.distill(time_limits=distill_time, teacher_preds='soft', augment_method='gan', augment_args=augment_args, verbosity=verbosity, **additional_distill_kwargs)
        learner.save()
        s3_sync_folder(local_folder=save_path, s3_bucket=BUCKET, s3_prefix=s3prefix_save_path, to_s3=True)

    leaders_postdistill = learner.leaderboard(test_data)
    leaders_postdistill.to_csv(postdistill_ldr_csv, index=False)
    s3_sync_folder(local_folder=save_path, s3_bucket=BUCKET, s3_prefix=s3prefix_save_path, to_s3=True)

    print(f"Completed run {tag} on dataset {dataset_name} with: \n")
    print(profile)
    print(methods)


def run(dataset_name, profile, methods, tag = None):
    """ Run this from command line via:

    DATASET=blood-transfusion
    PROFILE=PROFILE_1H_CONSTRAINED
    METHODS=ALL_METHODS_GIBBS_SOME  # Or: METHODS=[DISTILL_GIBBS_SOME,DISTILL_BASELINE]  # To load old predictor, do: METHODS=[LOAD_FROM_TAG--TAGNAME]
    TAG=RUN1  # optional

    export PYTHONPATH='/home/ubuntu/autogluon-utils/autogluon_utils/benchmarking/distill_benchmark'
    python -u /home/ubuntu/autogluon-utils/autogluon_utils/benchmarking/distill_benchmark/run_dataset.py $DATASET $PROFILE $METHODS $TAG

    Note: if you don't include FIT in METHODS, then PROFILE & TAG must correspond to an existing PROFILE & TAG to load predictor from (ie. overwriting old result),
          or you must add include LOAD_FROM_TAG__PROFILENAME__TAGNAME in methods, eg: LOAD_FROM_TAG__PROFILE_30M_CONSTRAINED__flt1_Ami6_dstl
    METHODS must be list unless it is one of the ALL_* methods.


    """
    evaluate(dataset_name, profile, methods, tag)

if __name__ == '__main__':
    fire.Fire(run)  # CLI wrapper






