import os
import time
import openai
from openai import OpenAI
import google.generativeai as genai
from datetime import datetime
from typing import List, Dict
from fastchat.conversation import get_conv_template
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

import torch
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoModelForCausalLM,
    AutoTokenizer,
    StoppingCriteria,
    BitsAndBytesConfig,
    StoppingCriteriaList,
)

from vllm import LLM, SamplingParams

PATH_HF_CACHE = "~/.cache"
API_TIMEOUTS = [1, 2, 4, 8, 16, 32]

API_KEYS = {
    "openai": "xxx"
}

MODELS = {
    "openai/gpt-3.5-turbo": {
        "company": "openai",
        "model_class": "OpenAIModel",
        "model_name": "gpt-3.5-turbo",  # gpt-3.5-turbo-0125
        "8bit": None,
        "likelihood_access": False,
        "endpoint": "ChatCompletion",
    },
    "openai/gpt-4": {
        "company": "openai",
        "model_class": "OpenAIModel",
        "model_name": "gpt-4",
        "8bit": None,
        "likelihood_access": False,
        "endpoint": "ChatCompletion",
    },
    "openai/gpt-4-turbo": {
        "company": "openai",
        "model_class": "OpenAIModel",
        "model_name": "gpt-4-turbo-2024-04-09",
        "8bit": None,
        "likelihood_access": False,
        "endpoint": "ChatCompletion",
    },
    "openai/gpt-4-1106": {
        "company": "openai",
        "model_class": "OpenAIModel",
        "model_name": "gpt-4-1106-preview",  # gpt-4-turbo-2024-04-09
        "8bit": None,
        "likelihood_access": False,
        "endpoint": "ChatCompletion",
    },
    "openai/gpt-4-0125": {
        "company": "openai",
        "model_class": "OpenAIModel",
        "model_name": "gpt-4-0125-preview",
        "8bit": None,
        "likelihood_access": False,
        "endpoint": "ChatCompletion",
    },
    "google/gemini-pro": {
        "company": "google",
        "model_class": "GeminiModel",
        "model_name": "gemini-pro",  # gemini-1.5-flash-latest
        "8bit": None,
        "likelihood_access": False,
        "endpoint": None,
    },
    "meta/llama2-7b-chat": {
        "company": "meta",
        "model_class": "LlamaModel",
        "model_name": "meta-llama/Llama-2-7b-chat-hf",
        "8bit": None,
        "likelihood_access": True,
        "endpoint": None,
        "conv_template": "llama-2"
    },
    "meta/llama2-13b-chat": {
        "company": "meta",
        "model_class": "LlamaModel",
        "model_name": "meta-llama/Llama-2-13b-chat-hf",
        "8bit": None,
        "likelihood_access": True,
        "endpoint": None,
        "conv_template": "llama-2"
    },
    "meta/llama2-70b-chat": {
        "company": "meta",
        "model_class": "LlamaModel",
        "model_name": "meta-llama/Llama-2-70b-chat-hf",
        "8bit": None,
        "likelihood_access": True,
        "endpoint": None,
        "conv_template": "llama-2"
    },
    "mistralai/mistral-large": {
        "company": "mistralai",
        "model_class": "MistralLargeModel",
        "model_name": "mistral-large-latest",
        "8bit": None,
        "likelihood_access": True,
        "endpoint": None,
        "conv_template": "mistral"
    },
    "mistralai/mistral-8x7b-instruct": {
        "company": "mistralai",
        "model_class": "MistralModel",
        "model_name": "mistralai/Mixtral-8x7B-Instruct-v0.1",
        "8bit": None,
        "likelihood_access": True,
        "endpoint": None,
        "conv_template": "mistral"
    },
    "mistralai/mistral-7b-instruct": {
        "company": "mistralai",
        "model_class": "MistralModel",
        "model_name": "mistralai/Mistral-7B-Instruct-v0.2",
        "8bit": None,
        "likelihood_access": True,
        "endpoint": None,
        "conv_template": "mistral"
    },
    "qwen/qwen1.5-72b-chat": {
        "company": "alibaba",
        "model_class": "QwenModel",
        "model_name": "Qwen/Qwen1.5-72B-Chat",
        "8bit": None,
        "likelihood_access": True,
        "endpoint": None,
        "conv_template": "qwen-7b-chat"
    },
    "qwen/qwen1.5-7b-chat": {
        "company": "alibaba",
        "model_class": "QwenModel",
        "model_name": "Qwen/Qwen1.5-7B-Chat",
        "8bit": None,
        "likelihood_access": True,
        "endpoint": None,
        "conv_template": "qwen-7b-chat"
    },
    "lmsys/vicuna-7b-v1.5": {
        "company": "lmsys",
        "model_class": "VicunaModel",
        "model_name": "lmsys/vicuna-7b-v1.5",
        "8bit": None,
        "likelihood_access": True,
        "endpoint": None,
        "conv_template": "vicuna_v1.1",
    },
    "lmsys/vicuna-33b-v1.3": {
        "company": "lmsys",
        "model_class": "VicunaModel",
        "model_name": "lmsys/vicuna-33b-v1.3",
        "8bit": None,
        "likelihood_access": True,
        "endpoint": None,
        "conv_template": "vicuna_v1.1",
    },
    "allenai/tulu-2-dpo-7b": {
        "company": "allenai",
        "model_class": "TuluModel",
        "model_name": "allenai/tulu-2-dpo-7b",
        "8bit": None,
        "likelihood_access": True,
        "endpoint": None,
        "conv_template": "tulu",
    },
    "allenai/tulu-2-dpo-70b": {
        "company": "allenai",
        "model_class": "TuluModel",
        "model_name": "allenai/tulu-2-dpo-70b",
        "8bit": None,
        "likelihood_access": True,
        "endpoint": None,
        "conv_template": "tulu",
    },
    "thudm/chatglm3-6b": {
        "company": "THUDM",
        "model_class": "GLMModel",
        "model_name": "THUDM/chatglm3-6b",
        "8bit": None,
        "likelihood_access": True,
        "endpoint": None,
        "conv_template": "chatglm3",
    },
}

