import random

from llm_interface.chatgpt import ChatGPT

import os
import logging

from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import Chroma
from langchain.prompts import PromptTemplate
from langchain_openai import OpenAIEmbeddings
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from utils.data_loader import load_llm

from symbolic_compiler.compiler import SymbolicCompiler

logger = logging.getLogger("global_logger")

class LLM_RAG:
    def __init__(self, dataset, train_dataset, engine, RAG, embedding_model, context_num, compiler, chunk_size):
        self.dataset = dataset
        self.train_dataset = train_dataset

        self.engine = engine
        self.RAG = RAG
        self.llm = load_llm(self.engine)
        self.context_num = context_num
        self.compiler = compiler

        self.chunk_size = chunk_size

        if embedding_model == "text-embedding-ada-002":
            self.embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
        elif embedding_model == "bge-large-en":
            model_name = "BAAI/bge-large-en"
            model_kwargs = {'device': 'cpu'}
            encode_kwargs = {'normalize_embeddings': True}
            self.embeddings = HuggingFaceBgeEmbeddings(
                model_name=model_name,
                model_kwargs=model_kwargs,
                encode_kwargs=encode_kwargs
            )
        else:
            raise "wrong embedding model!"

        self.template = ""
        if RAG == None:
            with open("data/prompt_without_context.txt", "r") as file:
                self.template = file.read()
        else:
            with open("data/prompt_with_context.txt", "r") as file:
                self.template = file.read()
            self.__prepare_train_dataset()
        self.prompt_template = PromptTemplate.from_template(self.template)

        self.examples = {
            "Synthesis": {
                "Question": '<Add vessel="reactor_2" reagent="anyhydrouse DMF" volume="2 mL" /> <HeatChill vessel="reactor_2" temp="120 °C" time="24 h" stir="True" /> <HeatChillToTemp vessel="reactor_2" temp=<MASK> stir="True" /> <Add vessel="reactor_2" reagent="ethyl acetate" volume="100 mL" /> <Transfer from_vessel="reactor_2" to_vessel="separator" volume="120 mL" />',
                "Answer": "25 °C"
            },
            "BioEng": {
                "Question": '{"action": "spin", "output": "", "speed": ["500-1,000 x g"], "time": ["10 min"]} {"action": "add", "output": "", "volume": ["4 ml XBP buffer"], "reagent": ["supernatant"]} {"action": "add", "output": "flow-through", "reagent": ["<MASK>"], "container": ["exoEasy maxi spin column"]} {"action": "centrifuge", "output": "Qiazol", "volume": ["10 ml XWP"], "time": ["5 min"], "speed": ["5,000 x g"], "device": ["spin column"]} {"action": "add", "output": "", "reagent": ["Qiazol"]}',
                "Answer": "XBP mix sample"
            }
        }
        if self.dataset != "Synthesis":
            self.example = self.examples["BioEng"]
        else:
            self.example = self.examples[self.dataset]
        return


    def invoke(self, query, answer):
        if self.RAG == None:
            text = self.prompt_template.format(question=query, example="Question:\n```\n" + self.example["Question"] + "\n```\n\nAnswer: " + self.example["Answer"])
            context = []
        else:
            retrieved_docs = self.retriever.invoke(query)
            retrieved_strs = [s.page_content.replace("\n", " ") for s in retrieved_docs]
            if self.dataset in ["BioEng", "Medical", "Ecology", "Genetics"] and self.RAG == "DSL":
                complier_results = []
                for s in retrieved_strs:
                    logger.info("Compiling: " + s)
                    result = self.__comiple_nl2dsl(s)
                    complier_results.append(result)
                    logger.info("Compiled result: " + result)

                text = self.prompt_template.format(context=self.__format_strs([a+"\n"+b for a, b in zip(retrieved_strs, complier_results)]), question=query, example="Question:\n```\n"+self.example["Question"]+"\n```\n\nAnswer: "+self.example["Answer"])
                context = complier_results
            else:
                text = self.prompt_template.format(context=self.__format_strs(retrieved_strs), question=query, example="Question:\n```\n"+self.example["Question"]+"\n```\n\nAnswer: "+self.example["Answer"])
                context = retrieved_strs
        logger.info("prompt:")
        logger.info(text)
        responses = self.llm.sample_completions(text)
        return responses[0].response_text, context

    def __prepare_train_dataset(self):
        if self.dataset == "Synthesis":
            db_path = "llm_embedding/" + self.dataset + "_" + self.RAG + "_db" + "_" + str(self.chunk_size)
        else:
            db_path = "llm_embedding/" + self.dataset + "_db" + "_" + str(self.chunk_size)
        
        with open(db_path + ".txt" ,"w") as f:
            f.write(self.train_dataset)
        
        if not os.path.exists(db_path):
            loader = TextLoader(db_path + ".txt")
            docs = loader.load()
            text_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=self.chunk_size, chunk_overlap=0, add_start_index=True)
            splits = text_splitter.split_documents(docs)
            self.vectorstore = Chroma.from_documents(documents=splits, embedding=self.embeddings, persist_directory=db_path)
            self.retriever = self.vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": self.context_num})
        else:
            self.vectorstore = Chroma(embedding_function=self.embeddings, persist_directory=db_path)
            self.retriever = self.vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": self.context_num})
        return
    
    def __format_strs(self, strs):
        return "\n".join(doc for doc in strs)

    def __comiple_nl2dsl(self, example):
        prediction, _, _, _, _, _ = self.compiler.compile(example)
        prediction = prediction.replace("\n", " ")
        return prediction