from typing import Any, Dict, List, Tuple
import torch
import torch.nn.functional as F
import torch.nn as nn

from copy import deepcopy
from transformers import AutoModelForCausalLM, AutoTokenizer
from .ice_hparams import ICEHyperParams
import collections

memories = collections.OrderedDict()

class ICERewriteExecutor:
    def __init__(self, embedding_tok, embedding_model):
        self.embedding_tok = embedding_tok
        self.embedding_model = embedding_model

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def get_embedding(self, query):
        # Tokenize sentences
        encoded_input = self.embedding_tok(query, padding=True, truncation=True, return_tensors='pt')

        # Compute token embeddings
        with torch.no_grad():
            model_output = self.embedding_model(**encoded_input)

        # Perform pooling
        sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])

        # Normalize embeddings
        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        return sentence_embeddings

    def search(self, embeddings):
        global memories
        cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        memory_embs = torch.cat(list(memories.values()), dim=0)

        scores = cos(embeddings, memory_embs).numpy()
        max_pairs = scores.argmax(), scores.max()
        res = list(memories.keys())[max_pairs[0]]
        return res

    def roll_back_memory(self, num=1):
        global memories
        assert len(memories) >= num, print('memory is empty....')
        keys = list(memories.keys())
        for i in range(num):
            memories.pop(keys[i])

    def apply_ice_to_model(
            self,
            model: AutoModelForCausalLM,
            tok: AutoTokenizer,
            requests: Dict,
            hparams: ICEHyperParams,
            copy=False,
            **kwargs: Any,
    ) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
        request = requests
        if copy:
            model = deepcopy(model)

        global memories

        query = request['prompt'] + ' ' + request['target_new']
        embeddings = self.get_embedding(query)
        memories[query] = embeddings

        return model, None

    def infer_eval(self, query):
        query_embeds = self.get_embedding(query)

        icl_content = self.search(query_embeds)

        return icl_content


