import os, socket

import lightning

import argparse
import copy
import numpy as np
import yaml
import torch
import torch.nn.functional as F
import pandas as pd
import lightning as L
from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType
import loralib
import math
from torch.autograd import Variable

from workbench.utils.cpr_wrapper import apply_CPR
from workbench.data.pubmedqa import get_pubmedqa
from workbench.data.pubmedqa_owt import get_pubmedqa_owt
from workbench.data.medmcqa import get_medmcqa
from workbench.data.gsm8k import get_gsm8k
from workbench.data.trivia_qa import get_trivia_qa
from workbench.data.math_qa import get_math_qa
from workbench.data.piqa import get_piqa

from workbench.evaluate import evaluate_model
from workbench.configuration import Config, read_unknown_args
from workbench.folder_manager import get_experiment_folder
from workbench.utils.log_gradient import LogParamsAndGrads


torch.set_float32_matmul_precision('medium')

def trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def total_params(model):
    return sum(p.numel() for p in model.parameters())

def MultiClassCrossEntropy(logits, labels, T=2):
    # Ld = -1/N * sum(N) sum(C) softmax(label) * log(softmax(logit))
    labels = Variable(labels.data, requires_grad=False).cuda()
    outputs = torch.log_softmax(logits/T, dim=-1)   # compute the log of softmax values
    labels = torch.softmax(labels/T, dim=-1)
    # print('outputs: ', outputs)
    # print('labels: ', labels.shape)
    outputs = torch.sum(outputs * labels, dim=-1, keepdim=False)
    outputs = -torch.mean(outputs)
    return Variable(outputs.data, requires_grad=True).cuda()




def loralib_set_r(model, r) -> None:
    for m in model.modules():
        if isinstance(m, loralib.LoRALayer):
            m.r = r


def print_results_table(results):
    tasks = results.keys()
    acc_dict = {}
    for task in tasks:
        if 'acc' in results[task]:
            acc_dict[task] = results[task]['acc']
        else:
            acc_dict[task + "_mc1"] = results[task]['mc1']
            acc_dict[task + "_mc2"] = results[task]['mc2']
    acc_dict = {'acc': acc_dict}
    results_df = pd.DataFrame(acc_dict)
    print(results_df.to_markdown())



def split_into_sublists(lst, max_sublists=4):
    """Splits a list into up to max_sublists equally filled sublists."""
    sublists = []
    for idx, item in enumerate(lst):
        if len(sublists) < max_sublists:
            sublists.append([])
        sublists[idx % max_sublists].append(item)

    return sublists

def save_results(expt_dir, name, eval_results):
    results_dict = {}
    for key, value in eval_results.items():
        if 'acc' in value:
            results_dict[key] = value['acc']
        else:
            results_dict[key + 'mc1'] = value['mc1']
            results_dict[key + 'mc2'] = value['mc2']
    torch.save(results_dict, os.path.join(expt_dir, name))

