import os
import json
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from openai import OpenAI
from collections import defaultdict
from sklearn.cluster import KMeans, AgglomerativeClustering

from prompt_text import prompts, GEval_prompts
from extract_examples import *
from utils import your_openai_key, create_model, map_answer_to_label, compute_cosine_similarity, set_random_seed, get_text_embeddings

ROOT_PATH = "../data"
RESULT_PATH = "../results"
set_random_seed(2024)

# few-shot for annotation
def few_shot_train_data(train_data, target_value, args):
    """
    train_data: pandas data of train_data, (scenario, action, value, value_details, label)
    """
    examples = train_data[train_data["value"] == target_value]
    if args.prompt_method == "few_shot_random":
        selected = examples.sample(n=args.few_shot_num)

    # add the selected examples to the prompt
    example_str = ""
    for idx, row in selected.iterrows():
        if args.dataset == "moral_choice":
            example_str += f"[Example {idx+1}]\n\
            Scenario: {row['scenario']}\n\
            Action: {row['action']}\n\
            Decision: {row['label']}\n\n"
        elif args.dataset == "denevil":
            example_str += f"[Example {idx+1}]\n\
            Scenario: {row['scenario']}\n\
            Decision: {row['label']}\n\n"
        elif args.dataset == "beavertails":
            example_str += f"[Example {idx+1}]\n\
            Question: {row['scenario']}\n\
            Response: {row['action']}\n\
            Decision: {row['label']}\n\n"
        elif args.dataset == "value_fulcra":
            example_str += f"[Example {idx+1}]\n\
            Scenario: {row['scenario']}\n\
            Action: {row['action']}\n\
            Decision: {row['label']}\n\n"

    example_str = example_str.replace("    ", "").replace("\t", "")
    return example_str  # return as pandas data

# generate evaluation steps for G-Eval method
def geval_generate_cot(args):
    model = create_model(args.model_name, args)
    template = GEval_prompts[args.dataset]
    responses = model.get_top_p_answer(
        prompt_base = [template],
        prompt_system = [""],
        max_tokens=args.eval_max_tokens,
        temperature = args.eval_temp,
        top_p = args.eval_top_p,
    )
    print("G-Eval steps: ", responses[0]["answer"])

# iterative in-context learning based on training examples
class AllureCalibration:
    def __init__(self, args, train_data):
        self.args = args
        self.train_data = train_data
        self.example_memory = {}
        self.model = create_model(args.model_name, args)

        if not os.path.exists(f"{RESULT_PATH}/{args.dataset}/allure"):
            os.makedirs(f"{RESULT_PATH}/{args.dataset}/allure")
        
        if os.path.exists(f"{RESULT_PATH}/{args.dataset}/allure/allure_{args.train_split}.json"):
            with open(f"{RESULT_PATH}/{args.dataset}/allure/allure_{args.train_split}.json", "r") as f:
                self.example_memory = json.load(f)
    
    def filter_error_examples(self, target_value):
        template = prompts[self.args.dataset]["vanilla_prompt"]
        
        print(f"Filtering error examples for value {target_value}...")
        self.example_memory[target_value] = []
        train_examples = self.train_data[self.train_data["value"] == target_value]
        train_examples = train_examples.sample(frac=1)

        for idx, row in train_examples.iterrows():
            prompt = template.format(value = target_value, scenario=row["scenario"], action=row["action"]).replace("    ", "").replace("\t", "")
            responses = self.model.get_top_p_answer(
                prompt_base = [prompt],
                prompt_system = [""],
                max_tokens=self.args.eval_max_tokens,
                temperature = self.args.eval_temp,
                top_p = self.args.eval_top_p,
            )
            prediction = map_answer_to_label([responses[0]["answer"]])
            ground_truth = map_answer_to_label([row["label"]])
            if prediction != ground_truth:
                self.example_memory[target_value].append(row.tolist())
    
        print(f"#number of error examples for value {target_value}: {len(self.example_memory[target_value])}")
        return self.example_memory[target_value][:]
    
    def incontext_learning_samples(self, target_value):
        examples = self.example_memory[target_value]
        selected = random.sample(examples, min(self.args.few_shot_num, len(examples)))
        while len(selected) < self.args.few_shot_num: # if not enough error examples, add random examples
            addition = self.train_data[self.train_data["value"] == target_value].sample(1).iloc[0].tolist()
            if addition not in selected:
                selected.append(addition)
        example_str = ""
        for idx, (scenario, action, _, _, label) in enumerate(selected):
            if self.args.dataset == "moral_choice":
                example_str += f"[Example {idx+1}]\n\
                Scenario: {scenario}\n\
                Action: {action}\n\
                Decision: {label}\n\n"
            if self.args.dataset == "beavertails":
                example_str += f"[Example {idx+1}]\n\
                Question: {scenario}\n\
                Response: {action}\n\
                Decision: {label}\n\n"
            if self.args.dataset == "value_fulcra":
                example_str += f"[Example {idx+1}]\n\
                Question: {scenario}\n\
                Response: {action}\n\
                Decision: {label}\n\n"
            if self.args.dataset == "denevil":
                example_str += f"[Example {idx+1}]\n\
                Scenario: {scenario}\n\
                Decision: {label}\n\n"
        example_str = example_str.replace("    ", "").replace("\t", "")
        return example_str

