import functools
import logging
import collections
import json

from prompt_compiler.retriever import retrieve_fn_dict, setup_bm25
from prompt_compiler.constrain_decoding import predict_program_with_earley_correction, predict_rules_with_earley_correction
from prompt_compiler.lark_utils import rulelist2larkstr, lark2bnf, linearize_tree, linearized_tree_to_program, counter2pred, gen_min_lark, bnf2lark, check_grammar_validity, decorate_grammar
from prompt_compiler.lark_utils import remove_lf_space as remove_lf_space_overnight
from prompt_compiler.earley_parser.parser import EarleyParser

from utils.corpora_feature import Corpora_Feature
from utils.utils import check_annotated_format

from nltk.tokenize import sent_tokenize, word_tokenize
from nltk import pos_tag

logger = logging.getLogger("global_logger")

class PromptCompiler:
    def __init__(self, dataset, prompt_mode, llm, temperature, retrieve_fn, batch_size, train_examples, prompt_template, global_parser,
                 global_rules, add_rule_instruction_flag, use_linearized_tree_flag, constrain_prog_gen_flag, use_oracle_rule_flag, lazy_constrain_flag, constrain_rule_gen_flag, kg_rule_flag,
                 seed, max_tokens, llm_cache_dir, freq_penalty, add_rule_list_flag, use_action_list_flag):
        self.dataset = dataset
        self.prompt_mode = prompt_mode
        self.llm = llm
        self.retrieve_fn = retrieve_fn_dict[retrieve_fn]
        self.batch_size = batch_size
        self.train_examples = train_examples
        self.add_rule_instruction_flag = add_rule_instruction_flag
        self.add_rule_list_flag = add_rule_list_flag
        self.use_action_list_flag = use_action_list_flag
        self.global_parser = global_parser
        self.global_rules = global_rules
        self.use_linearized_tree_flag = use_linearized_tree_flag
        self.constrain_prog_gen_flag = constrain_prog_gen_flag
        self.use_oracle_rule_flag = use_oracle_rule_flag
        self.temperature = temperature
        self.seed = seed
        self.max_tokens = max_tokens
        self.llm_cache_dir = llm_cache_dir
        self.freq_penalty = freq_penalty
        self.lazy_constrain_flag = lazy_constrain_flag
        self.constrain_rule_gen_flag = constrain_rule_gen_flag
        self.kg_rule_flag = kg_rule_flag
        if use_linearized_tree_flag:
            assert not constrain_prog_gen_flag, "linearized tree is not compatible with earley correction"

        if retrieve_fn == "bm25":
            bm25 = setup_bm25(train_examples)
            self.retrieve_fn = functools.partial(self.retrieve_fn, batch_size=batch_size, bm25=bm25)
        else:
            self.retrieve_fn = functools.partial(self.retrieve_fn, batch_size=batch_size)

        self.DELIMITER = "\nProgram:\n"
        self.prompt_templates = {
            "std": {
                "instruction": ("You are an expert programmer, and you need to write a program"
                                " for the given natural language query.\n"),
                "rule_instruction": "",
                "rule_instruction_with_rule_list": "",
                "exemplar": lambda ex: f"query: {ex.source}\nprogram:\n{ex.target}\n\n",
                "prediction": lambda ex: f"query: {ex.source}\nprogram:\n",
            },
            "wrule": {
                "instruction": ("You are an expert programmer, and you need to write a program"
                                " for the given natural language query.\n"),
                "rule_instruction": "",
                "rule_instruction_with_rule_list": "",
                "exemplar": lambda
                    ex: f"Query:\n\"{ex.source}\"\nBNF:\n```BNF\n{ex.grammar}\n```{self.DELIMITER}```DSL\n{ex.target}\n```\n",
                "rule_exemplar": lambda ex: f"Query:\n\"{ex.source}\"\nBNF:\n```BNF\n{ex.grammar}\n```\n",
                "prediction": lambda ex: f"Query:\n\"{ex.source}\"\nBNF:\n",
                "prediction_given_rule": lambda ex: f"Query:\n\"{ex.source}\"\nBNF:\n```BNF\n{ex.grammar}\n```{self.DELIMITER}",
            }
        }

        self.prompt_template = self.prompt_templates[prompt_template]
        if add_rule_instruction_flag:
            rule_instruction, rule_instruction_with_rule_list = self.__construct_rule_instruction(global_rules, dataset)
            self.prompt_template["rule_instruction"] = rule_instruction
            self.prompt_template["rule_instruction_with_rule_list"] = rule_instruction_with_rule_list

        with open("data/operation_extraction.txt") as file:
            self.operation_extraction_prompt = file.read()

    def __construct_rule_instruction(self, rules, dataset):
        if dataset == "GeoQuery" or dataset == "Overnight-Blk":
            instruction = "First, you should write grammar rules by choosing from the following BNF rules. Then, you should write programs that conform to your predicted rules.\nYou will be provided several demonstration examples. Your output should strictly follow the demonstration example. \n"
            add_rules_flag = True
        elif dataset == "SMCalFlow":
            instruction = "First, you should write a grammar that contains all the necessary BNF rules. Then, you should write programs that conform to your predicted rules.\nYou will be provided several demonstration examples. Your output should strictly follow the demonstration example. \n"
            add_rules_flag = False
        else:
            instruction = "First, you should write grammar rules by choosing from the following BNF rules. Then, you should write programs that conform to your predicted rules.\nYou will be provided several demonstration examples. Your output should strictly follow the demonstration example.\n"
            add_rules_flag = True
            # raise ValueError(f"Unknown dataset: {dataset}")

        rule_instruction = instruction
        rule_instruction_with_rule_list = instruction[:]
        if add_rules_flag:
            lark_str = rulelist2larkstr(rules)
            bnf_str = lark2bnf(lark_str)
            rule_instruction_with_rule_list = f"{instruction}\n[BEGIN RULES]\n{bnf_str}\n[END RULES]\n\n"
        return rule_instruction, rule_instruction_with_rule_list

    def compile(self, input_example):
        if self.prompt_mode == "std":
            template_ex = self.prompt_template["exemplar"]
            template_p = self.prompt_template["prediction"]
            fewshot_prompt = self.prompt_template["instruction"]
            if not self.add_rule_list_flag:
                fewshot_prompt += self.prompt_template["rule_instruction"]
            else:
                fewshot_prompt += self.prompt_template["rule_instruction_with_rule_list"]
            exemplars = self.retrieve_fn(input_example, self.train_examples)
            for exemplar in exemplars:
                if self.use_linearized_tree_flag:
                    if not hasattr(exemplar, "linearized"):
                        exemplar_tree = self.global_parser.parse(exemplar.target)
                        exemplar.target = linearize_tree(exemplar_tree)
                        exemplar.linearized = True
                fewshot_prompt += template_ex(exemplar)

            _prompt = fewshot_prompt + template_p(input_example)

            ret_predictions = []
            if self.constrain_prog_gen_flag:
                prediction = predict_program_with_earley_correction(self.llm, _prompt, self.global_parser,
                                                                    max_tokens=self.max_tokens, llm_cache_dir=self.llm_cache_dir, freq_penalty=self.freq_penalty, seed=self.seed)
                ret_predictions.append(prediction)
            else:
                responses = self.llm.sample_completions(_prompt, self.temperature, stop_token="\n\n",
                                                        max_tokens=self.max_tokens, llm_cache_dir=self.llm_cache_dir, freq_penalty=self.freq_penalty, seed=self.seed)
                assert len(responses) == 1
                prediction = responses[0].response_text

                if self.use_linearized_tree_flag:
                    # recover the original program
                    logger.info("prediction before linearization: " + prediction)
                    if self.dataset == "Overnight-Blk":
                        prediction = linearized_tree_to_program(prediction, delimiter=" ")
                        prediction = remove_lf_space_overnight(prediction)
                    else:
                        prediction = linearized_tree_to_program(prediction)

                ret_predictions.append(prediction)

            _counter = collections.Counter(ret_predictions)
            logger.info("Summary:" + "-" * 80)
            logger.info(f"number of unique predictions from std prompt: {len(_counter)}")
            logger.info(f"frequency distribution of new predictions: {list(_counter.values())}")

            logger.info(f"    source:\n{input_example.source}")
            logger.info(f"prediction:\n{counter2pred(_counter)}")
            logger.info(f"    target:\n{input_example.target}")
            logger.info("-" * 80)
            return _prompt, _counter, None

        elif self.prompt_mode == "rot":
            """
                Args:
                    use_oracle_rule_flag: if True, use oracle rule to generate the prompt
                    constrain_rule_gen_flag: if True, constrain rule generation
                    constrain_prog_gen_flag: if True, constrain program generation
                    seperate_rule_gen_flag: if True, generate rule first, then program using different prompts
                    lazy_constrain_flag: sample k candidates first; if no candidate is valid, then use early two-stage generation
            """

            template_rule_prog_ex = self.prompt_template["exemplar"]
            template_rule_ex = self.prompt_template["rule_exemplar"]
            template_starts_wrule_pred = self.prompt_template["prediction"]
            template_prog_given_rule_pred = self.prompt_template["prediction_given_rule"]

            exemplars = self.retrieve_fn(input_example, self.train_examples)
            fewshot_rule_prompt = self.prompt_template["instruction"]
            fewshot_prog_prompt = self.prompt_template["instruction"][:]
            if not self.add_rule_list_flag:
                fewshot_rule_prompt += self.prompt_template["rule_instruction"]
                fewshot_prog_prompt += self.prompt_template["rule_instruction"]
            else:
                fewshot_rule_prompt += self.prompt_template["rule_instruction_with_rule_list"]
                fewshot_prog_prompt += self.prompt_template["rule_instruction"]
            for exemplar in exemplars:
                exemplar.grammar = lark2bnf(gen_min_lark(exemplar.target, self.global_parser))
                fewshot_rule_prompt += template_rule_prog_ex(exemplar)
                fewshot_prog_prompt += template_rule_prog_ex(exemplar)

            prompt_for_rule = fewshot_rule_prompt + template_starts_wrule_pred(input_example)
            try:
                if self.constrain_rule_gen_flag:
                    pred_bnf_grammar = predict_rules_with_earley_correction(self.llm, prompt_for_rule, self.global_rules, self.DELIMITER,
                                                                            max_tokens=self.max_tokens, llm_cache_dir=self.llm_cache_dir, freq_penalty=self.freq_penalty, seed=self.seed, use_action_list_flag=self.use_action_list_flag)
                else:
                    response = self.llm.sample_completions(prompt_for_rule, self.temperature, stop_token=self.DELIMITER,
                                                      max_tokens=self.max_tokens, llm_cache_dir=self.llm_cache_dir, freq_penalty=self.freq_penalty, seed=self.seed)[0]
                    pred_bnf_grammar = response.response_text
                pred_lark_grammar = bnf2lark(pred_bnf_grammar)

                if self.kg_rule_flag:
                    flag, action = self.__operation_extraction(input_example.source)
                    if flag:
                        lines = pred_lark_grammar.split("\n")
                        convert_lines = []
                        for line in lines:
                            if "action_name" == line[:11]:
                                convert_lines.append("action_name : \"\\\"" + action + "\\\"\"")
                            else:
                                convert_lines.append(line)
                        pred_lark_grammar = "\n".join(convert_lines)

                    flag, entity_list = self.__entity_extraction(input_example.source, ["reagent", "container", "device", "time", "temperature", "mass", "speed", "concentration", "volume", "length"])
                    if flag:
                        lines = pred_lark_grammar.split("\n")
                        convert_lines = []
                        for line in lines:
                            if "object" == line[:6]:
                                entity_list = " | ".join(["\"\\\""+a+"\\\"\"" for a in entity_list])
                                convert_lines.append("object : " + entity_list)
                            else:
                                convert_lines.append(line)
                        pred_lark_grammar = "\n".join(convert_lines)

                    logger.info(f"KG correction with grammar\n{pred_lark_grammar}")

                input_example.grammar = pred_bnf_grammar
                prompt_for_prog = fewshot_prog_prompt + template_prog_given_rule_pred(input_example)

                if self.constrain_prog_gen_flag:
                    try:
                        logger.info(f"earley correction with grammar\n{pred_lark_grammar}")
                        local_parser = EarleyParser(decorate_grammar(pred_lark_grammar),
                                                    start=self.global_parser.option.start)
                        input_example.grammar = pred_bnf_grammar
                        prompt_for_prog = fewshot_prog_prompt + template_prog_given_rule_pred(input_example)

                    except Exception as e:
                        logger.info(f"failed to create parser due to {e}, reverting to global parser")
                        local_parser = self.global_parser
                        input_example.grammar = lark2bnf(rulelist2larkstr(self.global_rules))
                        prompt_for_prog = fewshot_prog_prompt + template_prog_given_rule_pred(input_example)

                    pred_program = predict_program_with_earley_correction(self.llm, prompt_for_prog, local_parser,
                                                                          max_tokens=self.max_tokens, llm_cache_dir=self.llm_cache_dir, freq_penalty=self.freq_penalty, seed=self.seed)
                else:
                    resposne = self.llm.sample_completions(prompt_for_prog, self.temperature, stop_token="\n\n",
                                                      max_tokens=self.max_tokens, llm_cache_dir=self.llm_cache_dir, freq_penalty=self.freq_penalty, seed=self.seed)[0]
                    pred_program = resposne.response_text
            except Exception as e:
                raise e
                logger.info(f"failed to find prediction due to {e}")

                prompt_for_rule_prog = fewshot_rule_prompt + template_starts_wrule_pred(input_example)
                response = self.llm.sample_completions(prompt_for_rule_prog, self.temperature, stop_token="\n\n",
                                                  max_tokens=self.max_tokens, llm_cache_dir=self.llm_cache_dir, freq_penalty=self.freq_penalty, seed=self.seed)[0]
                try:
                    pred_bnf_grammar, pred_program = response.split(self.DELIMITER)
                    pred_lark_grammar = bnf2lark(pred_bnf_grammar)
                except:
                    logger.info(f"failed to find prediction from {response.response_text} due to {e}")
                    pred_lark_grammar, pred_program = None, None

            ret_grammars = [pred_lark_grammar]
            ret_predictions = [pred_program]

            # collect prompts and predictions
            used_prompts = []
            if "prompt_for_prog" in locals():
                used_prompts.append(prompt_for_prog)
            if "prompt_for_rule_prog" in locals():
                used_prompts.append(prompt_for_rule_prog)
            if "prompt_for_rule" in locals():
                used_prompts.append(prompt_for_rule)
            _pred_counter = collections.Counter(ret_predictions)
            _grammar_counter = collections.Counter(ret_grammars)

            logger.info("Summary:" + "-" * 80)
            logger.info(f"number of unique predictions: {len(_pred_counter)}")
            logger.info(f"frequency distribution of predictions: {list(_pred_counter.values())}")

            logger.info(f"    source:\n{input_example.source}")
            logger.info(f"prediction:\n{counter2pred(_pred_counter)}")
            logger.info(f"    target:\n{input_example.target}")
            logger.info(f"   grammar:\n{counter2pred(_grammar_counter)}")
            logger.info("-" * 80)

            return used_prompts, _pred_counter, _grammar_counter

    def __operation_extraction(self, sentense):
        prompt = self.operation_extraction_prompt.replace("------", sentense)
        resposne = self.llm.sample_completions(prompt, self.temperature, stop_token="\n\n",
                                               max_tokens=self.max_tokens, llm_cache_dir=self.llm_cache_dir,
                                               freq_penalty=self.freq_penalty, seed=self.seed)[0]
        action = resposne.response_text
        words = word_tokenize(action)
        tagged_words = pos_tag(words)
        if len(tagged_words) == 1:
            return True, tagged_words[0][0].lower()
        return False, None

    def __entity_extraction(self, sentense, datatype):
        corpora_feature = Corpora_Feature(datatype, self.llm)
        annotated_corpora = corpora_feature.data_annotate(sentense)
        format, origin, label = check_annotated_format(annotated_corpora)
        if format:
            return True, [ori for ori, lab in zip(origin, label) if lab in datatype]
        else:
            return False, []