import os
import json
import torch
import numpy as np
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import DataCollatorForLanguageModeling
from sklearn.metrics import auc, roc_curve

from src.utils import set_progress
from src.tofuutil import prepare_dataset, prepare_loader
from src.tofuutil.data_module import MIATextDataset


def sweep(score, x):
    """
    Compute a ROC curve and then return the FPR, TPR, AUC, and ACC.
    """
    fpr, tpr, _ = roc_curve(x, -score)
    acc = np.max(1-(fpr+(1-tpr))/2)
    return fpr, tpr, auc(fpr, tpr), acc


def do_plot(prediction, answers, sweep_fn=sweep, metric='auc', legend="", output_dir=None):
    """
    Generate the ROC curves by using ntest models as test models and the rest to train.
    """
    fpr, tpr, auc, acc = sweep_fn(np.array(prediction), np.array(answers, dtype=bool))
    low = tpr[np.where(fpr<.05)[0][-1]] # TPR@5%FPR
    return auc, acc, low


def calculatePerplexity(input_ids, model, attention_mask, tokenizer, gpu):
    """
    exp(loss)
    """
    input_ids = input_ids.to(gpu)
    labels = input_ids.clone()
    labels[attention_mask == 0] = -100
    # print(input_ids.shape)
    with torch.no_grad():
        outputs = model(input_ids, labels=labels)
    batchloss, logits = outputs[:2]

    # evaluate logits
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    loss_fct = nn.CrossEntropyLoss(reduction='none')
    shift_logits = shift_logits.view(-1, model.config.vocab_size)
    shift_labels = shift_labels.view(-1)
    shift_labels = shift_labels.to(shift_logits.device)
    loss = loss_fct(shift_logits, shift_labels)
    loss = loss.view_as(labels[..., 1:])
    loss = loss.sum(dim=-1)
    loss = loss / attention_mask.sum(dim=-1).to(gpu)

    probabilities = torch.nn.functional.log_softmax(logits, dim=-1)
    shift_labels = shift_labels.view_as(labels[..., 1:])
    shift_labels[shift_labels == -100] = 0
    indices = shift_labels.unsqueeze(-1)  # Adds an extra dimension for gather
    ground_truth_probs = torch.gather(probabilities, -1, indices).squeeze(-1)  
    all_prob = []
    for gt_prob, sentence_mask in zip(ground_truth_probs, attention_mask):
        all_prob.append(gt_prob[sentence_mask[1:] == 1].tolist())
    return torch.exp(loss).tolist(), all_prob, loss.tolist()

def min_k_prob(probs):
    pred = {}
    for ratio in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
        k_length = int(len(probs)*ratio)
        topk_prob = np.sort(probs)[:k_length]
        pred[f"Min_{ratio*100}% Prob"] = -np.mean(topk_prob).item()
    return pred

def eval_mia(OUTPUTDIR, LOGGER, hparams, model, tokenizer, conv_template):
    progress = set_progress(disable=os.getenv("POOR", False)) 
    MAX_NUM = 300
    batch_size = hparams.data.eval.batch_size
    with progress:
        eval_tasks = [
            hparams.data.split, 
            # "retain_perturbed", 
            # "real_authors_perturbed", "world_facts_perturbed", 
        ]
        eval_task = progress.add_task(
            "evalbar",
            name="[green][Main Evaluate]",
            total=len(eval_tasks),
        )
        # Evaluate on the original gt-text, retain-text, real-world text
        for eval_split in eval_tasks:
            task_name = eval_split if eval_split != hparams.data.split else "forget"
            question_key = "question"
            answer_key = "answer"

            eval_dataset = prepare_dataset(
                hparams.data.name, tokenizer, conv_template, eval_split, question_key, answer_key, max_num=MAX_NUM
            )
            eval_dataloader = prepare_loader(
                eval_dataset, batch_size,
            ) 
            
            with torch.no_grad():
                gen_task = progress.add_task( #? build progress
                    "evaltask", name=f"[yellow][{task_name}-eval]", total=len(eval_dataloader),
                )
                all_preds = []
                for batch in eval_dataloader:
                    input_ids, _, attention_mask = batch
                    p1, all_prob, p1_likelihood = calculatePerplexity(input_ids, model, attention_mask, tokenizer, gpu=model.device)

                    for p, probs, pl in zip(p1, all_prob, p1_likelihood):
                        pred = {}
                        pred['ppl'] = p
                        pred['likelihood'] = pl
                        if eval_split == hparams.data.split:
                            pred['label'] = 1
                        else:
                            pred['label'] = 0
                        pred.update(min_k_prob(probs))
                        all_preds.append(pred)

                    progress.advance(gen_task) #? update progress
                
                with open(os.path.join(OUTPUTDIR, f"{task_name}-pred.json"), 'w') as f:
                    f.write(json.dumps(all_preds, indent=2))
            progress.advance(eval_task)
        
        # Evaluate on the perturbed text 
        for eval_split in ["WikiMIA_length64", "WikiMIA_length128"]:
            task_name = eval_split
            eval_dataset = MIATextDataset(
                split=task_name,
                tokenizer=tokenizer,
                conv_template=conv_template,
            )
            eval_dataloader = prepare_loader(
                eval_dataset, batch_size,
            )  
            with torch.no_grad():
                gen_task = progress.add_task( #? build progress
                    "evaltask", name=f"[yellow][{task_name}-eval]", total=len(eval_dataloader),
                )
                all_preds = []
                for batch in eval_dataloader:
                    input_ids, _, attention_mask = batch
                    input_ids = input_ids.view(-1, input_ids.shape[-1])
                    attention_mask = attention_mask.view(-1, input_ids.shape[-1])
                    p1, all_prob, p1_likelihood = calculatePerplexity(input_ids, model, attention_mask, tokenizer, gpu=model.device)

                    for p, probs, pl in zip(p1, all_prob, p1_likelihood):
                        pred = {}
                        pred['ppl'] = p
                        pred['likelihood'] = pl
                        pred['label'] = 0
                        pred.update(min_k_prob(probs))
                        all_preds.append(pred)

                    progress.advance(gen_task) #? update progress
                with open(os.path.join(OUTPUTDIR, f"{task_name}-pred.json"), 'w') as f:
                    f.write(json.dumps(all_preds, indent=2))