def get_timestamp():
    """
    Generate timestamp of format Y-M-D_H:M:S
    """
    return datetime.now().strftime("%Y-%m-%d_%H:%M:%S")

def get_compute_capability():
    if not torch.cuda.is_available():
        raise ValueError("CUDA is not available on this device!")

    capability_str = torch.cuda.get_device_capability()
    capability = float(f"{capability_str[0]}.{capability_str[1]}")
    return capability

def check_bf16_support():
    capability = get_compute_capability()
    if capability >= 8.0:
        return True
    return False

class LanguageModel:
    """ Generic LanguageModel Class"""
    
    def __init__(self, model_name, args):
        assert model_name in MODELS, f"Model {model_name} is not supported!"

        # Set some default model variables
        self._model_id = model_name
        self._model_name = MODELS[model_name]["model_name"]
        self._model_endpoint = MODELS[model_name]["endpoint"]
        self._company = MODELS[model_name]["company"]
        self._likelihood_access = MODELS[model_name]["likelihood_access"]

        # more parameters
        self.args = args

    def get_model_id(self):
        """Return model_id"""
        return self._model_id

    def get_greedy_answer(
        self, prompt_base: str, prompt_system: str, max_tokens: int
    ) -> str:
        """
        Gets greedy answer for prompt_base

        :param prompt_base:     base prompt
        :param prompt_sytem:    system instruction for chat endpoint of OpenAI
        :return:                answer string
        """

    def get_top_p_answer(
        self,
        prompt_base: str,
        prompt_system: str,
        max_tokens: int,
        temperature: float,
        top_p: float,
    ) -> str:
        """
        Gets answer using sampling (based on top_p and temperature)

        :param prompt_base:     base prompt
        :param prompt_sytem:    system instruction for chat endpoint of OpenAI
        :param max_tokens       max tokens in answer
        :param temperature      temperature for top_p sampling
        :param top_p            top_p parameter
        :return:                answer string
        """

