import argparse
import datasets
import numpy as np
import torch
import ray.tune as tune
from sklearn.metrics import balanced_accuracy_score
from datasets import load_from_disk, concatenate_datasets
from typing import List, Dict, Any
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    AutoConfig,
    Trainer,
    TrainingArguments,
    DefaultDataCollator,
    DataCollatorWithPadding,
    set_seed
)

parser = argparse.ArgumentParser()
parser.add_argument('--dataset',  default='imdb_goodbad_allpara')
parser.add_argument('--model', type=str, default='roberta-base')
parser.add_argument('--num_steps', type=int, default=500)
parser.add_argument('--tune', action='store_true') # whether to run hparam sweep
parser.add_argument('--weak_val', action='store_true') # early stopping based on weak validation instead of strong
parser.add_argument('--gold_train', action='store_true') # use the gold labels for training (establish blue-sky perf)
parser.add_argument('--eval_only', action='store_true') # whether to resume from latest ckpt. if training is done, this just runs val
parser.add_argument('--train_fraction', type=float, default=1.0) # whether to resume from latest ckpt. if training is done, this just runs val
parser.add_argument('--train_segment', type=int, default=0) # if train_fraction < 1, which segment to use for training; max segment * train_fraction=1.0
parser.add_argument('--debug', action='store_true') # subsamples test datasets for faster eval
parser.add_argument('--robust', action='store_true') # whether to enforce that f(x) = f(n(x))
parser.add_argument('--robust_coef', type=float, default=1.0) # coef to use on TV loss term
parser.add_argument('--wd', type=float, default=0.1) # coef to use for weight decay
parser.add_argument('--linear', action='store_true') # whether to just tune CLS token repr
parser.add_argument('--projection_model', type=str, default=None) # what pretrained projection to use
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()

set_seed(args.seed)
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') # for bow

class LinearTrainer(Trainer):
    def __init__(self, **kwargs):
        do_onehot = kwargs.pop('do_onehot', False)
        embedding_dim = kwargs.pop('embedding_dim', None)

        super().__init__(**kwargs)
        self.loss = torch.nn.CrossEntropyLoss()
        self.do_onehot = do_onehot
        self.embedding_dim=embedding_dim

    def compute_loss(self, model, inputs, return_outputs=False):
        if "weak_label" in inputs: # eval time, using true labels
            inputs.pop("weak_label")

        labels = inputs.pop("labels")
        input_embeds = inputs['text_embeds']

        outputs = model(input_embeds)
        loss = self.loss(outputs, labels)
        outputs = {'logits': outputs, 'loss': loss}

        if return_outputs:
            return (loss, outputs)
        else:
            return loss


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = predictions.argmax(axis=-1)
    balacc = balanced_accuracy_score(labels, predictions)
    return {"accuracy": (predictions == labels).astype(float).mean(), "balacc": balacc}


