import os
import re
import json
import numpy as np
import math

class DatasetDSL:

    def __init__(self, dataset, algorithm, context_len, query_scope):
        self.dataset = dataset
        self.algorithm=algorithm
        self.context_len=context_len
        self.query_scope=query_scope

        self.prepare_train_dataset()
        self.prepare_test_examples()

    def prepare_train_dataset(self):
        data = None
        if self.dataset == "Synthesis" and self.algorithm == "Ours":
            with open(f"Dataset/{self.dataset}/trainset_DSL.json", "r") as f:
                data = json.load(f)
        else:
            with open(f"Dataset/{self.dataset}/trainset_NL.json", "r") as f:
                data = json.load(f)
        total_length = sum(len(s) for sublist in data for s in sublist)
        total_count = sum(len(sublist) for sublist in data)

        self.avg_trainexample_len = math.ceil(total_length/total_count)
        self.train_dataset = "\n\n".join(["\n".join(procedure) for procedure in data])
        return

    def prepare_test_examples(self):
        with open(f"Dataset/{self.dataset}/testset_DSL.json", "r") as f:
            data = json.load(f)
        
        self.test_examples = {"query": [], "answer": []}
        for procedure in data:
            for index in range(len(procedure)):
                L = max(index - self.query_scope, 0)
                R = min(index + self.query_scope, len(procedure) - 1)
                sentence = procedure[index]["sentence"]
                answer = procedure[index]["answer"]
                for ans_index in range(len(answer)):
                    ans = answer[ans_index]
                    mask_pos = self.__find_nth_mask(sentence, ans_index+1)
                    
                    left_sent = "".join([a["sentence"] + " " for a in procedure[L: index]]) + sentence[:mask_pos]
                    left_ans = sum([a["answer"] for a in procedure[L: index]], start=[]) + answer[:ans_index]
                    right_sent = sentence[mask_pos+6:]  + " " + "".join([a["sentence"] + " " for a in procedure[index+1: R+1]])
                    right_ans = answer[ans_index+1:] + sum([a["answer"] for a in procedure[index+1: R+1]], start=[])
                    query = self.__fill_masks(left_sent, left_ans) + "<MASK>" + self.__fill_masks(right_sent, right_ans)

                    self.test_examples["query"].append(query)
                    self.test_examples["answer"].append(ans)
        return
    
    def __find_nth_mask(self, string, n):
        pos = -1
        count = 0

        while count < n:
            pos = string.find('<MASK>', pos + 1)
            if pos == -1:
                raise "wrong testdata!"
            count += 1
        return pos

    def __fill_masks(self, sentence, replacements):
        parts = sentence.split("<MASK>")
        if len(parts) - 1 != len(replacements):
            print(sentence)
            print(replacements)
            raise "wrong! The number of <MASK>s and replacements do not match."
        
        result = parts[0]
        for part, replacement in zip(parts[1:], replacements):
            result += replacement + part
        
        return result