import torch
import torch.nn as nn
import numpy as np
import scipy
import nltk
import typing
from ..util.generate import generate_fast
import torch.nn.functional as F
from sklearn.metrics import f1_score


def test_prediction_acc(model, tok, hparams, prompts, targets, device, locality=False, vanilla_generation=False):
    if vanilla_generation:
        target_new_tokens = tok.encode(' ' + targets)
        if target_new_tokens[0] == tok.pad_token_id or (
                hasattr(tok, 'bos_token_id') and target_new_tokens[0] == tok.bos_token_id):
            target_new_tokens = tok.encode(targets)
            target_new_tokens = target_new_tokens[1:]
        prompt_tok = tok(
            prompts,
            return_tensors="pt",
        ).to(device)
        gen_token = model.generate(
            input_ids=prompt_tok['input_ids'],
            attention_mask=prompt_tok['attention_mask'],
            max_new_tokens=len(target_new_tokens),
            pad_token_id=tok.eos_token_id,
            use_cache=False,
        )
        if locality:
            return [gen_token.detach().cpu().numpy().tolist()[0][-len(target_new_tokens):]]
        else:
            return [np.mean(
                np.equal(target_new_tokens, gen_token.detach().cpu().numpy().tolist()[0][-len(target_new_tokens):] + [-100] * (len(target_new_tokens) - len(gen_token.detach().cpu().numpy().tolist()[0]))))]
    if isinstance(prompts, str):
        prompts, targets = [prompts, ], [targets, ]
    prompt_target = [prompt + ' ' + target for prompt, target in zip(prompts, targets)]
    prompt_target_tok = tok(
        prompt_target,
        padding=True,
        truncation=True,
        return_tensors="pt",
    ).to(f"cuda:{device}")
    prompt_tok = tok(
        prompts,
        padding=True,
        truncation=True,
        return_tensors="pt",
    )
    num_prompt_toks = [int((i != tok.pad_token_id).sum()) for i in prompt_tok['input_ids']]
    num_pad_toks = [int((i == tok.pad_token_id).sum()) for i in prompt_target_tok['input_ids'].cpu()]
    prompt_len = [x + y for x, y in zip(num_pad_toks, num_prompt_toks)]
    with torch.no_grad():
        outputs = model(**prompt_target_tok)
        if type(outputs) is torch.Tensor:
            logits = outputs
        else:
            logits = outputs.logits
        answers = torch.argmax(logits, dim=-1).squeeze().detach().cpu().numpy().tolist()
        labels = prompt_target_tok['input_ids'].squeeze().detach().cpu().numpy().tolist()
        answers = slice_list(answers, prompt_len, left=True)
        labels = slice_list(labels, prompt_len, left=False)
        if locality:
            return answers if type(answers[0]) is list else [answers, ]
        if isinstance(answers[0], list):
            res = []
            for ans, label in zip(answers, labels):
                temp_acc = np.mean(np.equal(ans, label))
                if np.isnan(temp_acc):
                    continue
                res.append(temp_acc)
            return res
        else:
            return [np.mean(np.equal(answers, labels))]


def test_generation_quality(
        model,
        tok,
        prefixes: typing.List[str],
        max_out_len: int,
        vanilla_generation: bool = False
        # consistency_texts: typing.List[str],
        # essence_texts: typing.List[str],
        # vec: TfidfVectorizer,
):
    gen_texts = generate_fast(
        model,
        tok,
        prefixes,
        n_gen_per_prompt=1,
        max_out_len=max_out_len,
        vanilla_generation=vanilla_generation,
    )

    ngram_entropy = n_gram_entropy(gen_texts)
    # consistency_tfidf = tfidf_similarity(
    #     " ".join(gen_texts), " ".join(consistency_texts), vec
    # )

    ret = {
        "ngram_entropy": ngram_entropy,
        # "reference_score": consistency_tfidf,
        # "text": gen_texts,
    }

    # if len(essence_texts) > 0:
    #     ppl = perplexity(model, tok, " ".join(essence_texts), max_input_length=100)
    #     ret.update({"essence_score": ppl, "essence_text": essence_texts})

    return ret


