import fastchat.conversation
from transformers import AutoTokenizer
import uuid

TOKENIZER_CACHE = {}

def get_tokenizer(model_name):
    if model_name in TOKENIZER_CACHE:
        return TOKENIZER_CACHE[model_name]
    tokenizer_name = model_name
    if "llama" in tokenizer_name.lower() or 'vicuna' in tokenizer_name:
        tokenizer_name = "hf-internal-testing/llama-tokenizer"
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    TOKENIZER_CACHE[model_name] = tokenizer
    return tokenizer

def get_conv_template(model_name):
    if "Llama-2" in model_name:
        name = "llama-2"
    elif "vicuna" in model_name:
        name = "vicuna_v1.1"
    else:
        raise NotImplementedError(f"unknown {model_name=}")
    return fastchat.conversation.get_conv_template(name)


def get_attack_template(command, response, conv_template, tokenizer):
    placeholder = str(uuid.uuid4())
    prompt = f"{command} {placeholder}" if command else placeholder
    target = f"{response}"
    conv = conv_template.copy()
    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], "")
    prefix, suffix = conv.get_prompt().split(placeholder)

    return (
        tokenizer.encode(prefix),
        tokenizer.encode(suffix, add_special_tokens=False),
        tokenizer.encode(target, add_special_tokens=False),
    )

def get_raw_embedding_table(model): 
    return model.get_input_embeddings()._parameters["weight"]