# calibration of prompt based on examples
class PromptCalibration:
    def __init__(self, args, train_data, target_value):
        self.args = args
        self.train_data = train_data
        self.target_value = target_value
        self.model = create_model(args.model_name, args)
        print("calibrating for value: ", target_value)

        if not os.path.exists(f"{RESULT_PATH}/{args.dataset}/calibration"):
            os.makedirs(f"{RESULT_PATH}/{args.dataset}/calibration")
    
    def calibrate_draft(self):
        self.calibration_draft = set()
        train_examples = self.train_data[self.train_data["value"] == self.target_value]
        print(f"#number of examples for calibration: {len(train_examples)}")
        for few_shot_num in self.args.few_shot_list:
            for _ in range(self.args.monte_carlo_trial):
                selected = train_examples.sample(n=few_shot_num)
                example_str = ""
                for idx, row in selected.iterrows():
                    if self.args.dataset == "moral_choice":
                        example_str += f"[Example {idx+1}]\n\
                        Scenario: {row['scenario']}\n\
                        Action: {row['action']}\n\
                        Decision: {row['label']}\n\n"
                    if self.args.dataset == "beavertails":
                        example_str += f"[Example {idx+1}]\n\
                        Question: {row['scenario']}\n\
                        Response: {row['action']}\n\
                        Decision: {row['label']}\n\n"
                    if self.args.dataset == "value_fulcra":
                        example_str += f"[Example {idx+1}]\n\
                        Question: {row['scenario']}\n\
                        Response: {row['action']}\n\
                        Decision: {row['label']}\n\n"
                    if self.args.dataset == "denevil":
                        example_str += f"[Example {idx+1}]\n\
                        Scenario: {row['scenario']}\n\
                        Decision: {row['label']}\n\n"
                self.target_value_details = selected["value_details"].tolist()[0] if self.args.dataset != "moral_choice" else selected["value"].tolist()[0]
                current_prompt = prompts[self.args.dataset]["prompt_calibration_draft"].format(value=self.target_value_details, examples=example_str).replace("    ", "").replace("\t", "")
                responses = self.model.get_top_p_answer(
                    prompt_base = [current_prompt],
                    prompt_system = [""],
                    max_tokens=self.args.eval_max_tokens,
                    temperature = self.args.eval_temp,
                    top_p = self.args.eval_top_p,
                )
                refined_prompt = responses[0]["answer"].split(":", 1)[-1].strip()
                self.calibration_draft.add(refined_prompt)
        
        self.calibration_draft = list(self.calibration_draft)
        for idx in range(len(self.calibration_draft)):
            print(f"Draft {idx}: {self.calibration_draft[idx]}")

    def compute_calibration_score(self, refined_value):
        train_examples = self.train_data[self.train_data["value"] == self.target_value]
        template = prompts[self.args.dataset]["vanilla_prompt"]
        results = []

        for idx, row in train_examples.iterrows():
            prompt = template.format(value = refined_value, scenario=row["scenario"], action=row["action"]).replace("    ", "").replace("\t", "")

            responses = self.model.get_top_p_answer(
                prompt_base = [prompt],
                prompt_system = [""],
                max_tokens=self.args.eval_max_tokens,
                temperature = self.args.eval_temp,
                top_p = self.args.eval_top_p,
            )

            result_base = {
                "index": idx,
                "context": row["scenario"],
                "action": row["action"],
                "value": refined_value,
                "prompt": prompt,
                "label": row["label"],
                "model_id": self.args.model_name,
                "eval_repeats": 1,
            }
            result = {**result_base, **responses[0]}
            
            results.append(result)
        results_df = pd.DataFrame(results)
        results_df.to_csv(f"{RESULT_PATH}/{self.args.dataset}/calibration/{self.target_value}_calibrate_draft.csv")
        return results_df

    def calculate_acc(self, results):
        prediction = map_answer_to_label(results["answer"].tolist())
        ground_truth = map_answer_to_label(results["label"].tolist())
        labels = [1 if p == g else 0 for p, g in zip(prediction, ground_truth)]
        accuracy = sum(labels) / len(labels)
        print(f"Accuracy on {self.target_value}: {accuracy}")

        return labels, accuracy

    def calibrate_revisit(self):
        misaligned_marks = []
        misaligned_examples = []
        self.top_calibration_draft = []
        train_examples = self.train_data[self.train_data["value"] == self.target_value]
        
        for refined_value in self.calibration_draft:
            results = self.compute_calibration_score(refined_value)
            labels, accuracy = self.calculate_acc(results)
            if misaligned_marks == []:
                misaligned_marks = labels
            else:
                misaligned_marks = [m + l for m, l in zip(misaligned_marks, labels)] # larger value means less misaligned
            self.top_calibration_draft.append((refined_value, accuracy))
            
        # filter out examples with top-k misaligned marks
        marks_df = pd.DataFrame(misaligned_marks, columns=["mark"])  # only one column
        misaligned_examples = pd.concat([train_examples, marks_df], axis=1).sort_values(by="mark", ascending=True).head(min(20, len(misaligned_marks)))
        self.top_calibration_draft = sorted(self.top_calibration_draft, key=lambda x: x[1], reverse=True)

        if len(misaligned_examples) == 0:
            print("No misaligned examples found")
            self.calibration_revisit = random.sample(self.calibration_draft[:], 1)
            return

        self.calibration_revisit = set()
        for refined_value, _ in self.top_calibration_draft:
            for _ in range(self.args.monte_carlo_trial):
                selected = misaligned_examples.sample(n=self.args.few_shot_list[0])
                example_str = ""
                for idx, row in selected.iterrows():
                    if self.args.dataset == "moral_choice":
                        example_str += f"[Example {idx+1}]\n\
                        Scenario: {row['scenario']}\n\
                        Action: {row['action']}\n\
                        Decision: {row['label']}\n\n"
                    if self.args.dataset == "beavertails":
                        example_str += f"[Example {idx+1}]\n\
                        Question: {row['scenario']}\n\
                        Response: {row['action']}\n\
                        Decision: {row['label']}\n\n"
                    if self.args.dataset == "value_fulcra":
                        example_str += f"[Example {idx+1}]\n\
                        Question: {row['scenario']}\n\
                        Response: {row['action']}\n\
                        Decision: {row['label']}\n\n"
                    if self.args.dataset == "denevil":
                        example_str += f"[Example {idx+1}]\n\
                        Scenario: {row['scenario']}\n\
                        Decision: {row['label']}\n\n"
                current_prompt = prompts[self.args.dataset]["prompt_calibration_revisit"].format(value=refined_value, examples=example_str).replace("    ", "").replace("\t", "")
                responses = self.model.get_top_p_answer(
                    prompt_base = [current_prompt],
                    prompt_system = [""],
                    max_tokens=self.args.eval_max_tokens,
                    temperature = self.args.eval_temp,
                    top_p = self.args.eval_top_p,
                )
                refined_value = responses[0]["answer"]
            self.calibration_revisit.add(refined_value)

        self.calibration_revisit = list(self.calibration_revisit)
        for idx in range(len(self.calibration_revisit)):
            print(f"Revisit {idx}: {self.calibration_revisit[idx]}")

    def calibrate_apply(self):
        self.calibration_results = []
        for refined_value in self.calibration_revisit:
            results = self.compute_calibration_score(refined_value)
            labels, accuracy = self.calculate_acc(results)
            self.calibration_results.append((refined_value, accuracy))
        
        self.calibration_results = sorted(self.calibration_results, key=lambda x: x[1], reverse=True)
        for idx in range(len(self.calibration_results)):
            print(f"Apply {idx}: {self.calibration_results[idx]}")
        return self.calibration_results[0][0]