def test_seq2seq_batch_prediction_acc(model, tok, hparams, prompts, targets, device, locality=False):
    if isinstance(prompts, str):
        prompts, targets = [prompts, ], [targets, ]
    prompt_tok = tok(
        prompts,
        padding=True,
        truncation=True,
        return_tensors="pt",
    ).to(f"cuda:{device}")

    trg_tok = tok(
        targets,
        padding=True,
        truncation=True,
        return_tensors="pt",
    ).to(f"cuda:{device}")

    prompt_tok['decoder_input_ids'] = trg_tok['input_ids']
    prompt_tok['decoder_attention_mask'] = trg_tok['attention_mask']

    with torch.no_grad():
        outputs = model(**prompt_tok)
        if type(outputs) is torch.Tensor:
            logits = outputs
        else:
            logits = outputs.logits

        assert logits.size(1) == trg_tok['input_ids'].size(1)
        ans = torch.argmax(logits, dim=-1)
        if locality:
            answers = ans.squeeze().detach().cpu().numpy().tolist()
            return answers if type(answers[0]) is list else [answers, ]
        return torch.mean((trg_tok['input_ids'][:, :-1] == ans[:, :-1]).float(), dim=-1).detach().cpu().numpy().tolist()


def n_gram_entropy(gen_texts, agg="arith"):
    assert agg in ["arith", "geom"]

    return (scipy.stats.mstats.gmean if agg == "geom" else np.mean)(
        [compute_n_gram_entropy(txt) for txt in gen_texts]
    ).item()


def compute_n_gram_entropy(sentence, ns=None, weights=None, agg="arith"):
    if ns is None:
        ns = [2, 3]
    if weights is None:
        weights = [2 / 3, 4 / 3]
    assert agg in ["arith", "geom"]

    entropy_list = []
    for n in ns:
        fdist = compute_freq(sentence, n)
        freqs = np.array([freq for _, freq in fdist.items()])
        freqs = freqs / freqs.sum()

        entropy_list.append(np.sum(-freqs * np.log(freqs) / np.log(2)))

    entropy_list = np.array(entropy_list) * np.array(weights)

    return (scipy.stats.mstats.gmean if agg == "geom" else np.mean)(entropy_list)


def compute_freq(sentence, n=2):
    tokens = nltk.word_tokenize(sentence)
    ngrams = nltk.ngrams(tokens, n)
    return nltk.FreqDist(ngrams)


def PPL(
        model,
        tok,
        prompt: typing.Union[str, typing.List[str]],
        target_new: typing.Union[str, typing.List[str]],
        device,
        **kwargs
):
    if isinstance(prompt, str):
        prompt, target_new = [prompt, ], [target_new, ]
    full_prompt = [f"{p} {l}" for p, l in zip(prompt, target_new)]
    prompt_ids = tok(list(prompt), return_tensors="pt", padding=True, truncation=True)["input_ids"]
    num_prompt_toks = [int((i != tok.pad_token_id).sum()) for i in prompt_ids]
    tokens = tok(full_prompt, return_tensors="pt", padding=True, truncation=True)
    tokens["labels"] = tokens["input_ids"].clone()
    for i in range(len(prompt)):
        tokens["labels"][i][:num_prompt_toks[i]] = -100
    tokens["labels"][tokens["input_ids"] == tok.pad_token_id] = -100  # What is this doing?
    batch = {f"{k1}": v1 for k1, v1 in tokens.items()}
    input_ids = batch["input_ids"][:, :1024]  # .to(device)
    if "labels" not in batch:
        target_ids = batch["input_ids"][:, :1024].clone()
    else:
        target_ids = batch["labels"][:, :1024].clone()
    with torch.no_grad():
        logits = model(input_ids=input_ids.to(device), attention_mask=tokens['attention_mask'].to(device))
        if type(logits) is not torch.Tensor:
            logits = logits.logits
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = target_ids.to(device)[:, 1:].contiguous()

        log_probs = -nn.functional.log_softmax(shift_logits, dim=-1)
        if shift_labels.dim() == log_probs.dim() - 1:
            shift_labels = shift_labels.unsqueeze(-1)

        padding_mask = shift_labels.eq(-100)

        # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
        # will ignore them in any case.
        shift_labels = torch.clamp(shift_labels, min=0)

        nll_loss = log_probs.gather(dim=-1, index=shift_labels)
        nll_loss.masked_fill_(padding_mask, 0.0)

        num_active_elements = padding_mask.numel() - padding_mask.long().sum()
        nll = nll_loss.sum() / num_active_elements
    ppl = torch.exp(nll).clip(0, 100)
    return ppl.cpu().numpy().tolist()


