import torch
from transformers import AutoTokenizer, AutoModel
from typing import Union, List
import numpy as np


def extract_feature_set(quantile: float, cosine_similarities, gen_desc):
    """Extract feature set where all features are at the given quantile distance from the original descriptions.

    quantile=0. would fetch the MOST DISSIMILAR feature description for each dataset (most dissimilar to the
    original feature descriptions; lowest cosine similarity); quantile=1. would fetch MOST SIMILAR.
    """
    assert 0.0 <= quantile <= 1.0
    out = {}
    for k, cos_sims in cosine_similarities.items():
        # np.argsort sorts in descending order (lowest similarity has quantile 0)
        ranks = np.argsort(cos_sims)
        quant_rank = int(quantile * max(ranks))  # quantized rank to fetch
        pos = list(ranks).index(quant_rank)
        desc = gen_desc[k][pos]
        out[k] = desc
    return out


class TextEncoder:
    def __init__(self):
        # This is the smallest model in the top 10 on the MTEB leaderboard.
        # https://huggingface.co/spaces/mteb/leaderboard
        self.tokenizer = AutoTokenizer.from_pretrained(
            "khoa-klaytn/bge-base-en-v1.5-angle"
        )
        self.model = AutoModel.from_pretrained("khoa-klaytn/bge-base-en-v1.5-angle")

    def encode_text(self, x: Union[str, List[str]]):
        encoded_input = self.tokenizer(
            x, padding=True, truncation=True, return_tensors="pt"
        )

        # Compute token embeddings
        with torch.no_grad():
            model_output = self.model(**encoded_input)
            # Perform pooling. In this case, cls pooling.
            sentence_embeddings = model_output[0][:, 0]
        # normalize embeddings
        sentence_embeddings = torch.nn.functional.normalize(
            sentence_embeddings, p=2, dim=1
        )
        return sentence_embeddings


# Encoding w/distilbert
# model = DistilBertModel.from_pretrained("distilbert-base-uncased")
# tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

# def encode_text(x: Union[str, List[str]]):
#     """Encode text by fetching Tensor of shape [hidden_size,]"""
#     tokenized = tokenizer(x, return_tensors='pt', truncation=True)
#     # enc has shape [batch_size, seq_len, 768]
#     enc = model(**tokenized)['last_hidden_state']
#     # the first token is the [CLS] token; take that one only
#     enc = enc[:, 0, :]
#     return enc