# extract key concepts for value evaluation based on training examples
class KeyConceptExtraction:
    def __init__(self, args, train_data, test_data, target_value, key_concepts_init=None, key_concepts_extract=None, key_concepts_all=None):
        self.args = args
        self.train_data = train_data
        self.test_data = test_data
        self.target_value = target_value
        self.model = create_model(args.model_name, args)

        self.key_concepts_init = key_concepts_init
        self.key_concepts_extract = key_concepts_extract
        self.key_concepts_all = key_concepts_all
        self.api_key = your_openai_key
        if self.args.dataset == "beavertails":
            self.extract_example = beavertails_extract_examples
            self.inferred_example = beavertails_infer_examples
        elif self.args.dataset == "value_fulcra":
            self.extract_example = value_fulcra_extract_examples
            self.inferred_example = value_fulcra_infer_examples
        elif self.args.dataset == "denevil":
            self.extract_example = denevil_extract_examples
            self.inferred_example = denevil_infer_examples
    
    def extract_key_concepts_separate(self):
        train_examples = self.train_data[self.train_data["value"] == self.target_value]
        train_examples = train_examples.sample(frac=1)  # shuffle the examples
        value_details = train_examples["value_details"].tolist()[0] if self.args.dataset != "moral_choice" else train_examples["value"].tolist()[0]
        print(f"#number of training samples for value {self.target_value}: {len(train_examples)}")

        self.key_concepts = []
        self.examples_2_concepts = defaultdict(list)
        extract_example = self.extract_example[self.target_value] # demonstrations in the prompt

        cluster_labels = self.cluster_train_data(train_examples)
        print("#clusters", max(cluster_labels)+1, "cluster labels: ", cluster_labels, "\n\n")

        # reorder the examples based on the cluster labels
        for cluster in range(max(cluster_labels)+1):
            cluster_examples = train_examples[cluster_labels == cluster]
            # if there are two many examples in one cluster, split them into several batches
            for batch_id in range(max(len(cluster_examples) // self.args.batch_size, 1)):
                example_str = ""
                if batch_id == len(cluster_examples) // self.args.batch_size - 1:
                    batch_examples = cluster_examples[batch_id*self.args.batch_size:]
                else:
                    batch_examples = cluster_examples[batch_id*self.args.batch_size: (batch_id+1)*self.args.batch_size]
                for idx, (_, row) in enumerate(batch_examples.iterrows()):
                    if self.args.dataset == "beavertails":
                        example_str += f"[Example {idx+1}]\n\
                        Question: {row['scenario']}\n\
                        Response: {row['action']}\n\
                        Decision: {row['label']}\n\n"
                    if self.args.dataset == "value_fulcra":
                        example_str += f"[Example {idx+1}]\n\
                        Scenario: {row['scenario']}\n\
                        Action: {row['action']}\n\
                        Decision: {row['label']}\n\n"
                    if self.args.dataset == "denevil":
                        example_str += f"[Example {idx+1}]\n\
                        Scenario: {row['scenario']}\n\
                        Decision: {row['label']}\n\n"
                prompt = prompts[self.args.dataset]["key_concepts_extraction_separate"].format(value=value_details, value_simple=self.target_value, examples=example_str, extract_example=extract_example).replace("    ", "").replace("\t", "")

                responses = self.model.get_top_p_answer(
                    prompt_base = [prompt],
                    prompt_system = [""],
                    max_tokens=self.args.eval_max_tokens,
                    temperature = self.args.eval_temp,
                    top_p = self.args.eval_top_p,
                )
                batch_key_concepts = responses[0]["answer"].split("\n")

                # only keep lines starting with a index number
                batch_key_concepts = [x for x in batch_key_concepts if x != "" and x[0].isdigit()]

                self.key_concepts.extend(batch_key_concepts)

                example_concepts = [[] for _ in range(len(batch_examples))]
                for cidx, concept in enumerate(batch_key_concepts):
                    for idx in range(len(batch_examples)):
                        if f"example {idx+1}" in concept:
                            example_concepts[idx].append(concept.split("(example")[0].strip())
                for idx in range(len(batch_examples)):
                    example = "Scenario: {}\tAction: {}\tValue: {}\tDecision: {}".format(batch_examples.iloc[idx]["scenario"], batch_examples.iloc[idx]["action"], batch_examples.iloc[idx]["value"], batch_examples.iloc[idx]["label"])
                    self.examples_2_concepts[example] = example_concepts[idx][:]
        
        self.key_concepts = [x.split("(example")[0].strip() for x in self.key_concepts]  
        return self.key_concepts, self.examples_2_concepts
    
    # in the above extraction process, some examples corresponds to no key concepts, we need to re-extract key concepts for these examples
    def extract_key_concepts_individual(self, row):
        value = row["value"]
        value_details = row["value_details"] if self.args.dataset != "moral_choice" else row["value"]
        extract_example = self.extract_example[value]
        example_str = ""
        if self.args.dataset == "beavertails":
            example_str += f"[Example 1]\n\
            Question: {row['scenario']}\n\
            Response: {row['action']}\n\
            Decision: {row['label']}\n\n"
        if self.args.dataset == "value_fulcra":
            example_str += f"[Example 1]\n\
            Scenario: {row['scenario']}\n\
            Action: {row['action']}\n\
            Decision: {row['label']}\n\n"
        if self.args.dataset == "denevil":
            example_str += f"[Example 1]\n\
            Scenario: {row['scenario']}\n\
            Decision: {row['label']}\n\n"
        prompt = prompts[self.args.dataset]["key_concepts_extraction_separate"].format(value=value_details, value_simple=value, examples=example_str, extract_example=extract_example).replace("    ", "").replace("\t", "")
        responses = self.model.get_top_p_answer(
            prompt_base = [prompt],
            prompt_system = [""],
            max_tokens=self.args.eval_max_tokens,
            temperature = self.args.eval_temp,
            top_p = self.args.eval_top_p,
        )
        one_key_concepts = responses[0]["answer"].split("\n")
        # only keep lines starting with a index number
        one_key_concepts = [x for x in one_key_concepts if x!="" and x[0].isdigit()]
        one_key_concepts = [x.split("(example")[0].strip() for x in one_key_concepts]
        return one_key_concepts

    def encode_key_concepts(self, key_concepts):
        client = OpenAI(api_key=self.api_key)
        key_concepts_embeddings = {}
        for value in key_concepts:
            print("Encoding key concepts for value: ", value)
            key_concepts_embeddings[value] = []
            for idx, concept in enumerate(key_concepts[value]):
                if self.args.with_result == "False":
                    concept = concept.split("-->", 1)[0].strip()
                embed = client.embeddings.create(input = [concept], model="text-embedding-3-small").data[0].embedding
                key_concepts_embeddings[value].append({"concept": concept, "embedding": embed})
        return key_concepts_embeddings

    def integrate_key_concepts(self, key_concepts_embeddings):
        integrate_key_concepts_embedding = {}
        concept_integrate_dict = {}
        client = OpenAI(api_key=self.api_key)

        for value in key_concepts_embeddings:
            print("Integrating key concepts for value: ", value)
            concepts = [x["concept"] for x in key_concepts_embeddings[value]]
            embeddings = [x["embedding"] for x in key_concepts_embeddings[value]]

            X = np.array(embeddings)
            if self.args.with_result == "False":
                clusterer = AgglomerativeClustering(metric="cosine", linkage="complete", distance_threshold=0.4, n_clusters=None)
            else:
                clusterer = AgglomerativeClustering(metric="cosine", linkage="complete", distance_threshold=0.3, n_clusters=None)  # with result, the similarity would be higher
            clusterer.fit(X)
            labels = clusterer.labels_
            print(f"#raw concepts: {len(concepts)}, #clusters: {max(labels)+1}\n")

            cosine_sim_matrix = compute_cosine_similarity(X, X)  # [sample_num, sample_num]
            integrate_key_concepts_embedding[value] = []
            concept_integrate_dict[value] = {}
            for label in range(max(labels)+1):
                concept_index = [idx for idx in range(len(labels)) if labels[idx] == label]
                cluster_concepts = [concepts[idx] for idx in concept_index]
                cluster_embeddings = [embeddings[idx] for idx in concept_index]

                if len(cluster_concepts) <= 1:  # single concept cluster
                    integrate_key_concepts_embedding[value].append({"concept": cluster_concepts[0], "embedding": cluster_embeddings[0]})

                elif len(cluster_concepts) == 2:  # two concepts cluster
                    represent_concept = cluster_concepts[0] if len(cluster_concepts[0]) > len(cluster_concepts[1]) else cluster_concepts[1]  # shorter concept could be more general
                    represent_embedding = cluster_embeddings[0] if len(cluster_concepts[0]) > len(cluster_concepts[1]) else cluster_embeddings[1]
                    integrate_key_concepts_embedding[value].append({"concept": represent_concept, "embedding": represent_embedding})
                    
                else:
                    cluster_sim_matrix = cosine_sim_matrix[labels == label][:, labels == label]
                    cluster_sim_mean = np.mean(cluster_sim_matrix, axis=0)
                    represent_idx = np.argmax(cluster_sim_mean)
                    integrate_key_concepts_embedding[value].append({"concept": cluster_concepts[represent_idx], "embedding": cluster_embeddings[represent_idx]})

                    ### integrate key concepts with GPT
                    # concepts_list_str = []
                    # for cidx, concept in enumerate(cluster_concepts):
                    #     concept = concept.split(". ", 1)[-1]
                    #     concepts_list_str.append(f"{cidx+1}. {concept}")
                    # concepts_list_str = "\n".join(concepts_list_str)
                    # prompt = prompts[self.args.dataset]["key_concepts_integrate"].format(value_simple=value, key_principles=concepts_list_str).replace("    ", "").replace("\t", "")
                    # responses = self.model.get_top_p_answer(
                    #     prompt_base = [prompt],
                    #     prompt_system = [""],
                    #     max_tokens=self.args.eval_max_tokens,
                    #     temperature = self.args.eval_temp,
                    #     top_p = self.args.eval_top_p,
                    # )
                    # repsensent_concept = responses[0]["answer"].split("\n")[0]
                    # represent_embedding = client.embeddings.create(input = [represent_concept], model="text-embedding-3-small").data[0].embedding
                    # integrate_key_concepts_embedding[value].append({"concept": represent_concept, "embedding": represent_embedding})

                
                for concept in cluster_concepts:
                    concept = concept.split(". ", 1)[-1]
                    concept_integrate_dict[value][concept] = len(integrate_key_concepts_embedding[value]) - 1
                
                # for index in concept_index:
                #     concept_integrate_dict[value][index] = len(integrate_key_concepts_embedding[value]) - 1

            # 重新编号
            for cidx in range(len(integrate_key_concepts_embedding[value])):
                concept = integrate_key_concepts_embedding[value][cidx]["concept"].split(". ", 1)[-1]
                integrate_key_concepts_embedding[value][cidx]["concept"] = f"{cidx + 1}. {concept}"
        return integrate_key_concepts_embedding, concept_integrate_dict
    
    def infer_key_concepts(self, inferred_data, store_path, inferred_concepts = [], train=False):
        self.inferred_concepts = inferred_concepts
        for idx, row in tqdm(inferred_data.iterrows(), desc="infer key concepts"):
            if idx < len(self.inferred_concepts):
                continue
            if idx >= self.args.eval_num_samples and train==False:
                break
            value = row["value"]
            value_details = row["value_details"]
            inferred_example = self.inferred_example[value]

            prompt = prompts[self.args.dataset]["key_concepts_inference"].format(value=value_details, value_simple = value, infer_example=inferred_example, scenario=row["scenario"], action=row["action"]).replace("    ", "").replace("\t", "")
            responses = self.model.get_top_p_answer(
                prompt_base = [prompt],
                prompt_system = [""],
                max_tokens=self.args.eval_max_tokens,
                temperature = self.args.eval_temp,
                top_p = self.args.eval_top_p,
            )
            one_inferred_concepts = responses[0]["answer"].split("\n")
            one_inferred_concepts = [x for x in one_inferred_concepts if x!="" and x[0].isdigit()]
            self.inferred_concepts.append(one_inferred_concepts[:])
            if idx % 30 == 0:
                with open(store_path, "w") as f:
                    json.dump(self.inferred_concepts, f)
        
        return self.inferred_concepts
    
    def embed_mapping_key_concepts(self, key_concepts_embed, inferred_concepts, store_path, mapped_concepts = [], train=False):
        self.mapped_concepts = mapped_concepts
        client = OpenAI(api_key=self.api_key)

        for idx, one_inferred_concept in tqdm(enumerate(inferred_concepts), desc="embed mapping key concepts"):
            if idx < len(self.mapped_concepts):
                continue
            value = self.test_data.iloc[idx]["value"]
            value_concepts_embed = [concept["embedding"] for concept in key_concepts_embed[value]]
            one_mapped_concept = []
            one_inferred_embeds = []
            for concept in one_inferred_concept:
                if concept == "" or concept is None:
                    continue
                if self.args.with_result == "False":
                    concept = concept.split("-->", 1)[0].strip()
                if concept == "" or concept is None:
                    continue
                embed = client.embeddings.create(input = [concept], model="text-embedding-3-small").data[0].embedding
                similarity = compute_cosine_similarity(embed, value_concepts_embed)
                most_relevant_concepts = [key_concepts_embed[value][i]["concept"] for i in similarity.argsort()[-self.args.top_concept_num:][::-1]]
                one_inferred_embeds.append({"concept": concept, "embedding": embed, "most_relevant_concepts": most_relevant_concepts})
                one_mapped_concept.extend(most_relevant_concepts)
            one_mapped_concept = list(set(one_mapped_concept))
            self.mapped_concepts.append([one_inferred_embeds[:], one_mapped_concept[:]])

            if idx % 10 == 0:
                for lidx, (_, concepts) in enumerate(self.mapped_concepts):
                    for cidx, concept in enumerate(concepts):
                        if concept.startswith(f"{cidx+1}. "):
                            continue
                        concept = concept.split(". ", 1)[-1]
                        self.mapped_concepts[lidx][1][cidx] = f"{cidx + 1}. {concept}"
                
                with open(store_path, "w") as f:
                    json.dump(self.mapped_concepts, f)
            
            if idx >= self.args.eval_num_samples and train==False:
                break
    
        # 重新编号
        for lidx, (_, concepts) in enumerate(self.mapped_concepts):
            for cidx, concept in enumerate(concepts):
                if concept.startswith(f"{cidx+1}. "):
                    continue
                concept = concept.split(". ", 1)[-1]
                self.mapped_concepts[lidx][1][cidx] = f"{cidx + 1}. {concept}"
        
        return self.mapped_concepts

    ## when finding similar concepts, map to them, otherwise, keep the inferred concepts
    def mixed_mapping_key_concepts(self, key_concepts_embed, embed_mapped_concepts, store_path, mapped_concepts = [], train=False):
        self.mapped_concepts = mapped_concepts

        use_inferred_count = 0
        for idx, (one_inferred_embeds, _) in tqdm(enumerate(embed_mapped_concepts), desc="mixed mapping key concepts"):
            if idx < len(self.mapped_concepts):
                continue
            value = self.test_data.iloc[idx]["value"]
            value_details = self.test_data.iloc[idx]["value_details"]
            value_concepts_embed = [concept["embedding"] for concept in key_concepts_embed[value]]
            one_mapped_concept = []

            use_inferred = False
            for concept_embed in one_inferred_embeds:
                embed = concept_embed["embedding"]
                similarity = compute_cosine_similarity(embed, value_concepts_embed) # [key_concept_num]
                most_similar_idx = np.argsort(similarity)[-self.args.top_concept_num:][::-1] # [top_concept_num]
                valid_relevant_concepts = [key_concepts_embed[value][i]["concept"] for i in most_similar_idx if similarity[i] > 0.55]
                if len(valid_relevant_concepts) == 0:
                    one_mapped_concept.append(concept_embed["concept"])
                    use_inferred = True
                else:
                    one_mapped_concept.extend(valid_relevant_concepts)
            
            if use_inferred:
                use_inferred_count += 1
            one_mapped_concept = list(set(one_mapped_concept))
            self.mapped_concepts.append([use_inferred, one_mapped_concept[:]])

            if idx % 10 == 0:
                for lidx, (_, concepts) in enumerate(self.mapped_concepts):
                    for cidx, concept in enumerate(concepts):
                        if concept.startswith(f"{cidx+1}. "):
                            continue
                        concept = concept.split(". ", 1)[-1]
                        self.mapped_concepts[lidx][1][cidx] = f"{cidx + 1}. {concept}"
                
                with open(store_path, "w") as f:
                    json.dump(self.mapped_concepts, f)
            
            if idx >= self.args.eval_num_samples and train==False:
                break
        # print("#use inferred concepts: ", use_inferred_count, ", #total examples: ", len(embed_mapped_concepts))
        return self.mapped_concepts


    def embed_mapping_key_concepts_all(self, key_concepts_embed, embed_mapped_concepts, store_path, mapped_concepts = [], train=False):
        self.mapped_concepts = mapped_concepts
        all_concepts_embed, all_concepts = [], []
        for value in key_concepts_embed:
            for concept in key_concepts_embed[value]:
                all_concepts_embed.append(concept["embedding"])
                all_concepts.append(concept["concept"])
        
        for idx, (one_inferred_embeds, _) in tqdm(enumerate(embed_mapped_concepts), desc="embed mapping to all key concepts"):
            if idx < len(self.mapped_concepts):
                continue
            one_mapped_concept = []
            for concept_embed in one_inferred_embeds:
                embed = concept_embed["embedding"]
                similarity = compute_cosine_similarity(embed, all_concepts_embed)
                most_relevant_concepts = [all_concepts[i] for i in similarity.argsort()[-self.args.top_concept_num:][::-1]]
                one_mapped_concept.extend(most_relevant_concepts[:])
            
            one_mapped_concept = list(set(one_mapped_concept))
            self.mapped_concepts.append([[], one_mapped_concept[:]])
            
            if idx % 10 == 0:
                for lidx, (_, concepts) in enumerate(self.mapped_concepts):
                    for cidx, concept in enumerate(concepts):
                        if concept.startswith(f"{cidx+1}. "):
                            continue
                        concept = concept.split(". ", 1)[-1]
                        self.mapped_concepts[lidx][1][cidx] = f"{cidx + 1}. {concept}"
                
                with open(store_path, "w") as f:
                    json.dump(self.mapped_concepts, f)
            
            if idx >= self.args.eval_num_samples and train==False:
                break
        return self.mapped_concepts

    def llm_mapping_key_concepts(self, key_concepts_embed, embed_mapped_concepts, store_path, mapped_concepts = [], train=False):
        self.mapped_concepts = mapped_concepts

        for idx, (one_inferred_embeds, _) in tqdm(enumerate(embed_mapped_concepts), desc="llm mapping key concepts"):
            if idx < len(self.mapped_concepts):
                continue
            value = self.test_data.iloc[idx]["value"]
            value_details = self.test_data.iloc[idx]["value_details"]
            value_concepts_embed = [concept["embedding"] for concept in key_concepts_embed[value]]
            one_mapped_concept = []
            for concept_embed in one_inferred_embeds:
                embed = concept_embed["embedding"]
                similarity = compute_cosine_similarity(embed, value_concepts_embed)
                most_relevant_concepts = [key_concepts_embed[value][i]["concept"] for i in similarity.argsort()[-10:][::-1]]

                # 重新编号
                for cidx, concept in enumerate(most_relevant_concepts):
                    concept = concept.split(". ", 1)[-1]
                    most_relevant_concepts[cidx] = f"{cidx + 1}. {concept}"

                prompt = prompts[self.args.dataset]["key_concepts_mapping_llm"].format(value=value_details, extracted_feature = concept_embed["concept"], key_principles = "\n".join(most_relevant_concepts)).replace("    ", "").replace("\t", "")
                responses = self.model.get_top_p_answer(
                    prompt_base = [prompt],
                    prompt_system = [""],
                    max_tokens=self.args.eval_max_tokens,
                    temperature = self.args.eval_temp,
                    top_p = self.args.eval_top_p,
                )
                mapped_concepts = responses[0]["answer"].split("\n")

                related_concepts = []
                for concept in mapped_concepts:
                    if "[Related]" in concept:
                        related_concepts.append("{}. {}".format(len(related_concepts) + 1, concept.split("[Related]")[1].strip()))
                
                one_mapped_concept.extend(related_concepts[:])

            one_mapped_concept = list(set(one_mapped_concept))
            self.mapped_concepts.append([[], one_mapped_concept[:]])

            if idx % 10 == 0:
                for lidx, (_, concepts) in enumerate(self.mapped_concepts):
                    for cidx, concept in enumerate(concepts):
                        if concept.startswith(f"{cidx+1}. "):
                            continue
                        concept = concept.split(". ", 1)[-1]
                        self.mapped_concepts[lidx][1][cidx] = f"{cidx + 1}. {concept}"
                
                with open(store_path, "w") as f:
                    json.dump(self.mapped_concepts, f)
            
            if idx >= self.args.eval_num_samples and train==False:
                break
    
    def llm_mapping_key_concepts_all(self, key_concepts_embed, embed_mapped_concepts, store_path, mapped_concepts = [], train=False):
        self.mapped_concepts = mapped_concepts
        all_concepts_embed, all_concepts = [], []
        for value in key_concepts_embed:
            for concept in key_concepts_embed[value]:
                all_concepts_embed.append(concept["embedding"])
                all_concepts.append(concept["concept"])
        
        for idx, (one_inferred_embeds, _) in tqdm(enumerate(embed_mapped_concepts), desc="llm mapping key concepts to all concepts"):
            if idx < len(self.mapped_concepts):
                continue
            value_details = self.test_data.iloc[idx]["value_details"]
            
            one_mapped_concept = []
            for concept_embed in one_inferred_embeds:
                embed = concept_embed["embedding"]
                similarity = compute_cosine_similarity(embed, all_concepts_embed)
                most_relevant_concepts = [all_concepts[i] for i in similarity.argsort()[-10:][::-1]]

                # 重新编号
                for cidx, concept in enumerate(most_relevant_concepts):
                    concept = concept.split(". ", 1)[-1]
                    most_relevant_concepts[cidx] = f"{cidx + 1}. {concept}"

                prompt = prompts[self.args.dataset]["key_concepts_mapping_llm_all"].format(value = value_details, extracted_feature = concept_embed["concept"], key_principles = "\n".join(most_relevant_concepts)).replace("    ", "").replace("\t", "")
                responses = self.model.get_top_p_answer(
                    prompt_base = [prompt],
                    prompt_system = [""],
                    max_tokens=self.args.eval_max_tokens,
                    temperature = self.args.eval_temp,
                    top_p = self.args.eval_top_p,
                )
                mapped_concepts = responses[0]["answer"].split("\n")

                related_concepts = []
                for concept in mapped_concepts:
                    if "[Related]" in concept:
                        related_concepts.append("{}. {}".format(len(related_concepts) + 1, concept.split("[Related]")[1].strip()))
                
                one_mapped_concept.extend(related_concepts[:])

            one_mapped_concept = list(set(one_mapped_concept))
            self.mapped_concepts.append([[], one_mapped_concept[:]])

            if idx % 10 == 0:
                for lidx, (_, concepts) in enumerate(self.mapped_concepts):
                    for cidx, concept in enumerate(concepts):
                        if concept.startswith(f"{cidx+1}. "):
                            continue
                        concept = concept.split(". ", 1)[-1]
                        self.mapped_concepts[lidx][1][cidx] = f"{cidx + 1}. {concept}"
                
                with open(store_path, "w") as f:
                    json.dump(self.mapped_concepts, f)
        
            if idx >= self.args.eval_num_samples and train==False:
                break
        return self.mapped_concepts

    def cluster_train_data(self, train_data):
        embed_path = f"{ROOT_PATH}/{self.args.dataset}/train_embed.jsonl"
        train_data_embed = {}
        if not os.path.exists(embed_path):
            print(f"compute text embedding for training data {self.args.dataset}")
            get_text_embeddings(self.args.dataset)

        with open(embed_path, "r") as f:
            for line in f:
                data = json.loads(line)
                if self.args.dataset == "beavertails":
                    text = data["prompt"].lower() + " " + data["response"].lower()
                if self.args.dataset == "value_fulcra":
                    dialogue = data["dialogue"]
                    question = dialogue.split("Bob:")[0].split("Human:")[1].strip()
                    response = dialogue.split("Bob:")[1].strip()
                    text = question.lower() + " " + response.lower()
                if self.args.dataset == "denevil":
                    text = data["scene"].lower() + " "
                train_data_embed[text] = data["embed"]
        
        selected_embed = []
        for idx, row in train_data.iterrows():
            text = row["scenario"].lower() + " " + row["action"].lower()
            selected_embed.append(train_data_embed[text])
        
        # when batch size is small, do not cluster
        if self.args.batch_size <= 2 or self.args.cluster_train_data == "False":
            labels = []
            for idx in range(len(selected_embed)//self.args.batch_size + 1):
                labels.extend([idx] * self.args.batch_size)
            return np.array(labels[:len(selected_embed)])

        # cluster the selected embed data, using k-means
        print("Clustering the training data using k-means...")
        X = np.array(selected_embed)
        n_clusters = len(X) // self.args.batch_size
        kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init="auto").fit(X)
        return kmeans.labels_