def start_training(config, expt_dir, model, tokenizer):

    np.random.seed(config.seed)
    torch.manual_seed(config.seed)

    lightning.seed_everything(config.seed)

    if config.dataset == "pubmedqa":
        train_loader, val_loader = get_pubmedqa(tokenizer, config=config)
    elif config.dataset == "pubmedqa_owt":
        train_loader, val_loader = get_pubmedqa_owt(tokenizer, config=config)
    elif config.dataset == "trivia_qa":
        train_loader, val_loader = get_trivia_qa(tokenizer, config.max_seq_length, config.val_split,
                                               config.effective_batch_size, config.cache_dir)
    elif config.dataset == "medmcqa":
        train_loader, val_loader = get_medmcqa(tokenizer, config.max_seq_length, config.val_split,
                                               config.effective_batch_size, config.cache_dir)
    elif config.dataset == "math_qa":
        train_loader, val_loader = get_math_qa(tokenizer, config.max_seq_length, config.val_split,
                                               config.effective_batch_size, config.cache_dir)
    elif config.dataset == "piqa":
        train_loader, val_loader = get_piqa(tokenizer, config.max_seq_length, config.val_split,
                                               config.effective_batch_size, config.cache_dir)
    elif config.dataset == "gsm8k":
        train_loader, val_loader = get_gsm8k(tokenizer, config.max_seq_length, config.val_split,
                                               config.effective_batch_size, config.cache_dir)

    tb_logger = TensorBoardLogger(expt_dir, name="")
    log_params = LogParamsAndGrads(model, log_gradient=False, log_params=True, log_quantiles=False, log_every_n_steps=20)

    fabric = L.Fabric(devices=config.devices, precision=config.precision, strategy="ddp",
                      loggers=tb_logger, callbacks=[log_params])
    fabric.launch()

    if config.reg_type == "cpr":
        optimizer = apply_CPR(model, torch.optim.AdamW, kappa_init_param=config.cpr_param, kappa_init_method=config.cpr_init,
                                   kappa_update=config.cpr_mu, lr=config.learning_rate, weight_decay=0)

    elif config.reg_type == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

    else:
        raise NotImplementedError


    train_loader, val_loader = fabric.setup_dataloaders(train_loader, val_loader)
    model = fabric.setup_module(model)

    print("START EVALUATION")
    model = model.eval()
    eval_model(fabric, config, expt_dir, model, tokenizer)
    print("END EVALUATION")

    optimizer = fabric.setup_optimizers(optimizer)

    total_training_steps = len(train_loader) * config.max_epochs
    gradient_accumulation_iters = config.batch_size // config.effective_batch_size // config.devices
    total_optim_steps = total_training_steps // gradient_accumulation_iters
    print("total_training_steps", total_training_steps)
    print("gradient_accumulation_iters", gradient_accumulation_iters)

    decay_factor = 0.1
    cosine_scheduler = lambda x : decay_factor + (1 - decay_factor) * max(0.0, (1 + math.cos(math.pi * ( x + 1-  config.warmup_steps) / float( max(1, total_optim_steps - config.warmup_steps)))) / 2)

    optim_step_count = 0
    step_count = 0


    fabric.log("train/optim_step_count", optim_step_count, step_count)
    fabric.log("train/epoch", 0, step_count)
    fabric.log("train/lr", 0, step_count)


    print("start training")
    for epoch in range(config.max_epochs):

        model.train()
        for batch_idx, batch in tqdm(enumerate(train_loader), total=len(train_loader)):

            input, target = batch
            step_count += 1
            is_accumulating = step_count % gradient_accumulation_iters != 0

            if not is_accumulating:
                if optim_step_count <= config.warmup_steps:
                    lr = config.learning_rate * optim_step_count / config.warmup_steps
                else:
                    lr = config.learning_rate *  cosine_scheduler(optim_step_count)
                for param_group in optimizer.param_groups:
                    param_group["lr"] = lr

            with fabric.no_backward_sync(model, enabled=is_accumulating):
                output = model(input)['logits']
                loss = F.cross_entropy(output.view(-1, output.shape[-1]), target.view(-1))
                fabric.log("train/loss", loss.detach().item(), step_count)

                loss = loss / gradient_accumulation_iters

                fabric.backward(loss)


            if not is_accumulating:
                if config.gradient_clipping > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clipping)
                fabric.call("on_before_optimizer_step", model, optim_step_count, fabric.loggers)



                optimizer.step()
                optimizer.zero_grad()
                optim_step_count += 1

                fabric.log("train/optim_step_count", optim_step_count, step_count)
                fabric.log("train/epoch", epoch, optim_step_count)
                fabric.log("train/lr", lr, optim_step_count)

                if optim_step_count % 10 == 0 and hasattr(optimizer, "cpr_states"):
                    for group in optimizer.param_groups:
                        if 'apply_cpr' in group and group['apply_cpr'] is True:
                            for name, param in zip(group['names'], group['params']):
                                fabric.log(f"cpr/{name}/lagmul", optimizer.cpr_states[param]["lagmul"].detach().cpu(), optim_step_count)
                                fabric.log(f"cpr/{name}/kappa", optimizer.cpr_states[param]["kappa"].detach().cpu(), optim_step_count)

            if batch_idx % 100 == 0:
                print(f"epoch: {epoch} - iteration: {batch_idx} - step: {step_count} - loss {loss.item():.4f}")


    if hasattr(model, "merge_weights"):
        model.merge_weights()


    torch.save(model.state_dict(), os.path.join(expt_dir, "state_dict.pt"))

    model = model.eval()
    eval_model(fabric, config, expt_dir, model, tokenizer)


    return 0




def eval_base_model(config, expt_dir, model, tokenizer, result_file="results.csv"):

    fabric = L.Fabric(devices=config.devices, precision=config.precision, strategy="ddp")
    fabric.launch()

    model = fabric.setup_module(model)

    eval_model(fabric, config, expt_dir, model, tokenizer, result_file)

