import torch
import transformers
import numpy as np
from torch.utils.data import DataLoader

from tqdm import tqdm

from llama3 import model, tokenizer, EOS
from data import get_loaders

@torch.no_grad()
def get_cache(model: transformers.PreTrainedModel, train_loader: DataLoader):
    
    keys = []
    vals = []

    model.eval()
    for entry in tqdm(train_loader, desc="Generating...", leave=True):
        texts = entry["Text"]
        for text in texts:
            text = text + EOS
        
        inputs = tokenizer(texts, return_tensors="pt", 
                           padding=True).to(model.device)
        length = inputs.attention_mask.sum(dim=-1)

        outputs = model(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1]

        for hss, ids, l in zip(hidden_states.cpu(), inputs.input_ids.cpu(), length.cpu()):
            for i in range(l - 1):
                keys.append(hss[i].numpy())
                vals.append(ids[i + 1].item())
    print(len(vals))
    
    return np.stack(keys), np.array(vals)

def main():
    global model
    trainloader, validloader = get_loaders(batch_size = 4)
    keys, vals = get_cache(model, trainloader)
    print(keys.shape)
    np.save("keys.npy", keys)
    np.save("vals.npy", vals)

if __name__ == "__main__":
    main()