### Close Sourced Models

class OpenAIModel(LanguageModel):
    """OpenAI API Wrapper"""
    def __init__(self, model_name: str, args: dict):
        super().__init__(model_name, args)
        assert MODELS[model_name]["model_class"] == "OpenAIModel", (
            f"Errorneous Model Instatiation for {model_name}"
        )
        self.api_key = API_KEYS[f"openai"]
        self.client = OpenAI(api_key = self.api_key)
        self.args = args

    def _prompt_request(
        self,
        prompt_base: str,
        max_tokens: int,
        prompt_system: str = "",
        temperature: float = 1.0,
        top_p: float = 1.0,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        logprobs: int = 1,
        stop: List = ["Human:", " AI:"],
        echo: bool = False,
    ):
        success = False
        t = 0

        while not success:
            try:
                if self._model_endpoint == "ChatCompletion":
                    # Dialogue Format
                    messages = []
                    if prompt_system != "":
                        messages.append({"role": "system", "content": f"{prompt_system}"})
                    messages.append({"role": "user", "content": f"{prompt_base}"})

                    # Query ChatCompletion endpoint
                    response = self.client.chat.completions.create(
                        model=self._model_name,
                        messages=messages,
                        temperature=temperature,
                        top_p=top_p,
                        max_tokens=max_tokens,
                        frequency_penalty=frequency_penalty,
                        presence_penalty=presence_penalty,
                    )

                elif self._model_endpoint == "Completion":
                    # Query Completion endpoint
                    response = openai.Completion.create(
                        model=self._model_name,
                        prompt=f"{prompt_system}{prompt_base}",
                        temperature=temperature,
                        max_tokens=max_tokens,
                        top_p=top_p,
                        frequency_penalty=frequency_penalty,
                        presence_penalty=presence_penalty,
                        logprobs=logprobs,
                        stop=stop,
                        echo=echo,
                    )

                else:
                    raise ValueError("Unknownw Model Endpoint")

                # Set success flag
                success = True

            except:
                time.sleep(API_TIMEOUTS[t])
                t = min(t + 1, len(API_TIMEOUTS))

        return response

    def get_greedy_answer(
        self, prompt_base: list, prompt_system: list, max_tokens: int
    ) -> str:
        return self.get_top_p_answer(
            prompt_base=prompt_base,
            prompt_system=prompt_system,
            max_tokens=max_tokens,
            temperature=0,  # without any other randomness
            top_p=1.0,
        )

    def get_top_p_answer(
        self,
        prompt_base: list,
        prompt_system: list,
        max_tokens: int,
        temperature: float,
        top_p: float,
    ) -> str:

        results = []
        for index in range(len(prompt_base)):
            result = {"timestamp": get_timestamp(),}

            # (1) Top-P Sampling
            response = self._prompt_request(
                prompt_base=prompt_base[index],
                prompt_system=prompt_system[index],
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                frequency_penalty=0.0,
                presence_penalty=0.0,
                logprobs=1,
                stop=["Human:", " AI:"],
                echo=False,
            )

            if self._model_endpoint == "ChatCompletion":
                completion = response.choices[0].message.content.strip()

            elif self._model_endpoint == "Completion":
                completion = response.choices[0].text.strip()

            result["answer_raw"] = completion.strip()
            result["answer"] = completion.strip()
        
            results.append(result)

        return results

