import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

dimension = 4096                            # Attention dimension
attention_heads = 32                        # How many attention heads to use
head_size = dimension // attention_heads    # How many dimension in each attention head
max_length = 4096
model_name = "LLaMA3-8B"
tempreture = 100

device = torch.device('cuda')
tokenizer = AutoTokenizer.from_pretrained("/meta-llama/Meta-Llama-3-8B")
model = AutoModelForCausalLM.from_pretrained("/meta-llama/Meta-Llama-3-8B").half().to(device)
model.eval()

@torch.no_grad()
def get_att(prompt: str) -> list:
    
    ips = tokenizer(
        prompt,
        return_tensors="pt")
    
    input_ids = ips.input_ids.to(device)

    if(input_ids.shape[-1] > max_length):
        return []
    
    att = model.model(
        input_ids=input_ids,
        return_dict=False,
    )[0][0][-1].cpu().tolist()

    return att

choices = {32: 0, 33: 1, 34: 2, 35: 3, 36: 4, 37: 5, 38: 6, 
           356: 2, 362: 0, 423: 3, 426: 1, 435: 5, 469: 4, 480: 6,}

@torch.no_grad()
def get_length(prompt: str) -> int:
    inputs = tokenizer(prompt, return_tensors="pt")
    return  inputs.input_ids.shape[-1]

@torch.no_grad()
def get_ans(prompt: str, full_generate=False):

    inputs = tokenizer(prompt, return_tensors="pt")
    base_length = inputs.input_ids.shape[-1]

    if base_length > max_length:
        return 0, -1

    option = 0

    if full_generate:
        option = model.generate(inputs.input_ids.to(device), 
                                return_dict=False, 
                                max_length=4096)[0][base_length:].cpu().tolist()[0]
    else:
        option = model(inputs.input_ids.to(device), 
                       return_dict=False)[0][0][-1].cpu().argmax().item()
    
    human = 0
    answer = choices.get(option)
    if answer is None:
        human = 1
        answer = -1
    
    return answer, human

if __name__ == "__main__":

    text = '''What's the right answer? Output A B C or D.
A. The wrong answer.
B. The right answer.
C. The wrong answer.
D. The wrong answer.
Answer:
'''
    text2 = '''Choose the best answer of the question. Output A, B, C, etc.
Question: What's the right answer?
Options:
A. The wrong answer.
B. The right answer.
C. The wrong answer.
D. The wrong answer.
The answer is:'''
    # input_ids = tokenizer(text, return_tensors="pt")["input_ids"]
    # print(input_ids)
    # output = model(input_ids.to(device))["logits"][0].cpu()
    # o_tokens = output.argmax(dim=-1)
    # print(o_tokens.shape)
    # print(tokenizer.decode(o_tokens))

    print(tokenizer.decode(model.generate(tokenizer(text2, return_tensors="pt").input_ids.to("cuda"))[0].tolist()))
    print(get_ans(text2))
    print(len(get_att(text2)))

    for k, v in choices.items():
        print("Key: {} Val: {} Chr: {}".format(k, v, tokenizer.decode(k)))

    for i in range(60000):
        token = tokenizer.decode([i])
        if token == 'A' or token == ' A' or token == 'A.':
            print(f"{i}: 0, ", end='')
        if token == 'B' or token == ' B':
            print(f"{i}: 1, ", end='')
        if token == 'C' or token == ' C':
            print(f"{i}: 2, ", end='')
        if token == 'D' or token == ' D':
            print(f"{i}: 3, ", end='')
        if token == 'E' or token == ' E':
            print(f"{i}: 4, ", end='')
        if token == 'F' or token == ' F':
            print(f"{i}: 5, ", end='')
        if token == 'G' or token == ' G':
            print(f"{i}: 6, ", end='')
    print()

    print(model)