
import os, socket
import pandas as pd

# if "juwels" in socket.gethostname() or "jureca" in socket.gethostname():
#     os.environ["HF_DATASETS_CACHE"]= "/p/scratch/transfernetx/franke5/datasets/cache/"
#     os.environ["HF_HOME"]= "/p/scratch/transfernetx/franke5/model/cache/"
#     os.environ["HUGGINGFACE_HUB_CACHE"]= "/p/scratch/transfernetx/franke5/model/cache/"



import json
import sys
from pathlib import Path
from typing import Dict, List, Literal, Optional
import torch
import yaml

from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from peft.peft_model import PeftModel
from typing import Any, Literal, Optional
import lightning as L
import torch
from lightning.fabric.plugins import BitsandbytesPrecision
from lm_eval import base, evaluator, tasks
from lm_eval.base import BaseLM

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

# from generate.base import generate
# from lit_gpt import GPT, Config, Tokenizer
# from lit_gpt.utils import (
#     check_valid_checkpoint_dir,
#     get_default_supported_precision,
#     gptq_quantization,
#     load_checkpoint,
# )

def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
    if torch._dynamo.is_compiling():
        # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
        distribution = torch.empty_like(probs).exponential_(1)
        return torch.argmax(probs / distribution, dim=-1, keepdim=True)
    return torch.multinomial(probs, num_samples=1)

def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None) -> torch.Tensor:
    logits = logits[0, -1]
    # optionally crop the logits to only the top k options
    if top_k is not None:
        v, i = torch.topk(logits, min(top_k, logits.size(-1)))
        # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
        logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
    # optionally scale the logits and sample from a probability distribution
    if temperature > 0.0:
        probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
        return multinomial_num_samples_1(probs)
    return torch.argmax(logits, dim=-1, keepdim=True)