def OOD_PPL(
        model,
        tok,
        prompt: typing.Union[str, typing.List[str]],
        target_new: typing.Union[str, typing.List[str]],
        device,
        threshold=0.8
):
    if isinstance(prompt, str):
        prompt, target_new = [prompt, ], [target_new, ]

    full_prompt = [f"{p}" for p, l in zip(prompt, target_new)]
    tokens = tok(full_prompt, return_tensors="pt", padding=True, truncation=True)

    tokens["labels"] = tokens['input_ids'].clone()
    tokens["labels"][tokens["input_ids"] == tok.pad_token_id] = -100

    # if target_new[0] in prompt[0]:
    #     tokens["labels"] = torch.ones_like(tokens['input_ids']).to(device) * (-100)
    #     target_token = tok.encode(' ' + target_new[0], add_special_tokens=False)
    #     target_token1 = tok.encode(target_new[0], add_special_tokens=False)
    #     target_length = len(target_token)
    #     for i, token in enumerate(tokens['input_ids']):
    #         start_idxs = find_sublist_start_indexes(token.detach().cpu().numpy().tolist(), target_token)
    #         if not start_idxs:
    #             start_idxs = find_sublist_start_indexes(token.detach().cpu().numpy().tolist(), target_token1)
    #             target_length = len(target_token1)
    #         if start_idxs:
    #             for start_idx in start_idxs:
    #                 tokens["labels"][i][start_idx: start_idx + target_length] = token[
    #                                                                             start_idx: start_idx + target_length]

    batch = {f"{k1}": v1 for k1, v1 in tokens.items()}
    input_ids = batch["input_ids"][:, :1024]  # .to(device)
    target_ids = batch["labels"][:, :1024]

    with torch.no_grad():
        logits = model(input_ids=input_ids.to(device), labels=target_ids.to(device)).logits
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = target_ids.to(device)[:, 1:].contiguous()

        log_probs = -nn.functional.log_softmax(shift_logits, dim=-1)
        if shift_labels.dim() == log_probs.dim() - 1:
            shift_labels = shift_labels.unsqueeze(-1)

        padding_mask = shift_labels.eq(-100)

        # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
        # will ignore them in any case.
        shift_labels = torch.clamp(shift_labels, min=0)

        nll_loss = log_probs.gather(dim=-1, index=shift_labels)
        nll_loss.masked_fill_(padding_mask, 0.0)

        # num_active_elements = padding_mask.numel() - padding_mask.long().sum()
        # nll_loss = nll_loss.sum() / num_active_elements

        threshold = -np.log(threshold)

        return len(nll_loss[nll_loss < threshold]) / len(nll_loss.view(-1))
        # answers = torch.argmax(shift_logits, dim=-1).squeeze().detach().cpu().numpy().tolist()
        # labels = shift_labels.squeeze().detach().cpu().numpy().tolist()
        # np.mean(np.equal(answers, labels))
        # num_active_elements = padding_mask.numel() - padding_mask.long().sum()
        # nll_loss = nll_loss.sum() / num_active_elements


def slice_list(matrix, start_indices, left):
    if isinstance(matrix[0], list):
        if left:
            return [row[start_index - 1:-1] for row, start_index in zip(matrix, start_indices)]
        else:
            return [row[start_index:] for row, start_index in zip(matrix, start_indices)]
    else:
        if left:
            return matrix[start_indices[0] - 1:-1]
        else:
            return matrix[start_indices[0]:]


def gather_log_probs(logits, labels):
    # print(f"labels.shape: {labels.shape} , logits.shape[:-1] :{logits.shape[:-1]}")
    assert labels.dim() == logits.dim() - 1
    assert labels.shape == logits.shape[:-1]
    return logits.log_softmax(-1).gather(-1, labels.unsqueeze(-1)).squeeze(-1)