class GeminiModel(LanguageModel):
    """Google API Wrapper"""
    def __init__(self, model_name: str, args: dict):
        super().__init__(model_name, args)
        assert MODELS[model_name]["model_class"] == "GeminiModel", (
            f"Errorneous Model Instatiation for {model_name}"
        )
        self.api_key = API_KEYS["google"]
        genai.configure(api_key=self.api_key)
        self.model = genai.GenerativeModel(self._model_name)
        self.args = args
    
    def _prompt_request(
        self,
        prompt_base: str,
        max_tokens: int
    ):
        success = False
        t = 0

        # print("Prompting Gemini API...")
        while not success:
            try:
                response = self.model.generate_content(prompt_base,
                                                    generation_config=genai.types.GenerationConfig(
                                                    max_output_tokens=512,
                                                    temperature=1.0),
                                                    safety_settings = [
                                                        {
                                                            "category": "HARM_CATEGORY_HARASSMENT",
                                                            "threshold": "BLOCK_NONE",
                                                            "probability": "NEGLIGIBLE"
                                                        },
                                                        {
                                                            "category": "HARM_CATEGORY_HATE_SPEECH",
                                                            "threshold": "BLOCK_NONE",
                                                            "probability": "NEGLIGIBLE"
                                                        },
                                                        {
                                                            "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
                                                            "threshold": "BLOCK_NONE",
                                                            "probability": "NEGLIGIBLE"
                                                        },
                                                        {
                                                            "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
                                                            "threshold": "BLOCK_NONE",
                                                            "probability": "NEGLIGIBLE"
                                                        },
                                                    ])
                success = True
            except:
                time.sleep(API_TIMEOUTS[t])
                t = min(t + 1, len(API_TIMEOUTS))

        return response

    def get_top_p_answer(
        self,
        prompt_base: list,
        prompt_system: list,
        max_tokens: int,
        temperature: float,
        top_p: float,
    ) -> str:
        results = []
        for index in range(len(prompt_base)):
            result = {"timestamp": get_timestamp(),}

            response = self._prompt_request(
                prompt_base=prompt_base[index],
                max_tokens=max_tokens
            )

            try:
                completion = response.text.strip()
            except:
                completion = "blocked by Gemini AI"
                print(response.prompt_feedback)
            result["answer_raw"] = completion.strip()
            result["answer"] = completion.strip()

            results.append(result)
        return results

class MistralLargeModel(LanguageModel):
    """Mistral AI Model Wrapper"""
    def __init__(self, model_name: str, args: dict):
        super().__init__(model_name, args)
        assert MODELS[model_name]["model_class"] == "MistralLargeModel", (
            f"Errorneous Model Instatiation for {model_name}"
        )
        self.api_key = API_KEYS["mistralai"]
        self.client = MistralClient(api_key=self.api_key)

    def _prompt_request(
        self,
        prompt_base: str,
        max_tokens: int,
        prompt_system: str = "",
        temperature: float = 1.0,
        top_p: float = 1.0,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        logprobs: int = 1,
        stop: List = ["Human:", " AI:"],
        echo: bool = False,
    ):
        success = False
        t = 0

        while not success:
            try:
                messages = []
                if prompt_system != "":
                    messages.append(ChatMessage(role="system", content=f"{prompt_system}"))
                messages.append(ChatMessage(role="user", content=f"{prompt_base}"))

                chat_response = self.client.chat(
                    model=self._model_name,
                    messages=messages,
                    temperature=temperature,
                    top_p=top_p,
                    max_tokens=max_tokens,
                )

                success = True
            
            except:
                time.sleep(API_TIMEOUTS[t])
                t = min(t + 1, len(API_TIMEOUTS))
        
        return chat_response

    def get_greedy_answer(
        self, prompt_base: list, prompt_system: list, max_tokens: int
    ) -> str:
        return self.get_top_p_answer(
            prompt_base=prompt_base,
            prompt_system=prompt_system,
            max_tokens=max_tokens,
            temperature=0,  # without any other randomness
            top_p=1.0,
        )

    def get_top_p_answer(
        self,
        prompt_base: list,
        prompt_system: list,
        max_tokens: int,
        temperature: float,
        top_p: float,
    ) -> str:

        results = []
        for index in range(len(prompt_base)):
            result = {"timestamp": get_timestamp(),}

            # (1) Top-P Sampling
            response = self._prompt_request(
                prompt_base=prompt_base[index],
                prompt_system=prompt_system[index],
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                frequency_penalty=0.0,
                presence_penalty=0.0,
                logprobs=1,
                stop=["Human:", " AI:"],
                echo=False,
            )

            completion = response.choices[0].message.content.strip()

            result["answer_raw"] = completion.strip()
            result["answer"] = completion.strip()
        
            results.append(result)

        return results

