import os
import csv
import json
import torch
import hydra
from hydra.core.hydra_config import HydraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from peft import PeftModel
from glob import glob
import shutil
import numpy as np
import torch.nn as nn
import itertools
import re
from transformers import pipeline, DataCollatorForLanguageModeling
import copy

import src.dataset
from src.convert_data import RETAIN_TASKS, get_wikitext
from src.utils import init_script, set_progress
from src.language_models import UnlearnLLM, SmallLLM, ContrastLLM, AssistedModel, load_unlearned_model
from src.mia_util import eval_mia
from src.gen_util import ContrastGenerationMixin
from src.conv_util import create_template
from src.tofuutil import tofu_eval
from src.lightningutil.modelmodule import ForgetModule
from codetiming import Timer
from src.utils import NameTimer

@hydra.main(version_base=None, config_path="../configs", config_name="eval_config")
def main(hparams):
    LOGGER = init_script(hparams)
    LOGGER.info("Configs", configs=hparams)
    OUTPUTDIR = HydraConfig.get().runtime.output_dir
    device = f'cuda:{hparams.gpu.gpu_id}'
    print("DEVICE", device)

    conv_template = create_template(hparams.data.conv_template)
    model = load_unlearned_model(hparams, device)
    #! Copy runtime information
    if 'remember' not in hparams:
        model_path = hparams.model.model_path
    else:
        model_path = hparams.remember.save_path
    simpleprofiles = glob(os.path.join(model_path, "../**/", "*simpleprofile*"), recursive=True)
    if len(simpleprofiles) > 0:
        # Copy the simpleprofile to the output directory
        for fname in simpleprofiles:
            if os.path.getsize(fname) > 10:
                shutil.copy(fname, OUTPUTDIR)

    tokenizer = AutoTokenizer.from_pretrained("locuslab/tofu_ft_llama2-7b")
    tokenizer.padding_side = "left"
    tokenizer.padding_size = 'longest'
    tokenizer.pad_token = tokenizer.eos_token

    right_pad_tokenizer = AutoTokenizer.from_pretrained("locuslab/tofu_ft_llama2-7b")
    right_pad_tokenizer.padding_side = 'right'
    right_pad_tokenizer.padding_size = 'longest'
    right_pad_tokenizer.pad_token = tokenizer.eos_token

    with Timer("Evaluation", text="{name} spent: {:0.4f} seconds"):
        if hparams.eval_func == 'tofu':
            tofu_eval(OUTPUTDIR, LOGGER, hparams, model, tokenizer, right_pad_tokenizer, conv_template, only_forget_quality=False)
            eval_data = get_wikitext()
            retain_conv_template = copy.deepcopy(conv_template)
            retain_conv_template.question_start_token = ""
            retain_conv_template.question_end_token = " "
            retain_conv_template.answer_token = ""
            eval_res = eval_retain_ppl(OUTPUTDIR, LOGGER, 'wikitext', eval_data, retain_conv_template, model, right_pad_tokenizer)
        elif hparams.eval_func == 'tofu-onlyforget':
            tofu_eval(OUTPUTDIR, LOGGER, hparams, model, tokenizer, right_pad_tokenizer, conv_template, only_forget_quality=True)
        elif hparams.eval_func == 'mia':
            eval_mia(OUTPUTDIR, LOGGER, hparams, model, right_pad_tokenizer, conv_template)
        else:
            raise ValueError("Unkown evaluation function")

def batchify(data, batch_size):
    num_batches = len(data) // batch_size
    for i in range(num_batches + 1):
        batchitems = [data[j] for j in range(i*batch_size, min(len(data), (i+1)*batch_size))]
        if len(batchitems) == 0:
            break
        yield batchitems
        
import torch.nn.functional as F
def checklogits(logits, labels):
    logits = logits.cpu()
    labels = labels.cpu()
    # gather prob
    logits = F.log_softmax(logits, dim=-1)
    logits = logits[:, :-1]
    labels = labels[:, 1:] # shift

    log_likelihood = []
    idx = 0
    for logit, label in zip(logits, labels):
        logit = logit[label != -100] 
        label = label[label != -100].unsqueeze(-1)
        if idx == 0:
            idx += 1
        logit = torch.gather(
            logit, -1, label
        )
        log_likelihood.append(torch.mean(logit).cpu().item())
    return log_likelihood

def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len):
    #! Copied from lm-evaluation-harness/lm_eval/utils.py:177
    assert 1 <= context_len <= max_seq_len
    if not token_list:
        return
    # +1 offset, going from input->preds
    pred_len = max_seq_len - context_len + 1
    predicted = 0

    # Special handling for first window: predict all tokens
    first_seq_len = min(max_seq_len, len(token_list))
    yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len])
    predicted += first_seq_len

    while predicted < len(token_list):
        window_pred_len = min(len(token_list) - predicted, pred_len)
        window_end = predicted + window_pred_len

        yield (
            token_list[window_end - max_seq_len - 1 : window_end - 1],
            token_list[window_end - window_pred_len : window_end],
        )
        predicted += window_pred_len

def make_disjoint_window(pair):
    a, b = pair
    return a[: len(a) - (len(b) - 1)], b


@torch.no_grad()
def eval_retain_ppl(OUTPUTDIR, LOGGER, NAME, data, conv_template, model, tokenizer):
    progress = set_progress(disable=os.environ.get("POOR", False))
    collator = DataCollatorForLanguageModeling(mlm=False, tokenizer=tokenizer)
    #! Follow the implementation in lm-evluation-harness
    all_logprobs = []
    all_word_cnt = []
    lossfn = nn.CrossEntropyLoss(reduction='none')
    def eval_chunk_logits(chunk_inputs):
        collated= collator([*chunk_inputs])
        outputs = model(input_ids=collated['input_ids']).logits
        outputs = F.log_softmax(outputs, dim=-1)
        logits = outputs[:,:-1]
        labels = collated['labels'][:, 1:]
        logits = logits[labels != -100]
        labels = labels[labels != -100].unsqueeze(-1)
        logprob = torch.gather(logits, 1, labels)
        return logprob
    
    with progress:
        batch_size = 1
        task = progress.add_task(f"[green]Evaluating {NAME}", name=f"{NAME}", total=len(data) // batch_size)
        for batchitem in batchify(data, batch_size):
            rawtexts = [x['newdoc'] for x in batchitem]
            inputs = tokenizer(rawtexts).input_ids[0]
            doc_logprobs = []
            for chunk_inputs in get_rolling_token_windows(inputs, prefix_token=tokenizer.eos_token_id, max_seq_len=tokenizer.model_max_length, context_len=1):
                chunk_inputs = make_disjoint_window(chunk_inputs)[1]
                chunk_inputs = torch.tensor(chunk_inputs).unsqueeze(0).to(model.device)
                chunk_logprob = eval_chunk_logits(chunk_inputs)
                doc_logprobs.append(chunk_logprob.cpu().tolist())

            doc_logprobs = list(itertools.chain.from_iterable(doc_logprobs))
            progress.advance(task)
            all_logprobs.append(np.sum(doc_logprobs))
            all_word_cnt.append(len(re.split(r"\s+", batchitem[0]['page'])))

    with open(os.path.join(OUTPUTDIR, f'{NAME}-loglikelihood.json'), 'w') as f:
        json.dump({
            'logprob': all_logprobs,
        }, f, indent=2)

    result = {
        f"{NAME}-word-ppl" : np.exp(- np.sum(all_logprobs) / np.sum(all_word_cnt)).item()
    }
    LOGGER.info(f"{NAME}-result", **result)
    return result    
    

if __name__ == "__main__":
    main()