import torch
import transformers
import numpy as np
from torch.utils.data import DataLoader

from tqdm import tqdm
from math import e

from llama3 import model, tokenizer, EOS
from data import get_loaders

import faiss

class rtd:
    def __init__(self, keys, vals, use_gpu=True, heads=32) -> None:
        self.res = faiss.StandardGpuResources()
        self.val = vals
        self.dim = keys.shape[-1]
        self.heads = heads
        self.head_dim = self.dim // heads
        if use_gpu:
            self.key = [faiss.GpuIndexFlatL2(self.res, self.head_dim) for _ in tqdm(range(self.heads),
                desc="Allocating GPU index libraries")]
        else:
            self.key = [faiss.IndexFlatL2(self.head_dim) for _ in range(self.heads)]
        
        for i in tqdm(range(self.heads), desc="Adding Indexes", leave=False):
            lower = i * self.head_dim
            upper = (i + 1) * self.head_dim
            self.key[i].add(keys[..., lower:upper].astype(np.float32))
        
    def search(self, hs, logits_base, k: int, lambd: float, tempreture: float):

        logits_base = logits_base * (1 - lambd)
        cached = logits_base
        rtd_weight = lambd / self.heads

        for i in tqdm(range(self.heads), leave=False):
            lower = i * self.head_dim
            upper = (i + 1) * self.head_dim
            dis, ind = self.key[i].search(hs[..., lower:upper].cpu().numpy().astype(np.float32), k)
            dis = torch.tensor(dis) / tempreture
            dis = torch.nn.functional.softmax(dis, dim=-1) * rtd_weight

            for indx, (ds, vi) in enumerate(zip(dis, ind)):
                vs = self.val[vi]
                for d, v in zip(ds, vs):
                    logits_base[indx][v] += d

        cached_logits_diff = (logits_base - cached) / lambd
        return logits_base, cached_logits_diff

cached = []

@torch.no_grad()
def evaluate(model: transformers.PreTrainedModel, valid_loader: DataLoader, rtd_lib: rtd, 
             k: int, lambd: float, tempreture: float):

    model.eval()
    
    loss_func = torch.nn.NLLLoss()
    preloss, posloss, preppl, posppl, tot = 0., 0., 0., 0., 0

    use_cache = True
    if len(cached) == 0:
        use_cache = False
    
    for index, entry in tqdm(enumerate(valid_loader), 
                             desc="Evaluating...", 
                             total=len(valid_loader), 
                             leave=True):
        
        texts = entry["Text"]
        tot += len(texts)
        for text in texts:
            text = text + EOS
        
        inputs = tokenizer(texts, return_tensors="pt", 
                           padding=True).to(model.device)
        length = inputs.attention_mask.sum(dim=-1)

        outputs = model(**inputs, output_hidden_states=True, labels=inputs.input_ids)
        hidden_states = outputs.hidden_states[-1]

        preloss += outputs.loss
        preppl += e ** outputs.loss

        logits = torch.nn.functional.softmax(outputs.logits, dim=-1)

        hs, lg, gd = [], [], []
        for hss, lgs, ids, l in zip(hidden_states.cpu(), logits.cpu(), 
                                    inputs.input_ids.cpu(), length.cpu()):
            hs.append(hss[0:l - 1])
            lg.append(lgs[0:l - 1])
            gd.append(ids[1:l])

        hs = torch.cat(hs, dim=0)
        lg = torch.cat(lg, dim=0)
        gd = torch.cat(gd, dim=0)

        if use_cache:
            lg = (1 - lambd) * lg + (lambd) * cached[index]
        else:
            lg, cached_diff = rtd_lib.search(hs, lg, k, lambd, tempreture)
            cached.append(cached_diff)

        loss = loss_func(torch.log(lg), gd)
        posloss += loss
        posppl += e ** loss
    
    return preloss / tot, preppl / tot, posloss / tot, posppl / tot

@torch.no_grad()
def generate(prompt: str, rtd_lib: rtd, 
             k: int, lambd: float, tempreture: float) -> str:
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model(**inputs, output_hidden_states=True)

    response = []
    while True:
        
        past_key_values = outputs.past_key_values
        
        next_token_logits = torch.nn.functional.softmax(outputs.logits[..., -1, :], dim=-1)
        last_hidden_state = outputs.hidden_states[-1][..., -1, :]
        if lambd != 0:
            next_token_logits, _ = rtd_lib.search(last_hidden_state, next_token_logits, k, lambd, tempreture)
        
        next_token = next_token_logits.argmax(dim=-1, keepdim=True)
        # print(next_token.shape)

        if next_token.item() == tokenizer.eos_token_id or len(response) > 128:
            break
        response.append(next_token.cpu().item())
        # print(tokenizer.decode(response[-1:]))

        outputs = model(next_token, past_key_values=past_key_values, output_hidden_states=True)
    
    print(tokenizer.decode(response))

        

def main():
    vals = np.load("vals.npy") 
    keys = np.load("keys.npy")
    lib = rtd(keys, vals, False, 32)

    global model
    trainloader, validloader = get_loaders(4)
    
    print(evaluate(model, validloader, lib, 32, .85, 5))
    

if __name__ == "__main__":
    main()