def eval_model(fabric, config, expt_dir, model, tokenizer, result_file="results.csv"):

    if fabric.world_size == 1:
        eval_results = evaluate_model(model, tokenizer, config.eval_tasks, fabric=fabric)
        save_results(expt_dir, f"results_{fabric.local_rank}.pt", eval_results)
    else:

        eval_task_split = split_into_sublists(config.eval_tasks, max_sublists=fabric.world_size)

        if fabric.local_rank == 0:
            eval_results = evaluate_model(model, tokenizer, eval_task_split[0], fabric=fabric)
            save_results(expt_dir, f"results_{fabric.local_rank}.pt", eval_results)
            print(eval_results)
        if fabric.local_rank == 1 and len(eval_task_split) > 1:
            eval_results = evaluate_model(model, tokenizer, eval_task_split[1], fabric=fabric)
            save_results(expt_dir, f"results_{fabric.local_rank}.pt", eval_results)
            print(eval_results)
        if fabric.local_rank == 2 and len(eval_task_split) > 2:
            eval_results = evaluate_model(model, tokenizer, eval_task_split[2], fabric=fabric)
            save_results(expt_dir, f"results_{fabric.local_rank}.pt", eval_results)
            print(eval_results)
        if fabric.local_rank == 3 and len(eval_task_split) > 3:
            eval_results = evaluate_model(model, tokenizer, eval_task_split[3], fabric=fabric)
            save_results(expt_dir, f"results_{fabric.local_rank}.pt", eval_results)
            print(eval_results)

    fabric.barrier()

    if fabric.local_rank == 0:
        results_all = []
        for file in os.listdir(expt_dir):
            if file.startswith("results_"):
                results_dict = torch.load(os.path.join(expt_dir, file))
                results_df = pd.DataFrame(results_dict, index=["acc"]).T
                results_all.append(results_df)
        results_all = pd.concat(results_all)
        results_all.sort_index(inplace=True)
        print(results_all.to_markdown())
        results_all.to_csv(os.path.join(expt_dir, result_file))

        acc_dict = results_all['acc'].to_dict()
        for k, v in acc_dict.items():
            fabric.log(f"eval/{k}/acc", v)

        return results_all
    else:
        return None


if __name__ == "__main__":

    print("CUDA AVAILABLE", torch.cuda.is_available())
    print("CUDA DEVICES", torch.cuda.device_count())

    default_config_name = "default.yaml"

    parser = argparse.ArgumentParser(description='Train LLM')
    parser.add_argument('-c', '--config', type=str, default=default_config_name, help='config file name')

    args, unknown_args = parser.parse_known_args()

    config_name = args.config
    if not config_name.endswith('.yaml'):
        config_name += '.yaml'

    config_file = os.path.join("config", args.config)
    with open(config_file, 'r') as f:
        config_dict = yaml.load(f, Loader=yaml.Loader)
    config_dict = read_unknown_args(unknown_args, config_dict)

    if "replace_layer" in config_dict:
        if "-" in config_dict["replace_layer"]:
            config_dict["replace_layer"] = config_dict["replace_layer"].split("-")


    if isinstance(config_dict["seed"], str):
        seed_list = [int(i) for i in config_dict["seed"].split("-")]
    else:
        seed_list = [config_dict["seed"]]

    for seed in seed_list:

        config_dict_copy = copy.deepcopy(config_dict)
        config_dict_copy["seed"] = seed
        config_dict_copy["experiment"]["session_name"] += f"_seed-{seed}"

        print("start training with seed", seed)

        config = Config(config_dict=config_dict_copy)
        expt_dir = get_experiment_folder(**config.experiment, new_folder=True, count_folder=False)
        config.save_config(expt_dir)

        print("Experiment dir:", expt_dir)

        model = AutoModelForCausalLM.from_pretrained( config.model_name, cache_dir=config.cache_dir,)
        tokenizer = AutoTokenizer.from_pretrained(config.model_name , cache_dir=config.cache_dir)
        tokenizer.pad_token = tokenizer.eos_token



        if config.lora_type == "peft":
            print("PEFT lora fine tuning")
            lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=config.r, lora_alpha=int(config.r * config.lora_alpha), target_modules=config.replace_layer)
            model = get_peft_model(model, lora_config)
            model.print_trainable_parameters()

            trainable_p = trainable_params(model)
            total_p = total_params(model)
            print(f"Total params: {total_p:,d} || Trainable params: {trainable_p:,d} || Fraction %: {100 * trainable_p / total_p:.2f}")


        elif config.lora_type == "loralib":

            lora_config = {
                "r": config.r,
                "lora_alpha": int(config.lora_alpha),
                "lora_dropout": config.lora_dropout,}

            with torch.no_grad():
                for name, module in model.named_modules():
                    if any(replace_key in name for replace_key in config.replace_layer):
                        parent = model.get_submodule(".".join(name.split(".")[:-1]))
                        target_name = name.split(".")[-1]
                        target = model.get_submodule(name)
                        if isinstance(target, torch.nn.Linear):
                            new_module = loralib.Linear(target.in_features, target.out_features, bias=target.bias is not None,
                                                     **lora_config)
                            new_module.weight.copy_(target.weight)
                            if target.bias is not None:
                                new_module.bias.copy_(target.bias)
                        else:
                            raise NotImplementedError

                        setattr(parent, target_name, new_module)

            loralib.mark_only_lora_as_trainable(model)

            trainable_p = trainable_params(model)
            total_p = total_params(model)
            print(f"loralib Total params: {total_p:,d} || Trainable params: {trainable_p:,d} || Fraction %: {100 * trainable_p / total_p:.2f}")


        elif config.lora_type == "base":
            print("Base model performance evaluation")

            eval_base_model(config, expt_dir, model, tokenizer)

            exit()


        tb_logger = TensorBoardLogger(expt_dir, name="")
        csv_logger = L.fabric.loggers.CSVLogger(expt_dir, name="")

        start_training(config, expt_dir, model, tokenizer)
