import torch
import torch.nn.functional as F
import argparse
from torch import Tensor
from transformers import AutoTokenizer, AutoModel, set_seed
from datasets import load_from_disk
import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', choices=["imdb_goodbad_allpara", 'imdb_goodbad_withpara_final', "imdb_horribleincredible_allpara_final"], default='imdb_goodbad_withpara_final')
parser.add_argument('--seed',type=int, default=0)

args = parser.parse_args()
set_seed(args.seed)
#model = 'intfloat/e5-mistral-7b-instruct'
model = 'Salesforce/SFR-Embedding-Mistral'

def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]


tokenizer = AutoTokenizer.from_pretrained(model)

model = AutoModel.from_pretrained(
    model,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2"
).to('cuda')

dsdict = load_from_disk(args.dataset)
for dsk in dsdict:
    if not (dsk.startswith('S') or dsk.startswith('T')):
        continue
    ds = dsdict[dsk]#.select(range(1000))
    print(ds)

    queries = []
    Tpara_queries = []
    Spara_queries = []
    queries = {'text': [], 'Tpara1': [], 'Spara1': []}
    for example in ds:
        queries['text'].append(example['text'])
        if 'Tpara1' in example:
            queries['Tpara1'].append(example['Tpara1'])
        if 'Spara1' in example:
            queries['Spara1'].append(example['Spara1'])


    batch_size=512
    N = len(ds) // batch_size + 1
    all_embeds = {'text': [], 'Tpara1': [], 'Spara1': []}
    with torch.no_grad():
        for k,l in queries.items():
            if len(l) == 0:
                continue
            embeds = []
            for i in tqdm.tqdm(range(N)):
                input_texts = l[i*batch_size:(i+1)*batch_size]

                max_length = 512
                # Tokenize the input texts
                batch_dict = tokenizer(input_texts, max_length=max_length - 1, return_attention_mask=False, padding=False, truncation=True)
                # append eos_token_id to every input_ids
                batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
                batch_dict = tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors='pt')

                batch_dict = {k:v.to('cuda') for (k,v) in batch_dict.items()}
                outputs = model(**batch_dict)
                embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']).to(torch.float32).cpu()

                # normalize embeddings
                embeddings = F.normalize(embeddings, p=2, dim=1)
                embeds.extend(embeddings)
            all_embeds[k] = embeds

    for k in all_embeds.keys():
        if len(all_embeds[k]) > 0:
            embeds = [e.numpy().tolist() for e in all_embeds[k]]
            ds = ds.add_column(f"{k}_embeds", embeds)
    dsdict[dsk] = ds

dsdict.save_to_disk(f"{args.dataset}_sfrmistral")