### Open Sourced Models
class LlamaModel(LanguageModel):
    """Meta LLAMA Model Wrapper --> Access through HuggingFace Model Hub"""

    def __init__(self, model_name: str, args: dict):
        super().__init__(model_name, args)
        assert MODELS[model_name]["model_class"] == "LlamaModel", (
            f"Errorneous Model Instatiation for {model_name}"
        )

        self._tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path=self._model_name, cache_dir=PATH_HF_CACHE
        )
        if self._tokenizer.pad_token is None:
            self._tokenizer.pad_token = "<unk>"
            self._tokenizer.pad_token_id = (0)
        self._tokenizer.padding_side = "left"
        # self._tokenizer.pad_token_id = self._tokenizer.eos_token_id
        
        self._device = "cuda"
        device_count = torch.cuda.device_count()
        dtype = "bfloat16" if check_bf16_support() else "float16"
        self._model = LLM(
            model = self._model_name,
            tensor_parallel_size = min(device_count, self.args.gpu_num),
            dtype=dtype,
        )
        self.conv_template = MODELS[model_name]["conv_template"]

    def get_greedy_answer(
        self, prompt_base: list, prompt_system: list, max_tokens: int
    ) -> str:

        # Greedy Search
        return self.get_top_p_answer(
            prompt_base=prompt_base,
            prompt_system=prompt_system,
            max_tokens=max_tokens,
            temperature=0,
            top_p=1.0,
        )

    def get_top_p_answer(
        self,
        prompt_base: list,
        prompt_system: list,
        max_tokens: int,
        temperature: float,
        top_p: float,
    ) -> str:
        
        prompts = []
        for index in range(len(prompt_base)):
            conv_template = get_conv_template(self.conv_template)
            conv_template.append_message(conv_template.roles[0], f"{prompt_system[index]}{prompt_base[index]}")
            conv_template.append_message(conv_template.roles[1], None)
            prompts.append(conv_template.get_prompt())
        
        sampling_params = SamplingParams(
            temperature = temperature,
            top_p = top_p,
            max_tokens = max_tokens,
            stop = conv_template.stop_str,
            stop_token_ids=conv_template.stop_token_ids,
        )

        responses = self._model.generate(prompts, sampling_params)

        results = []
        for idx, response in enumerate(responses):
            result = {"timestamp": get_timestamp(),}
            response_text = response.outputs[0].text.strip()
            prompt = response.prompt
            result["answer_raw"] = f'{prompt_system[idx]}{prompt_base[idx]}{response_text}'
            result["answer"] = response_text
            results.append(result)
        
        return results

