import re
import numpy as np
import logging

from utils.utils import match_and_remove_first_occurrence, check_annotated_format

logger = logging.getLogger("global_logger")

class Corpora_Feature:
    def __init__(self, datatype, llm):
        self.examples = []
        self.datatype = datatype
        self.llm = llm

        with open("data/entity_extraction.txt", 'r') as file:
            self.entity_extraction_prompt = file.read()
        return

    def data_annotate(self, corpora, model="G"):
        stored_corpora = corpora[:]
        for _ in range(4):
            corpora = stored_corpora[:]
            prompt = self.entity_extraction_prompt.replace("%%%%%%", ", ".join(self.datatype)).replace("*-*-*-", corpora)
            resposne = self.llm.sample_completions(prompt)[0]
            annotated_corpora = resposne.response_text
            format, origin, label = check_annotated_format(annotated_corpora)
            if format and sum([1 for x in label if x in self.datatype]) == len(label):
                return annotated_corpora
            else:
                logger.info("format worng: " + annotated_corpora)
        return stored_corpora