if __name__ == '__main__':
    #### Load dataset
    dsdict = load_from_disk(args.dataset)
    print(dsdict)
    S_keys_train = [k for k in dsdict if k.startswith('S') and k.endswith('train')]
    T_keys_train = [k for k in dsdict if k.startswith('T') and k.endswith('train')]

    S_keys_test = [k for k in dsdict if k.startswith('S') and k.endswith('test')]
    T_keys_test = [k for k in dsdict if k.startswith('T') and k.endswith('test')]

    Strain = concatenate_datasets([dsdict[k] for k in S_keys_train])

    if not args.gold_train:
        Strain = Strain.remove_columns(["label"])
        Strain = Strain.rename_column("weak_label", "label")

    Ttrain = concatenate_datasets([dsdict[k] for k in T_keys_train])
    Stest = concatenate_datasets([dsdict[k] for k in S_keys_test])
    Ttest = concatenate_datasets([dsdict[k] for k in T_keys_test])

    if args.weak_val:
        perm = np.random.permutation(len(Strain))
        cutoff = len(dsdict['weaktraining_validation']) # go for same size as gold set
        validation = Strain.select(perm[:cutoff])
        Strain = Strain.select(perm[cutoff:])
        print("validation data:")
        print(validation)
    else:
        validation = dsdict['weaktraining_validation']

    if args.train_fraction < 1:
        perm = np.random.permutation(len(Strain))
        print(perm[:50])
        segment = args.train_segment
        ex_per_seg = int(args.train_fraction*len(Strain))

        Strain = Strain.select(perm[segment*ex_per_seg:(segment+1)*ex_per_seg])

    if args.model == 'bow':
        embedding_dim = tokenizer.vocab_size
    else:
        embedding_dim = torch.tensor(Strain[0]['text_embeds']).shape[0]


    orig_embedding_dim = embedding_dim
    if args.projection_model:
        from safetensors.torch import load_model
        from train_contrastive_projection import ProjectionModel
        projection = ProjectionModel(embedding_dim=embedding_dim)
        load_model(projection, args.projection_model)
        projection = projection.linear1
        embedding_dim = projection.out_features

    model = torch.nn.Linear(embedding_dim, 2, bias=True)

    def encode(examples):
        for k in ('text', 'Tpara1', 'Spara1'):
            if k in examples:
                examples[f'{k}_embeds'] = torch.tensor(examples[f'{k}_embeds'])
                if args.projection_model:
                    examples[f'{k}_embeds'] = projection(examples[f'{k}_embeds'])
        return examples


    batched = args.model != 'bow'
    encoded_train = Strain.map(encode, batched=batched)
    encoded_val = validation.map(encode, batched=batched)

    modelname = args.model.split('/')[-1]
    run_name = f"{args.dataset}_{modelname}_seed{args.seed}"
    if not args.gold_train:
        run_name += '_weaktrain'
    if args.weak_val:
        run_name += '_weakval'
    if args.linear:
        run_name += '_linear'
    if args.robust:
        run_name += f'_robust_coef{args.robust_coef}'
    if args.train_fraction < 1.0:
        run_name += f'_frac{args.train_fraction}_seg{args.train_segment}'
    if args.projection_model:
        run_name += f'_projection'

    if args.debug:
        run_name = "debug"

    lr=1e-2

    training_args = TrainingArguments(
        f"fixedbug/{run_name}",
        evaluation_strategy="steps",
        eval_steps=10,
        disable_tqdm=True,
        learning_rate=lr,
        weight_decay=args.wd,
        max_steps=args.num_steps,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=128,
        metric_for_best_model='eval_accuracy',
        save_total_limit=2,
        save_steps=10,
        load_best_model_at_end=True,
        remove_unused_columns=False,
        label_names=['labels']
    )

    collator = DefaultDataCollator() if args.model != 'bow' else DataCollatorWithPadding(tokenizer=tokenizer,max_length=512)
    trainer = LinearTrainer(
        do_onehot=args.model=='bow',
        embedding_dim=embedding_dim,
        model=model,
        args=training_args,
        train_dataset=encoded_train,
        eval_dataset=encoded_val,
        compute_metrics=compute_metrics,
        data_collator=collator,
    )

    if not args.eval_only:
        trainer.train()

    #final_eval = encoded_val if not args.weak_val else dsdict['weaktraining_validation'].map(encode, batched=True)
    #print("val metrics, true label:")
    #trainer.evaluate(eval_dataset=final_eval)

    # prevent encoding issues from None's later
    Ttest = Ttest.add_column('weak_label', [-1]*len(Ttest))
    Ttest = Ttest.add_column('Tpara1', ['']*len(Ttest))
    Ttest = Ttest.add_column('Tpara1_embeds', [[-1.0]*orig_embedding_dim]*len(Ttest))

    if 'Spara1' in Stest.column_names:
        Ttest = Ttest.add_column('Spara1', ['']*len(Ttest))
        Ttest = Ttest.add_column('Spara1_embeds', [[-1.0]*orig_embedding_dim]*len(Ttest))

    combined_test = concatenate_datasets([Stest, Ttest])
    encoded_test = combined_test.map(encode, batched=True)
    if args.debug:
        perm = np.random.permutation(len(encoded_test))
        inds = perm[:int(0.05*len(encoded_test))]
        encoded_test = encoded_test.select(inds)

    print("combined test metrics, true label:")
    # largely for debugging. this should equal weighted sum of the loop outputs below
    trainer.evaluate(eval_dataset=encoded_test)


    test_keys = S_keys_test + T_keys_test
    testsize = sum(len(dsdict[k]) for k in test_keys)
    all_metrics = {}
    for key in test_keys:
        Si = dsdict[key]
        if args.debug:
            perm = np.random.permutation(len(Si))
            inds = perm[:int(0.05*len(Si))]
            Si = Si.select(inds)
        encoded = Si.map(encode, batched=True)
        print(key, len(Si)/testsize)
        metrics = trainer.evaluate(eval_dataset=encoded) # this also prints
        all_metrics[key] = metrics
        #print(metrics)

    # now measure S-T expansion
    print("S--T expansion stats")
    print("=" * 50)

    def encode(examples, key='text'):
        examples['text_embeds'] = torch.tensor(examples[f'{key}_embeds'])
        return examples

    for i, (Skey, Tkey) in enumerate(zip(S_keys_test, T_keys_test)):
        Si = dsdict[Skey]

        encoded_para = Si.map(encode, batched=True, fn_kwargs={'key': 'Tpara1'})
        encoded_orig = Si.map(encode, batched=True, fn_kwargs={'key': 'text'})

        # these are the same indices, but do w/ two filter() calls cuz lazy
        Sigood_para = encoded_para.filter(lambda ex: ex['weak_label'] == ex['label'])
        Sigood_orig = encoded_orig.filter(lambda ex: ex['weak_label'] == ex['label'])

        # now get predictions for each dataset
        para_preds = trainer.predict(test_dataset=Sigood_para)
        orig_preds = trainer.predict(test_dataset=Sigood_orig)

        labels = para_preds[1]
        para_preds = para_preds[0].argmax(axis=1)
        orig_preds = orig_preds[0].argmax(axis=1)

        robust_inds = np.where(para_preds == orig_preds)[0]
        p_robust = len(robust_inds) / len(Sigood_para)
        err_on_robust = (para_preds[robust_inds] != labels[robust_inds]).astype(float).mean()

        # P(n(x) \in U, f(x) = f(n(x)) | Sigood)
        numerator = p_robust*err_on_robust

        acc_uncovered = all_metrics[Tkey]['eval_accuracy']
        err_uncovered = 1.0 - acc_uncovered
        denominator = err_uncovered
        good_expansion = numerator / denominator

        # now compute bad expansion

        # these are the same indices, but do w/ two filter() calls cuz lazy
        Sibad_para = encoded_para.filter(lambda ex: ex['weak_label'] != ex['label'])
        Sibad_orig = encoded_orig.filter(lambda ex: ex['weak_label'] != ex['label'])

        # now get predictions for each dataset
        para_preds = trainer.predict(test_dataset=Sibad_para)
        orig_preds = trainer.predict(test_dataset=Sibad_orig)

        labels = para_preds[1]
        para_preds = para_preds[0].argmax(axis=1)
        orig_preds = orig_preds[0].argmax(axis=1)

        robust_inds = np.where(para_preds == orig_preds)[0]
        p_robust = len(robust_inds) / len(Sibad_para)
        acc_on_robust = (para_preds[robust_inds] == labels[robust_inds]).astype(float).mean()

        # P(n(x) \in T\U, f(x) = f(n(x)) | Sibad)
        numerator = p_robust*acc_on_robust
        denominator = acc_uncovered
        bad_expansion = numerator / denominator

        full_preds = trainer.predict(test_dataset=Si.map(encode, batched=True))
        full_preds = full_preds.predictions.argmax(axis=1)
        weak_labels = np.array(Si['weak_label'])
        weak_err = (weak_labels != full_preds).mean()
        alpha = len(Sibad_orig) / len(Si)


        bound_numerator = (weak_err - bad_expansion * alpha)
        bound_denominator = (good_expansion - (good_expansion + bad_expansion)*alpha)

        if min(bound_numerator, bound_denominator) < 0:
            bound_value = float('inf')
        else:
            bound_value = bound_numerator / bound_denominator

        #Sibad = encoded.filter(lambda ex: ex['weak_label'] != ex['label'])
        #metrics_good = trainer.evaluate(eval_dataset=Sigood) # this also prints
        #metrics_bad = trainer.evaluate(eval_dataset=Sibad)
        #err_good_paraphrase = 1.0 - metrics_good['eval_accuracy']
        #acc_bad_paraphrase = metrics_bad['eval_accuracy']
        #acc_uncovered = all_metrics[Tkey]['eval_accuracy']
        #err_uncovered = 1.0 - acc_uncovered
        #good_expansion = err_good_paraphrase / err_uncovered
        #bad_expansion = acc_bad_paraphrase / acc_uncovered

        key = f"{Skey}_{Tkey}_expansion"
        all_metrics[key] = [good_expansion, bad_expansion]
        all_metrics[f"{Skey}_{Tkey}_bdval"] = bound_value
        all_metrics[f"{Skey}_weakerr"] = weak_err
        all_metrics[f"{Skey}_alpha"] = alpha

        print(i, good_expansion, bad_expansion)
        print(f"bound_value={bound_value}")
        print('-' * 10)
        #print(metrics)


    # now measure S-S expansion
    print("S--S expansion stats")
    print("=" * 50)
    for i, Skey in enumerate(S_keys_test):
        Si = dsdict[Skey]

        encoded_para = Si.map(encode, batched=True, fn_kwargs={'key': 'Spara1'})
        encoded_orig = Si.map(encode, batched=True, fn_kwargs={'key': 'text'})

        # these are the same indices, but do w/ two filter() calls cuz lazy
        Sigood_para = encoded_para.filter(lambda ex: ex['weak_label'] == ex['label'])
        Sigood_orig = encoded_orig.filter(lambda ex: ex['weak_label'] == ex['label'])

        # these are the same indices, but do w/ two filter() calls cuz lazy
        Sibad_para = encoded_para.filter(lambda ex: ex['weak_label'] != ex['label'])
        Sibad_orig = encoded_orig.filter(lambda ex: ex['weak_label'] != ex['label'])

        # now get predictions for each dataset
        good_para_preds = trainer.predict(test_dataset=Sigood_para)
        good_orig_preds = trainer.predict(test_dataset=Sigood_orig)
        good_labels = good_para_preds[1]
        good_para_preds = good_para_preds[0].argmax(axis=1)
        good_orig_preds = good_orig_preds[0].argmax(axis=1)

        bad_para_preds = trainer.predict(test_dataset=Sibad_para)
        bad_orig_preds = trainer.predict(test_dataset=Sibad_orig)
        bad_labels = bad_para_preds[1]
        bad_para_preds = bad_para_preds[0].argmax(axis=1)
        bad_orig_preds = bad_orig_preds[0].argmax(axis=1)

        good_robust_inds = np.where(good_para_preds == good_orig_preds)[0]
        bad_robust_inds = np.where(bad_para_preds == bad_orig_preds)[0]

        # Sigood-->Sibad expansion:
        # err(Sigood_para) / err(Sibad_orig)

        p_robust = len(good_robust_inds) / len(Sigood_para)
        err_on_robust = (good_para_preds[good_robust_inds] != good_labels[good_robust_inds]).astype(float).mean()

        numerator = p_robust * err_on_robust
        denominator = (bad_orig_preds != bad_labels).astype(float).mean()

        good_expansion = numerator / denominator

        # Sibad-->Sigood expansion:
        # acc(Sibad_para) / acc(Sigood_orig)
        p_robust = len(bad_robust_inds) / len(Sibad_para)
        acc_on_robust = (bad_para_preds[bad_robust_inds] == bad_labels[bad_robust_inds]).astype(float).mean()

        numerator = p_robust * acc_on_robust
        denominator = (good_orig_preds == good_labels).astype(float).mean()
        bad_expansion = numerator / denominator

        full_preds = trainer.predict(test_dataset=Si.map(encode, batched=True))
        full_preds = full_preds.predictions.argmax(axis=1)
        weak_labels = np.array(Si['weak_label'])
        weak_err = (weak_labels != full_preds).mean()
        alpha = len(Sibad_orig) / len(Si)

        bound_value = (weak_err - bad_expansion * alpha) / (good_expansion - (good_expansion + bad_expansion)*alpha)

        badgood_c_lower_bound = (1-alpha)*weak_err / (2*alpha*(1-alpha) - alpha*weak_err)
        new_bound_value = (1-alpha+bad_expansion*alpha)/(1-alpha-bad_expansion*alpha)*(weak_err + alpha) - 2*bad_expansion*alpha/(1-alpha-bad_expansion*alpha)

        key = f"{Skey}good_{Skey}bad_expansion"
        all_metrics[key] = good_expansion

        key = f"{Skey}bad_{Skey}good_expansion"
        all_metrics[key] = bad_expansion

        print(i, good_expansion, bad_expansion)
        print(f"bound_value={bound_value}")
        print(f"new_bound_value={new_bound_value}")
        print(f"Lower bound on badgood expansion: {badgood_c_lower_bound}")
        print(f"actual badgood expansion: {bad_expansion}")

        print('-' * 10)

    import json
    with open(f"horribleincredible_mistral_frac1/{run_name}_stats.json", 'w') as fh:
        json.dump(all_metrics, fh)