class MistralModel(LanguageModel):
    """Mistral AI Model Wrapper --> Access through HuggingFace Model Hub"""

    def __init__(self, model_name: str, args: dict):
        super().__init__(model_name, args)
        assert MODELS[model_name]["model_class"] == "MistralModel", (
            f"Errorneous Model Instatiation for {model_name}"
        )
        
        self._tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path=self._model_name, cache_dir=PATH_HF_CACHE
        )
        if self._tokenizer.pad_token is None:
            self._tokenizer.pad_token = "<unk>"
            self._tokenizer.pad_token_id = (0)
        self._tokenizer.padding_side = "left"
        # self._tokenizer.pad_token_id = self._tokenizer.eos_token_id
        
        self._device = "cuda"
        device_count = torch.cuda.device_count()
        if "8x7b" in self._model_name:
            assert check_bf16_support(), "Mistral AI Model requires BF16 support!"
            dtype = "bfloat16"
        else:
            dtype = "bfloat16" if check_bf16_support() else "float16"
        self._model = LLM(
            model = self._model_name,
            tensor_parallel_size = min(device_count, self.args.gpu_num),
            dtype=dtype,
        )

    def get_greedy_answer(
        self, prompt_base: list, prompt_system: list, max_tokens: int
    ) -> str:
        
        # Greedy Search
        return self.get_top_p_answer(
            prompt_base=prompt_base,
            prompt_system=prompt_system,
            max_tokens=max_tokens,
            temperature=0,
            top_p=1.0,
        )

    def get_top_p_answer(
        self,
        prompt_base: list,
        prompt_system: list,
        max_tokens: int,
        temperature: float,
        top_p: float,
    ) -> str:

        prompts = []
        for index in range(len(prompt_base)):
            conv_template = get_conv_template("mistral")
            conv_template.append_message(conv_template.roles[0], f"{prompt_system[index]}{prompt_base[index]}")
            conv_template.append_message(conv_template.roles[1], None)
            prompts.append(conv_template.get_prompt())
        
        sampling_params = SamplingParams(
            temperature = temperature,
            top_p = top_p,
            max_tokens = max_tokens,
            stop = conv_template.stop_str,
            stop_token_ids=conv_template.stop_token_ids,
        )

        responses = self._model.generate(prompts, sampling_params)
        results = []
        for idx, response in enumerate(responses):
            result = {"timestamp": get_timestamp(),}
            response_text = response.outputs[0].text.strip()
            prompt = response.prompt
            result["answer_raw"] = f'{prompt_system[idx]}{prompt_base[idx]}{response_text}'
            result["answer"] = response_text
            results.append(result)
        
        return results

class QwenModel(LanguageModel):
    """Qwen AI Model Wrapper --> Access through HuggingFace Model Hub"""

    def __init__(self, model_name: str, args: dict):
        super().__init__(model_name, args)
        assert MODELS[model_name]["model_class"] == "QwenModel", (
            f"Errorneous Model Instatiation for {model_name}"
        )

        self._tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path=self._model_name, cache_dir=PATH_HF_CACHE
        )
        if self._tokenizer.pad_token is None:
            self._tokenizer.pad_token = "<unk>"
            self._tokenizer.pad_token_id = (0)
        self._tokenizer.padding_side = "left"
        # self._tokenizer.pad_token_id = self._tokenizer.eos_token_id
        
        self._device = "cuda"
        device_count = torch.cuda.device_count()
        dtype = "bfloat16" if check_bf16_support() else "float16"
        self._model = LLM(
            model = self._model_name,
            tensor_parallel_size = min(device_count, self.args.gpu_num),
            dtype=dtype,
        )
        self.conv_template = MODELS[model_name]["conv_template"]
    
    def get_greedy_answer(
        self, prompt_base: list, prompt_system: list, max_tokens: int
    ) -> str:

        # Greedy Search
        return self.get_top_p_answer(
            prompt_base=prompt_base,
            prompt_system=prompt_system,
            max_tokens=max_tokens,
            temperature=0,
            top_p=1.0,
        )

    def get_top_p_answer(
        self,
        prompt_base: list,
        prompt_system: list,
        max_tokens: int,
        temperature: float,
        top_p: float,
    ) -> str:
        
        prompts = []
        for index in range(len(prompt_base)):
            conv_template = get_conv_template(self.conv_template)
            conv_template.append_message(conv_template.roles[0], f"{prompt_system[index]}{prompt_base[index]}")
            conv_template.append_message(conv_template.roles[1], None)
            prompts.append(conv_template.get_prompt())
        
        sampling_params = SamplingParams(
            temperature = temperature,
            top_p = top_p,
            max_tokens = max_tokens,
            stop = conv_template.stop_str,
            stop_token_ids=conv_template.stop_token_ids,
        )

        responses = self._model.generate(prompts, sampling_params)

        results = []
        for idx, response in enumerate(responses):
            result = {"timestamp": get_timestamp(),}
            response_text = response.outputs[0].text.strip()
            prompt = response.prompt
            result["answer_raw"] = f'{prompt_system[idx]}{prompt_base[idx]}{response_text}'
            result["answer"] = response_text
            results.append(result)
        
        return results

