import os
import json
from typing import List, Dict, Union
from tqdm import tqdm, trange
import time
import numpy as np
import scipy
import scipy.optimize
from transformers import GPT2Tokenizer
from copy import deepcopy
import openai
import random
from query import query_wrapper


openai.api_key = os.environ["OPENAI_API_KEY"]
GPT2TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2")


def assign_labels(ground_truth_labels, predicted_labels):
    n = len(ground_truth_labels)
    assert n == len(predicted_labels)
    m = max(ground_truth_labels) + 1
    mp = max(predicted_labels) + 1
    assert mp >= m
    cost_matrix = np.zeros((m, mp))
    for gt_label, pred_label in zip(ground_truth_labels, predicted_labels):
        cost_matrix[gt_label, pred_label] -= 1
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(cost_matrix)
    row_ind, col_ind = row_ind.tolist(), col_ind.tolist()
    mapping = {col_ind[i]: row_ind[i] for i in range(m)}
    for i in range(mp):
        if i not in mapping:
            mapping[i] = np.argmax(cost_matrix[:, i]).item()
    return [mapping[p] for p in predicted_labels], mapping


def automap_cluster_to_descriptions(clustering_labels, validation_scores):
    n = len(clustering_labels)
    assert n == len(validation_scores)
    m_desc = validation_scores.shape[1]
    m_cluster = max(clustering_labels) + 1
    assert m_cluster <= m_desc
    cost_matrix = np.zeros((m_cluster, m_desc))
    for cluster_label, description_scores in zip(clustering_labels, validation_scores):
        cost_matrix[cluster_label] -= description_scores
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(cost_matrix)
    mapping = {row_ind[i]: col_ind[i] for i in range(m_cluster)}
    return mapping


def get_only_groundtruth_num_descriptions(data_path):
    from data.data import LabelInformation

    label_info = LabelInformation.load_data(os.path.join(data_path, "label_info.json"))
    assert label_info.class_names is not None
    return label_info.num_classes


DEFAULT_MESSAGE = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": None},
]


def chat_gpt_wrapper(**args) -> Union[None, List[str]]:
    """
    A wrapper for openai.ChatCompletion.create() that retries 10 times if it fails.

    Parameters
    ----------
    **args
        The arguments to pass to openai.ChatCompletion.create(). This includes things like the prompt, the model, temperature, etc.

    Returns
    -------
    List[str]
        The list of responses from the API.
    """

    if args.get("messages") is None:
        args["messages"] = deepcopy(DEFAULT_MESSAGE)
        args["messages"][1]["content"] = args["prompt"]
        del args["prompt"]

    if args["model"] == "gpt-4":
        openai.organization = os.environ["SUBSIDIZED_ORG"]
    else:
        openai.organization = os.environ["SUBSIDIZED_ORG"]
    for _ in range(10):
        try:
            responses = openai.ChatCompletion.create(**args)
            all_text_content_responses = [c.message.content for c in responses.choices]
            return all_text_content_responses
        except KeyboardInterrupt:
            raise KeyboardInterrupt
        except Exception as e:
            print(e)
            time.sleep(10)

    return None


def estimate_querying_cost(
    num_prompt_toks: int, num_completion_toks: int, model: str
) -> float:
    """
    Estimate the cost of running the API, as of 2023-04-06.

    Parameters
    ----------
    num_prompt_toks : int
        The number of tokens in the prompt.
    num_completion_toks : int
        The number of tokens in the completion.
    model : str
        The model to be used.

    Returns
    -------
    float
        The estimated cost of running the API.
    """

    if model == "gpt-3.5-turbo":
        cost_per_prompt_token = 0.002 / 1000
        cost_per_completion_token = 0.002 / 1000
    elif model == "gpt-4":
        cost_per_prompt_token = 0.03 / 1000
        cost_per_completion_token = 0.06 / 1000
    elif model == "gpt-4-32k":
        cost_per_prompt_token = 0.06 / 1000
        cost_per_completion_token = 0.12 / 1000
    elif model.startswith("text-davinci-"):
        cost_per_prompt_token = 0.02 / 1000
        cost_per_completion_token = 0.02 / 1000
    else:
        raise ValueError(f"Unknown model: {model}")

    cost = (
        num_prompt_toks * cost_per_prompt_token
        + num_completion_toks * cost_per_completion_token
    )
    return cost


def gpt3wrapper(max_repeat=20, **arguments) -> Union[None, openai.Completion]:
    """
    A wrapper for openai.Completion.create() that retries 20 times if it fails.

    Parameters
    ----------
    max_repeat : int, optional
        The maximum number of times to retry the API call, by default 20
    **arguments
        The arguments to pass to openai.Completion.create(). This includes things like the prompt, the model, temperature, etc.

    Returns
    -------
    Union[None, openai.Completion]
        The response from the API. If the API fails, this will be None.
    """

    openai.organization = os.environ["SUBSIDIZED_ORG"]
    i = 0
    while i < max_repeat:
        try:
            start_time = time.time()
            response = openai.Completion.create(**arguments)
            end_time = time.time()
            # print('completed one query in', end_time - start_time)
            return response
        except KeyboardInterrupt:
            raise KeyboardInterrupt
        except Exception as e:
            print(arguments["prompt"])
            print(e)
            print("now sleeping")
            time.sleep(30)
            i += 1
    return None


