import json
import requests
import openai
import asyncio
import backoff
import time
from requests.exceptions import Timeout
from typing import Any

@backoff.on_exception(backoff.expo, openai.error.RateLimitError)
def completions_with_backoff(**kwargs):
    return openai.Completion.create(**kwargs)


@backoff.on_exception(backoff.expo, openai.error.RateLimitError)
def chat_completions_with_backoff(**kwargs):
    return openai.ChatCompletion.create(**kwargs)


async def dispatch_openai_chat_requests(
    messages_list: list[list[dict[str,Any]]],
    model: str,
    temperature: float,
    max_tokens: int,
    top_p: float,
    stop_words: list[str]
) -> list[str]:
    """Dispatches requests to OpenAI API asynchronously.
    
    Args:
        messages_list: List of messages to be sent to OpenAI ChatCompletion API.
        model: OpenAI model to use.
        temperature: Temperature to use for the model.
        max_tokens: Maximum number of tokens to generate.
        top_p: Top p to use for the model.
        stop_words: List of words to stop the model from generating.
    Returns:
        List of responses from OpenAI API.
    """
    async_responses = [
        openai.ChatCompletion.acreate(
            # model=model,
            engine = model,
            messages=x,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=top_p,
            stop = stop_words
        )
        for x in messages_list
    ]
    return await asyncio.gather(*async_responses)

async def dispatch_openai_prompt_requests(
    messages_list: list[list[dict[str,Any]]],
    model: str,
    temperature: float,
    max_tokens: int,
    top_p: float,
    stop_words: list[str]
) -> list[str]:
    async_responses = [
        openai.Completion.acreate(
            model=model,
            prompt=x,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=top_p,
            frequency_penalty = 0.0,
            presence_penalty = 0.0,
            stop = stop_words
        )
        for x in messages_list
    ]
    return await asyncio.gather(*async_responses)

class LLMModel:
    def __init__(self, API_KEY, model_name, stop_words, max_new_tokens) -> None:
        openai.api_key = API_KEY
        self.model_name = model_name
        self.max_new_tokens = max_new_tokens
        self.stop_words = stop_words
    
        openai.api_type = "azure"
        openai.api_base = "http_url" ##you self url
        openai.api_version = "2023-05-15"


    def predict(self, messages, max_retries = 100, delay_between_retries=6, n = 1, temperature = 0.0, top_p = 1.0, max_tokens = 800):
        url = "gpt4_url" #gpt4 url
        if self.model_name == 'text-davinci-003':
            url = "gpt_3.5_url" #gpt3.5 url
        elif self.model_name.startswith('llama-2'):
            url = 'llama2_70b_url' #llama2 url
        
        response = self.call_gpt_with_url(messages, max_retries, delay_between_retries, n, temperature, top_p, max_tokens)
        response_json = json.loads(response.text) 
        if self.model_name.startswith('llama-2'):
            predict_result = response_json['output']
        else:
            predict_result = response_json['choices'][0]['message']['content']
        return predict_result

    def call_gpt_with_url(self, messages, max_retries = 100, delay_between_retries=6, n = 1, temperature = 0.7, top_p = 0.95, max_tokens = 800):
        url = "gpt4_url" #gpt4 url
        
        if self.model_name =='text-davinci-003':##gpt 3.5
            url = "gpt_3.5_url" #gpt3.5 url
        elif self.model_name.startswith('llama-2'):
            url = 'llama2_70b_url' #llama2 url
            

        payload = json.dumps({
        "model": self.model_name,
        "messages": messages,
            "n": n, 
            "temperature": temperature,
            "top_p": top_p,
            "frequency_penalty": 0,
            "presence_penalty": 0,
            "max_tokens": max_tokens,
            "stream": False,
            "stop": None
        })
        headers = {
            'Content-Type': 'application/json',
        }

        retries = 0
        while retries < max_retries:
            try:
                response = requests.request("POST", url, headers =headers, data=payload, timeout=300)
                if response.status_code == 200:
                    return response
                else:
                    print(f"http api call failed:, status code: {response.status_code}. re-trying...")
                    retries += 1
                    time.sleep(delay_between_retries)
            except Timeout:
                print("http api call timeout, re-trying...")
                retries += 1
                time.sleep(delay_between_retries)
            except requests.RequestException as e:
                print(f"http api call failed: {e}. re-rrying ...")
                retries += 1
                time.sleep(delay_between_retries)
        raise Exception(f"http api call failed, already re-try {max_retries} times。")
    


    def call_llama_with_url(self, messages, max_retries = 100, delay_between_retries=6, n = 1, temperature = 0.7, top_p = 0.95, max_tokens = 800):
        url = "llama2_url"
        

        payload = json.dumps({
        "model": self.model_name,
        "messages": messages,
            "n": n, 
            "temperature": temperature,
            "top_p": top_p,
            "frequency_penalty": 0,
            "presence_penalty": 0,
            "max_tokens": max_tokens,
            "stream": False,
            "stop": None
        })
        headers = {
            'Content-Type': 'application/json'
        }

        retries = 0
        while retries < max_retries:
            try:
                response = requests.request("POST", url, headers =headers, data=payload, timeout=300)
                if response.status_code == 200:
                    return response
                else:
                    print(f"http api call failed, already re-try, status code: {response.status_code}. re-trying...")
                    retries += 1
                    time.sleep(delay_between_retries)
            except Timeout:
                print("http api call failed, already re-try...")
                retries += 1
                time.sleep(delay_between_retries)
            except requests.RequestException as e:
                print(f"http api call failed: {e}. re-trying...")
                retries += 1
                time.sleep(delay_between_retries)
        raise Exception(f"http api call failed, already re-try {max_retries} times。")

    # used for chat-gpt and gpt-4
    def chat_generate(self, input_string, temperature = 0.0):
        response = chat_completions_with_backoff(
                model = self.model_name,
                messages=[
                        {"role": "user", "content": input_string}
                    ],
                max_tokens = self.max_new_tokens,
                temperature = temperature,
                top_p = 1.0,
                stop = self.stop_words
        )
        generated_text = response['choices'][0]['message']['content'].strip()
        return generated_text
        

    def batch_generate_with_openai(self, messages_list, temperature = 0.0):
        if self.model_name in ['text-davinci-002', 'code-davinci-002', 'text-davinci-003']:
            return self.batch_prompt_generate(messages_list, temperature)
        elif self.model_name in ['gpt-4', 'gpt-3.5-turbo']:
            return self.batch_chat_generate(messages_list, temperature)
        else:
            raise Exception("Model name not recognized")
        

    def batch_prompt_generate(self, prompt_list, temperature = 0.0):
        predictions = asyncio.run(
            dispatch_openai_prompt_requests(
                    prompt_list, self.model_name, temperature, self.max_new_tokens, 1.0, self.stop_words
            )
        )
        return [x['choices'][0]['text'].strip() for x in predictions]
    


    def batch_chat_generate(self, messages_list, temperature = 0.0):
        if self.model_name =='gpt-3.5-turbo':
            self.engine_name = 'play35'

        open_ai_messages_list = []
        for message in messages_list:
            open_ai_messages_list.append(
                [{"role": "user", "content": message}]
            )
       
        try: 
            predictions = asyncio.run(
                dispatch_openai_chat_requests(
                        open_ai_messages_list, self.engine_name, temperature, self.max_new_tokens, 1.0, self.stop_words
                )
            )
        except Exception as e:
            print(f"('******************Error******************): {e}")
            time.sleep(2)
            predictions = []
            
        return [x['choices'][0]['message']['content'].strip() for x in predictions]