class VicunaModel(LanguageModel):
    """Vicuna AI Model Wrapper --> Access through HuggingFace Model Hub"""

    def __init__(self, model_name: str, args: dict):
        super().__init__(model_name, args)
        assert MODELS[model_name]["model_class"] == "VicunaModel", (
            f"Errorneous Model Instatiation for {model_name}"
        )

        self._tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path=self._model_name, cache_dir=PATH_HF_CACHE
        )
        if self._tokenizer.pad_token is None:
            self._tokenizer.pad_token = "<unk>"
            self._tokenizer.pad_token_id = (0)
        self._tokenizer.padding_side = "left"
        # self._tokenizer.pad_token_id = self._tokenizer.eos_token_id
        
        self._device = "cuda"
        device_count = torch.cuda.device_count()
        dtype = "bfloat16" if check_bf16_support() else "float16"
        self._model = LLM(
            model = self._model_name,
            tensor_parallel_size = min(device_count, self.args.gpu_num),
            dtype=dtype,
        )
        self.conv_template = MODELS[model_name]["conv_template"]
    
    def get_greedy_answer(
        self, prompt_base: list, prompt_system: list, max_tokens: int
    ) -> str:

        # Greedy Search
        return self.get_top_p_answer(
            prompt_base=prompt_base,
            prompt_system=prompt_system,
            max_tokens=max_tokens,
            temperature=0,
            top_p=1.0,
        )

    def get_top_p_answer(
        self,
        prompt_base: list,
        prompt_system: list,
        max_tokens: int,
        temperature: float,
        top_p: float,
    ) -> str:
        
        prompts = []
        for index in range(len(prompt_base)):
            conv_template = get_conv_template(self.conv_template)
            conv_template.append_message(conv_template.roles[0], f"{prompt_system[index]}{prompt_base[index]}")
            conv_template.append_message(conv_template.roles[1], None)
            prompts.append(conv_template.get_prompt())
        
        sampling_params = SamplingParams(
            temperature = temperature,
            top_p = top_p,
            max_tokens = max_tokens,
            stop = conv_template.stop_str,
            stop_token_ids=conv_template.stop_token_ids,
        )

        responses = self._model.generate(prompts, sampling_params)

        results = []
        for idx, response in enumerate(responses):
            result = {"timestamp": get_timestamp(),}
            response_text = response.outputs[0].text.strip()
            prompt = response.prompt
            result["answer_raw"] = f'{prompt_system[idx]}{prompt_base[idx]}{response_text}'
            result["answer"] = response_text
            results.append(result)
        
        return results