def gpt3wrapper_texts(max_repeat=20, **arguments) -> Union[None, str, List[str]]:
    """
    A wrapper for openai.Completion.create() that returns the text of the response.

    Parameters
    ----------
    max_repeat : int, optional
        The maximum number of times to retry the API call, by default 20
    **arguments
        The arguments to pass to openai.Completion.create(). This includes things like the prompt, the model, temperature, etc.

    Returns
    -------
    Union[None, str, List[str]]
        The text of the response. If the prompt is a list, then the response is a list of strings. Otherwise, it is a single string. If the API call fails, then None is returned.
    """

    response = gpt3wrapper(max_repeat=max_repeat, **arguments)
    if response is None:
        return None
    if type(arguments["prompt"]) == list:
        return [r["text"] for r in response["choices"]]
    else:
        return response["choices"][0]["text"]


def gpt3wrapper_texts_batch_iter(max_repeat=20, bsize=20, verbose=False, **arguments):
    """
    A wrapper for gpt3wrapper_texts that batches the prompts.

    Parameters
    ----------
    max_repeat : int, optional
        The maximum number of times to retry the API call, by default 20
    bsize : int, optional
        The batch size, by default 20
    verbose : bool, optional
        Whether to print a progress bar, by default False
    **arguments
        The arguments to pass to gpt3wrapper_texts. This includes things like the prompt, the model, temperature, etc.

    Yields
    -------
    str
        The response from the API.
    """

    openai.api_key = os.environ["OPENAI_API_KEY"]

    # make sure the prompt is a list
    prompt = arguments["prompt"]
    assert type(prompt) == list

    # batch the prompts
    num_batches = (len(prompt) - 1) // bsize + 1
    iterator = trange(num_batches) if verbose else range(num_batches)
    for i in iterator:
        arg_copy = deepcopy(arguments)
        arg_copy["prompt"] = prompt[i * bsize : (i + 1) * bsize]

        # make the API call
        response = gpt3wrapper(max_repeat=max_repeat, **arg_copy)

        # yield the response
        if response is None:
            for _ in range(len(arg_copy["prompt"])):
                yield None
        else:
            for text in [r["text"] for r in response["choices"]]:
                yield text


def parse_description_responses(response: str) -> List[str]:
    """
    Parse the description responses from the proposer model.

    Parameters
    ----------
    response : str
        The response from the proposer model, each description is separated by a newline, surrounded by quotes. We will extract the description within the quotes for each line.

    Returns
    -------
    List[str]
        A list of descriptions.
    """
    descriptions = []
    for line_id, line in enumerate(response.split("\n")):
        line = line.strip()
        # find the two quotes
        if line_id == 0 and line.count('"') != 2:
            start, end = -1, line.rfind('"')
        else:
            start, end = line.find('"'), line.rfind('"')
        description = line[start + 1 : end].strip()
        if description != "":
            descriptions.append(description)

    return descriptions


GPT2_TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2")
DESCRIPTION_ESTIMATED_LENGTH = (
    20  # estimated length of a description for cost estimation
)


def estimate_proposer_cost(
    prompt: str, model: str, num_descriptions_per_prompt: int
) -> float:
    """
    Estimate the cost of running the API, as of 2023-04-06.

    Parameters
    ----------
    prompt : str
        The prompt to be sent to the model.
    model : str
        The model to be used.
    num_descriptions_per_prompt : int
        The number of descriptions the model should suggest.

    Returns
    -------
    float
        The estimated cost of running the API.
    """
    num_prompt_toks = len(GPT2_TOKENIZER.encode(prompt))
    num_completion_toks = num_descriptions_per_prompt * DESCRIPTION_ESTIMATED_LENGTH

    return estimate_querying_cost(num_prompt_toks, num_completion_toks, model)


def get_context_length(model: str) -> int:
    """
    Get the context length for the given model.

    Parameters
    ----------
    model : str
        The model in the API to be used.

    Returns
    -------
    int
        The context length.
    """

    if model in ("text-davinci-002", "text-davinci-003"):
        return 4096
    if model == "gpt-4":
        return 8000
    elif model == "gpt-4-32k":
        return 32000
    elif model == "gpt-3.5-turbo":
        return 4096
    elif "claude" in model:
        return 4096
    else:
        raise ValueError(f"Unknown model {model}")


def get_avg_length(texts: List[str], max_num_samples=500) -> float:
    """
    Get the average length of texts in a list of texts.

    Parameters
    ----------
    texts : List[str]
        A list of texts.
    max_num_samples : int
        The maximum number of texts to sample to compute the average length.

    Returns
    -------
    float
        The average length of texts.
    """
    if len(texts) > max_num_samples:
        sampled_texts = random.sample(texts, max_num_samples)
    else:
        sampled_texts = texts
    avg_length = np.mean([len(GPT2TOKENIZER.encode(t)) for t in sampled_texts])
    return avg_length


SIMILARITY_TEMPLATE_PATH = "templates/similarity_judgment.txt"
with open(SIMILARITY_TEMPLATE_PATH) as f:
    SIMILARITY_TEMPLATE = f.read()


def get_similarity_scores(dicts, model="claude-v1.3"):
    prompts = [SIMILARITY_TEMPLATE.format(**d) for d in dicts]
    responses = query_wrapper(
        prompts=prompts, model=model, num_processes=10, temperature=0.0, progress_bar=True
    )
    return [r.strip() for r in responses]


if __name__ == "__main__":
    dicts = [
        {"text_a": "has a genre of poem", "text_b": "has a poetic style"},
        {"text_a": "is sports related", "text_b": "is about sports"},
        {"text_a": "is a language of korean", "text_b": "is a type of cuisine"},
    ]
    print(get_similarity_scores(dicts))
