import abc
import os
import pickle
import logging
from pathlib import Path
from typing import List

from prompt_compiler.data_structs.llm_response import LLMResponse
from utils.utils import str_to_identifier

logger = logging.getLogger("global_logger")

class LargeLanguageModel(abc.ABC):
    """A pretrained large language model."""

    @abc.abstractmethod
    def get_id(self) -> str:
        """Get a string identifier for this LLM.

        This identifier should include sufficient information so that
        querying the same model with the same prompt and same identifier
        should yield the same result (assuming temperature 0).
        """
        raise NotImplementedError("Override me!")

    @abc.abstractmethod
    def _sample_completions(self,
                            prompt: str,
                            temperature: float,
                            stop_token: str,
                            max_tokens: int,
                            freq_penalty: float,
                            num_completions: int = 1) -> List[LLMResponse]:
        """This is the main method that subclasses must implement.

        This helper method is called by sample_completions(), which
        caches the prompts and responses to disk.
        """
        raise NotImplementedError("Override me!")

    def sample_completions(self,
                           prompt: str,
                           temperature: float = 0.0,
                           max_tokens: int = 512,
                           stop_token: str = "\n\n",
                           llm_cache_dir: str = "llm_cache",
                           freq_penalty: float = 0.0,
                           seed: int = 1,
                           num_completions: int = 1,
                           disable_cache: bool = False) -> List[LLMResponse]:
        """Sample one or more completions from a prompt.

        Higher temperatures will increase the variance in the responses.
        The seed may not be used and the results may therefore not be
        reproducible for LLMs where we only have access through an API
        that does not expose the ability to set a random seed. Responses
        are saved to disk.
        """

        # Set up the cache file.
        os.makedirs(llm_cache_dir, exist_ok=True)
        llm_id = self.get_id()
        prompt_id = str_to_identifier(prompt)
        # If the temperature is 0, the seed does not matter.
        escaped_stop_token = stop_token.replace("\n", "\\n")
        if temperature == 0.0:
            config_id = f"most_likely_{num_completions}_{escaped_stop_token}_{freq_penalty}"
        else:
            config_id = f"{temperature}_{seed}_{num_completions}_{escaped_stop_token}_{freq_penalty}"
        cache_filename = f"{llm_id}_{config_id}_{prompt_id}.pkl"
        cache_filepath = Path(llm_cache_dir) / cache_filename
        if not os.path.exists(cache_filepath):
            os.makedirs(os.path.dirname(cache_filepath), exist_ok=True)
        if disable_cache or not os.path.exists(cache_filepath):
            logger.info(f"Querying LLM {llm_id} with new prompt.")
            completions = self._sample_completions(prompt, temperature, stop_token, max_tokens, freq_penalty, num_completions)
            # Cache the completions.
            with open(cache_filepath, 'wb') as f:
                pickle.dump(completions, f)
            logger.info(f"Saved LLM response to {cache_filepath}.")

        # Load the saved completion.
        with open(cache_filepath, 'rb') as f:
            completions = pickle.load(f)
        logger.info(f"Loaded LLM response from {cache_filepath}.")
        logger.info(f"In this query {cache_filepath}:")
        logger.info(prompt)
        logger.info(f"In this response {cache_filepath}:")
        temp = completions[0].response_text
        temp = temp.replace("```BNF\n", "").replace("```\n", "").replace("```DSL\n", "").replace("```", "")
        object.__setattr__(completions[0], 'response_text', temp)
        logger.info(completions[0].response_text)
        return completions

    def greedy_completion(self,
                          prompt: str,
                          max_tokens: int,
                          stop_token: str,
                          llm_cache_dir: str,
                          freq_penalty: float,
                          seed: int) -> LLMResponse:
        """Sample a greedy completion from a prompt."""
        responses = self.sample_completions(prompt, 0.0, max_tokens, stop_token, llm_cache_dir, freq_penalty, seed)
        assert len(responses) == 1
        return responses[0]

    # @abc.abstractmethod
    def _sample_next_token_with_logit_bias(self, prompt, logit_bias, temperature):
        """Sample the next token from the model with a logit bias."""
        raise NotImplementedError("Override me!")