def masked_mean(values, mask):
    assert mask.dtype == torch.bool
    assert values.shape == mask.shape
    return (values * mask.float()).sum() / mask.sum().float()


def mask_hf_labels(labels, null_token=0):
    valid_mask = labels != -100
    valid_labels = labels.masked_fill(~valid_mask, null_token)
    return valid_mask, valid_labels


def kl_loc_loss(pre, post, mask=None):
    pre = pre.to(torch.float32).contiguous()
    post = post[:, -pre.shape[1]:, :].to(torch.float32).contiguous()

    sequence = pre.dim() == 3
    pre_ = pre.view(-1, pre.shape[-1])
    post_ = post.view(pre_.shape)
    assert pre_.shape[0] == post_.shape[0]

    if not sequence:
        if pre_.shape[-1] == 1:  # No masking needed for binary classification
            return (pre.sigmoid() * (F.logsigmoid(pre) - F.logsigmoid(post))).mean() + (
                    (-pre).sigmoid() * (F.logsigmoid(-pre) - F.logsigmoid(-post))
            ).mean()
    else:  # We have sequences of predictions; masking needed
        # print("sequence")
        if pre_.shape[-1] > 1:
            assert mask is not None
            mask_ = mask.view(pre_.shape[0])
            kl = (pre_.softmax(-1) * (pre_.log_softmax(-1) - post_.log_softmax(-1))).sum(-1)
            return (kl * mask_).sum() / mask_.sum()

    raise NotImplementedError


def F1(model, tok, hparams, prompts, targets, device, locality=False, vanilla_generation=True):
    if vanilla_generation:
        target_new_tokens = tok.encode(' ' + targets)
        if target_new_tokens[0] == tok.pad_token_id or (
                hasattr(tok, 'bos_token_id') and target_new_tokens[0] == tok.bos_token_id):
            target_new_tokens = tok.encode(targets)
            target_new_tokens = target_new_tokens[1:]
        prompt_tok = tok(
            prompts,
            return_tensors="pt",
        ).to(device)
        gen_token = model.generate(
            input_ids=prompt_tok['input_ids'],
            attention_mask=prompt_tok['attention_mask'],
            max_new_tokens=len(target_new_tokens),
            pad_token_id=tok.eos_token_id,
            use_cache=False,
        )
        return f1_score(target_new_tokens, gen_token.detach().cpu().numpy().tolist()[0][-len(target_new_tokens):],
                        average='macro')
    if isinstance(prompts, str):
        prompts, targets = [prompts, ], [targets, ]
    prompt_target = [prompt + ' ' + target for prompt, target in zip(prompts, targets)]
    max_prompt_len = max([len(tok.encode(_)) for _ in prompt_target]) + 1
    prompt_target_tok = tok(
        prompt_target,
        padding=True,
        truncation=True,
        return_tensors="pt",
    ).to(f"cuda:{device}")
    prompt_tok = tok(
        prompts,
        padding=True,
        truncation=True,
        return_tensors="pt",
    )
    num_prompt_toks = [int((i != tok.pad_token_id).sum()) for i in prompt_tok['input_ids']]
    num_pad_toks = [int((i == tok.pad_token_id).sum()) for i in prompt_target_tok['input_ids'].cpu()]
    prompt_len = [x + y for x, y in zip(num_pad_toks, num_prompt_toks)]
    with torch.no_grad():
        outputs = model(**prompt_target_tok)
        if type(outputs) is torch.Tensor:
            logits = outputs
        else:
            logits = outputs.logits
        answers = torch.argmax(logits, dim=-1).squeeze().detach().cpu().numpy().tolist()
        labels = prompt_target_tok['input_ids'].squeeze().detach().cpu().numpy().tolist()
        answers = slice_list(answers, prompt_len, left=True)
        labels = slice_list(labels, prompt_len, left=False)

        return f1_score(answers, labels, average='macro')


def find_sublist_start_indexes(list1, list2):
    res = []
    for i in range(len(list1) - len(list2) + 1):
        if all(a == b for a, b in zip(list1[i:i + len(list2)], list2)):
            res.append(i)
    return res