def next_token(model, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
    logits = model(x)["logits"]
    next = sample(logits, **kwargs)
    return next.to(dtype=x.dtype)

@torch.inference_mode()
def generate(
    model,
    prompt: torch.Tensor,
    max_returned_tokens: int,
    *,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    eos_id: Optional[int] = None,
) -> torch.Tensor:
    """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.

    The implementation of this function is modified from A. Karpathy's nanoGPT.

    Args:
        model: The model to use.
        prompt: Tensor of shape (T) with indices of the prompt sequence.
        max_returned_tokens: The maximum number of tokens to return (given plus generated).
        temperature: Scales the predicted logits by 1 / temperature.
        top_k: If specified, only sample among the tokens with the k highest probabilities.
        eos_id: If specified, stop generating any more token once the <eos> token is triggered.
    """
    T = prompt.size(0)
    assert max_returned_tokens > T

    if hasattr(model, "max_seq_length"):
        if model.max_seq_length < max_returned_tokens - 1:
            # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
            # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
            # not support it to avoid negatively impacting the overall speed
            raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")
    else:
        setattr(model, "max_seq_length", 500)


    tokens = [prompt]
    token = next_token(model, prompt.view(1, -1), temperature=temperature, top_k=top_k).clone()
    tokens.append(token)
    for _ in range(2, max_returned_tokens - T + 1):
        token = next_token(model, token.view(1, -1), temperature=temperature, top_k=top_k).clone()
        tokens.append(token)
        if token == eos_id:
            break
    return torch.cat(tokens)


class EvalHarness(BaseLM):
    # Credits:
    # https://github.com/EleutherAI/gpt-neox/blob/main/eval_tasks/eval_adapter.py
    def __init__(self, fabric: L.Fabric, model, tokenizer, batch_size: int, max_length=500):
        super().__init__()
        self.fabric = fabric
        self.model = model
        self.tokenizer = tokenizer
        self.batch_size_per_gpu = batch_size
        self.max_length_value = max_length

    @classmethod
    def create_from_arg_string(cls, arg_string, additional_config=None):
        kwargs = {el.split("=")[0]: el.split("=")[1] for el in arg_string.split(",")}
        return cls(**kwargs, **additional_config)

    @property
    def eot_token_id(self):
        return self.tokenizer.encode(self.tokenizer.eos_token)[0]

    @property
    def max_length(self):
        return self.max_length_value

    @property
    def vocab_size(self):
        return self.tokenizer.vocab_size

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        return self.batch_size_per_gpu * self.fabric.world_size

    @property
    def device(self):
        return self.fabric.device

    def tok_encode(self, string: str) -> List[int]:
        return self.tokenizer.encode(string)

    def tok_decode(self, tokens: List[int]) -> str:
        t = torch.tensor(tokens)
        return self.tokenizer.decode(t)

    @torch.inference_mode()
    def _model_call(self, inps):
        ret = self.model(inps)
        return ret['logits']

    @torch.inference_mode()
    def _model_generate(self, context, max_length, eos_token_id) -> torch.Tensor:
        # this only supports batch size 1
        assert context.shape[0] == 1
        out = generate(self.model, context[0], max_length, eos_id=eos_token_id)
        return out.unsqueeze(0)

    @torch.inference_mode()
    def run_eval(
        self, eval_tasks: List[str], num_fewshot: int, limit: Optional[int], bootstrap_iters: int, no_cache: bool
    ) -> Dict:
        # Returns a list containing all values of the task registry that
        # match at least one of the patterns
        import fnmatch

        def pattern_match(patterns, source_list):
            task_names = set()
            for pattern in patterns:
                for matching in fnmatch.filter(source_list, pattern):
                    task_names.add(matching)
            return list(task_names)

        eval_tasks = pattern_match(eval_tasks, tasks.ALL_TASKS)
        print(f"Found tasks: {eval_tasks}")

        # **HACK INCOMING**:
        # first get task dict on local main rank
        # the tasks are downloaded *as they are initialized*, and the downloads don't like multithreading.
        # so we download them once on the local main rank, wait, and then initialize them on all other ranks, which *should* load from the cache.
        if self.fabric.local_rank == 0:
            tasks.get_task_dict(eval_tasks)
        # torch barrier
        self.fabric.barrier()
        tasks.get_task_dict(eval_tasks)

        lm = self
        if not no_cache:
            lm = base.CachingLM(lm, "lm_cache/lit-gpt.db")

        results = evaluator.evaluate(
            lm=lm,
            task_dict=tasks.get_task_dict(eval_tasks),
            num_fewshot=num_fewshot,
            limit=limit,
            bootstrap_iters=bootstrap_iters,
        )
        # results["config"] = dict(
        #     model=self.model.config.name,
        #     batch_size=self.batch_size,
        #     device=str(self.device),
        #     num_fewshot=num_fewshot,
        #     limit=limit,
        #     bootstrap_iters=bootstrap_iters,
        #     no_cache=no_cache,
        # )
        return results


def convert_torch_scalars_to_floats(d):
    """
    Recursively traverses a nested dictionary and converts any PyTorch scalar
    to a float.
    """
    for key, value in d.items():
        if isinstance(value, dict):
            # If the value is a dictionary, recursively call this function
            convert_torch_scalars_to_floats(value)
        elif torch.is_tensor(value) and value.numel() == 1:
            # If the value is a PyTorch scalar, convert it to a float
            d[key] = value.item()

def evaluate_model(model, tokenizer, eval_tasks, fabric=None):

    if fabric is None:
        fabric = L.Fabric(devices=1, precision="16-true", plugins=None)
        model = fabric.setup_module(model)
    model.eval()

    eval_harness = EvalHarness(fabric, model, tokenizer, 4, max_length=1024)

    results = eval_harness.run_eval(eval_tasks, num_fewshot=0, limit=None, bootstrap_iters=100000, no_cache=True)

    convert_torch_scalars_to_floats(results)

    results = results["results"]

    return results


if __name__ == "__main__":




    # from jsonargparse import CLI
    #
    # torch.set_float32_matmul_precision("high")
    # CLI(run_eval_harness, as_positional=False)

    cache_dir = 'cache'
    # cache_dir = '/p/scratch/transfernetx/franke5/model/cache'
    # torch.set_float32_matmul_precision('medium')


    # model_name = "mistralai/Mistral-7B-Instruct-v0.2"
    # tokenizer_name = "mistralai/Mistral-7B-v0.1"
    model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    tokenizer_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    model_name = "distilgpt2"
    tokenizer_name = "distilgpt2"



    # Model and the tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        # quantization_config=quantization_config,
        # device_map=device_map,
        # cache_dir=cache_dir,
    )

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name , cache_dir=cache_dir)
    tokenizer.pad_token = tokenizer.eos_token

    precision = "16-true"
    plugins = None

    fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)

    fabric.launch()


    model.eval()
    model = fabric.setup_module(model)


    # eval_tasks = ["arc_challenge"]
    # eval_tasks = ["hendrycksTest-*"]
    eval_tasks = ["medmcqa", "pubmedqa"]
    # eval_tasks = ["arc_challenge", "piqa", "hellaswag"]
    # eval_tasks = ["arc_challenge", "piqa", "hellaswag", "hendrycksTest-*"]
    num_fewshot = 0
    limit = None
    bootstrap_iters = 100000
    no_cache = True

    eval_harness = EvalHarness(fabric, model, tokenizer, 1, max_length=4096)

    results = eval_harness.run_eval(eval_tasks, num_fewshot, limit, bootstrap_iters, no_cache)

    # torch.save(results, "/p/scratch/transfernetx/franke5/gpt2_results_4gpu_1.pt")
    # torch.save(results, "/p/scratch/transfernetx/franke5/m7b_results_hendrycksTest_1gpu.pt")

    print(results)