import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

dimension = 4096                            # Attention dimension
attention_heads = 32                        # How many attention heads to use
head_size = dimension // attention_heads    # How many dimension in each attention head
max_length = 4096
model_name = "glm3"
tempreture = 1000
model_batch_size = 1

device = torch.device('cuda')
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True)
tokenizer.padding_side = "left"
# tokenizer.add_special_tokens({'pad_token': '<s>'})
model = AutoModelForCausalLM.from_pretrained("THUDM/chatglm3-6b", 
                                             trust_remote_code=True, 
                                             device_map="auto").half()
model.eval()


@torch.no_grad()
def get_att(prompts: str) -> torch.Tensor:
    
    # ips = torch.cat([tokenizer(prompt, return_tensors="pt")["input_ids"], torch.LongTensor([[187]])], 
    #                 dim=-1)
    # prompt = cut_prompt(prompt)
    ips = tokenizer(prompts, return_tensors="pt", padding=True)["input_ids"]
    base_length = ips.shape[-1]

    if(base_length > max_length):
        return torch.Tensor([0])
    
    att = model.transformer(
        input_ids=ips.to(device),
    ).last_hidden_state[-1, 0,].cpu().tolist()

    return att

choices = {68: 0, 69: 1, 70: 2, 71: 3, 72: 4, 73: 5, 74: 6, 
           316: 0, 319: 2, 347: 1, 367: 3, 368: 5, 397: 4, 402: 6, 
           30938: 0, 30942: 2, 30949: 1, 30950: 4, 30952: 3, 30960: 5, 30964: 6,
           }

@torch.no_grad()
def get_length(prompt: str) -> int:
    # prompt = cut_prompt(prompt)
    return tokenizer(prompt, return_tensors="pt", padding=True)["input_ids"].shape[-1]

@torch.no_grad()
def get_ans(prompts: str, full_generate=True):

    # This model dose not support full_genetare.

    # prompt = cut_prompt(prompt)
    ips = tokenizer(prompts, return_tensors="pt", padding=True)["input_ids"]
    base_length = ips.shape[-1]

    if base_length > max_length:
        return 0, -1

    # if full_generate:
    #     option = model.generate(inputs.input_ids.to(device), 
    #                             return_dict=False, 
    #                             max_length=4096)[0][base_length:].cpu().tolist()[0]
    # else:
    #     option = model(inputs.input_ids.to(device), 
    #                    return_dict=False)[0][0][-1].cpu().argmax().item()
    
    option = model(ips.to(device))["logits"].cpu().argmax(dim=-1)[0, -1].item()

    human = 0
    answer = choices.get(option)

    if answer is None:
        human = 1
        answer = -1
    
    return answer, human



if __name__ == "__main__":

    text = '''What's the right answer? Output A B C or D.
A. The wrong answer.
B. The right answer.
C. The wrong answer.
D. The wrong answer.
Answer:
'''
    text2 = '''What's the best answer you think? Output A B C or D.
A. The wrong answer.
B. The right answer.
C. The wrong answer.
D. The wrong answer.
Answer:
'''
    # input_ids = tokenizer(text, return_tensors="pt")["input_ids"]
    # print(input_ids)
    # output = model(input_ids.to(device))["logits"][0].cpu()
    # o_tokens = output.argmax(dim=-1)
    # print(o_tokens.shape)
    # print(tokenizer.decode(o_tokens))

    # print(get_att(text2).shape)
    # print(get_ans(text2))

    for k, v in choices.items():
        print("Key: {} Val: {} Chr: {}".format(k, v, tokenizer.decode(k)))

    for i in range(50000):
        token = tokenizer.decode([i])
        if token == 'A' or token == ' A' or token == 'A.':
            print(f"{i}: 0, ", end='')
        if token == 'B' or token == ' B':
            print(f"{i}: 1, ", end='')
        if token == 'C' or token == ' C':
            print(f"{i}: 2, ", end='')
        if token == 'D' or token == ' D':
            print(f"{i}: 3, ", end='')
        if token == 'E' or token == ' E':
            print(f"{i}: 4, ", end='')
        if token == 'F' or token == ' F':
            print(f"{i}: 5, ", end='')
        if token == 'G' or token == ' G':
            print(f"{i}: 6, ", end='')

    # print(model)

    