class TuluModel(LanguageModel):
    """Tulu AI Model Wrapper --> Access through HuggingFace Model Hub"""

    def __init__(self, model_name: str, args: dict):
        super().__init__(model_name, args)
        assert MODELS[model_name]["model_class"] == "TuluModel", (
            f"Errorneous Model Instatiation for {model_name}"
        )

        self._tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path=self._model_name, cache_dir=PATH_HF_CACHE
        )
        if self._tokenizer.pad_token is None:
            self._tokenizer.pad_token = "<unk>"
            self._tokenizer.pad_token_id = (0)
        self._tokenizer.padding_side = "left"
        # self._tokenizer.pad_token_id = self._tokenizer.eos_token_id
        
        self._device = "cuda"
        device_count = torch.cuda.device_count()
        dtype = "bfloat16" if check_bf16_support() else "float16"
        self._model = LLM(
            model = self._model_name,
            tensor_parallel_size = min(device_count, self.args.gpu_num),
            dtype=dtype,
        )
        self.conv_template = MODELS[model_name]["conv_template"]
    
    def get_greedy_answer(
        self, prompt_base: list, prompt_system: list, max_tokens: int
    ) -> str:

        # Greedy Search
        return self.get_top_p_answer(
            prompt_base=prompt_base,
            prompt_system=prompt_system,
            max_tokens=max_tokens,
            temperature=0,
            top_p=1.0,
        )

    def get_top_p_answer(
        self,
        prompt_base: list,
        prompt_system: list,
        max_tokens: int,
        temperature: float,
        top_p: float,
    ) -> str:
        
        prompts = []
        for index in range(len(prompt_base)):
            conv_template = get_conv_template(self.conv_template)
            conv_template.append_message(conv_template.roles[0], f"{prompt_system[index]}{prompt_base[index]}")
            conv_template.append_message(conv_template.roles[1], None)
            prompts.append(conv_template.get_prompt())
        
        sampling_params = SamplingParams(
            temperature = temperature,
            top_p = top_p,
            max_tokens = max_tokens,
            stop = conv_template.stop_str,
            stop_token_ids=conv_template.stop_token_ids,
        )

        responses = self._model.generate(prompts, sampling_params)

        results = []
        for idx, response in enumerate(responses):
            result = {"timestamp": get_timestamp(),}
            response_text = response.outputs[0].text.strip()
            prompt = response.prompt
            result["answer_raw"] = f'{prompt_system[idx]}{prompt_base[idx]}{response_text}'
            result["answer"] = response_text
            results.append(result)
        
        return results

class GLMModel(LanguageModel):
    """THUDM GLM Model Wrapper --> Access through HuggingFace Model Hub"""

    def __init__(self, model_name: str, args: dict):
        super().__init__(model_name, args)
        assert MODELS[model_name]["model_class"] == "GLMModel", (
            f"Errorneous Model Instatiation for {model_name}"
        )

        self._tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path=self._model_name, cache_dir=PATH_HF_CACHE
        )
        if self._tokenizer.pad_token is None:
            self._tokenizer.pad_token = "<unk>"
            self._tokenizer.pad_token_id = (0)
        self._tokenizer.padding_side = "left"
        # self._tokenizer.pad_token_id = self._tokenizer.eos_token_id
        
        self._device = "cuda"
        device_count = torch.cuda.device_count()
        dtype = "bfloat16" if check_bf16_support() else "float16"
        self._model = LLM(
            model = self._model_name,
            tensor_parallel_size = min(device_count, self.args.gpu_num),
            dtype=dtype,
        )
        self.conv_template = MODELS[model_name]["conv_template"]

    def get_greedy_answer(
        self, prompt_base: list, prompt_system: list, max_tokens: int
    ) -> str:

        # Greedy Search
        return self.get_top_p_answer(
            prompt_base=prompt_base,
            prompt_system=prompt_system,
            max_tokens=max_tokens,
            temperature=0,
            top_p=1.0,
        )

    def get_top_p_answer(
        self,
        prompt_base: list,
        prompt_system: list,
        max_tokens: int,
        temperature: float,
        top_p: float,
    ) -> str:
        
        prompts = []
        for index in range(len(prompt_base)):
            conv_template = get_conv_template(self.conv_template)
            conv_template.append_message(conv_template.roles[0], f"{prompt_system[index]}{prompt_base[index]}")
            conv_template.append_message(conv_template.roles[1], None)
            prompts.append(conv_template.get_prompt())
        
        sampling_params = SamplingParams(
            temperature = temperature,
            top_p = top_p,
            max_tokens = max_tokens,
            stop = conv_template.stop_str,
            stop_token_ids=conv_template.stop_token_ids,
        )

        responses = self._model.generate(prompts, sampling_params)

        results = []
        for idx, response in enumerate(responses):
            result = {"timestamp": get_timestamp(),}
            response_text = response.outputs[0].text.strip()
            prompt = response.prompt
            result["answer_raw"] = f'{prompt_system[idx]}{prompt_base[idx]}{response_text}'
            result["answer"] = response_text
            results.append(result)
        
        return results