import json
import random
random.seed(45)



def is_substring_refined(a, b):
    """
    Check if string a is a substring of string b with refined conditions.

    :param a: The string to search for.
    :param b: The string to search within.
    :return: True if a is a substring of b under refined conditions, False otherwise.
    """
    # Finding all occurrences of substring a in b
    index = 0
    while index < len(b):
        index = b.find(a, index)
        if index == -1:
            # No more occurrences found
            break

        # If a is found at the end of b, return True
        if index + len(a) == len(b):
            return True

        # Check the character immediately following a in b
        next_char = b[index + len(a)]

        # If the next character is a letter, digit, apostrophe or hyphen, return False
        if next_char.isalnum() or next_char in ["'", "-"]:
            # Move to the next character to search for further occurrences
            index += 1
        else:
            # If the next character is not as above, return True
            return True

    # If no valid occurrence of a in b is found
    return False

def get_2wikihoppotqa_gold(source_path, sample_path, sample_original_path, sample_target_path):
    soruce_datas = json.load(open(source_path))
    print("total: " + str(len(soruce_datas)))

    sample_from_strucutre = open(sample_path).readlines()
    sample_ids = [json.loads(data)['question_id'] for data in sample_from_strucutre]

    ##get_sample
    sample_dev_datas = []
    sample_datas = []
    for data in soruce_datas:
        id = data['_id']
        if id in sample_ids:
          sample_datas.append(data)
          answer = data['answer']
          question = data['question']
          supporting_facts = data['supporting_facts']
          context = data['context']
          level = data['type']

          gold_title_set = set()
          for fact in supporting_facts:
              gold_title_set.add(fact[0])

          gold_paragrahs = []
          for gold_title in gold_title_set:
              gold_paragrah = None
              for paragrah in context:
                  title, sentence_list = paragrah
                  if gold_title == title:
                      gold_paragrah = ' '.join(sentence_list)
                      gold_paragrahs.append(gold_paragrah)
              if gold_paragrah is None:
                  for paragrah in context:
                      title, sentence_list = paragrah
                      if gold_title != title and is_substring_refined(gold_title, title):
                          gold_paragrah = ' '.join(sentence_list)
                          gold_paragrahs.append(gold_paragrah)
          print(len(gold_paragrahs))
          # assert len(gold_paragrahs) == 2
          if len(gold_paragrahs) != len(gold_title_set):
              
              print(gold_title_set)
              print(data)
              print(gold_paragrahs)
              break
          paragrah_context = ' '.join(gold_paragrahs)
          sample_dev_datas.append({'id': id, 'claim': question, 'label': answer, 'challenge': level, 'evidence': paragrah_context})

    json.dump(sample_datas, open(sample_original_path, 'w'), indent=4)
    json.dump(sample_dev_datas, open(sample_target_path, 'w'), indent=4)

    print(len(json.load(open(sample_target_path))))



def get_2wikihoppotqa_gold_list(source_path, sample_path, sample_target_path):
    soruce_datas = json.load(open(source_path))
    print("total: " + str(len(soruce_datas)))

    sample_from_strucutre = open(sample_path).readlines()
    sample_ids = [json.loads(data)['question_id'] for data in sample_from_strucutre]

    ##get_sample
    sample_dev_datas = []
    sample_datas = []
    for data in soruce_datas:
        id = data['_id']
        if id in sample_ids:
          sample_datas.append(data)
          answer = data['answer']
          question = data['question']
          supporting_facts = data['supporting_facts']
          context = data['context']
          level = data['type']

          gold_title_set = set()
          for fact in supporting_facts:
              gold_title_set.add(fact[0])

          gold_paragrahs = []
          for gold_title in gold_title_set:
              gold_paragrah = None
              for paragrah in context:
                  title, sentence_list = paragrah
                  if gold_title == title:
                      gold_paragrah = ' '.join(sentence_list)
                      gold_paragrahs.append(gold_paragrah)
              if gold_paragrah is None:
                  for paragrah in context:
                      title, sentence_list = paragrah
                      if gold_title != title and is_substring_refined(gold_title, title):
                          gold_paragrah = ' '.join(sentence_list)
                          gold_paragrahs.append(gold_paragrah)
          print(len(gold_paragrahs))
          # assert len(gold_paragrahs) == 2
          if len(gold_paragrahs) != len(gold_title_set):
              
              print(gold_title_set)
              print(data)
              print(gold_paragrahs)
              break
        #   paragrah_context = ' '.join(gold_paragrahs)
          sample_dev_datas.append({'id': id, 'claim': question, 'label': answer, 'challenge': level, 'evidence': gold_paragrahs})

    # json.dump(sample_datas, open(sample_original_path, 'w'), indent=4)
    json.dump(sample_dev_datas, open(sample_target_path, 'w'), indent=4)

    print(len(json.load(open(sample_target_path))))


def get_hoppotqa_gold(source_path, sample_path, sample_original_path, sample_target_path):
    soruce_datas = json.load(open(source_path))
    print("total: " + str(len(soruce_datas)))

    sample_from_strucutre = open(sample_path).readlines()
    sample_ids = [json.loads(data)['question_id'] for data in sample_from_strucutre]

    ##get_sample
    sample_dev_datas = []
    sample_datas = []
    for data in soruce_datas:
        id = data['_id']
        if id in sample_ids:
          sample_datas.append(data)
          answer = data['answer']
          question = data['question']
          supporting_facts = data['supporting_facts']
          context = data['context']
          level = data['level']

          gold_title_set = set()
          for fact in supporting_facts:
              gold_title_set.add(fact[0])

          gold_paragrahs = []
          for gold_title in gold_title_set:
              gold_paragrah = None
              for paragrah in context:
                  title, sentence_list = paragrah
                  if gold_title == title:
                      gold_paragrah = ' '.join(sentence_list)
                      gold_paragrahs.append(gold_paragrah)
              if gold_paragrah is None:
                  for paragrah in context:
                      title, sentence_list = paragrah
                      if gold_title != title and is_substring_refined(gold_title, title):
                          gold_paragrah = ' '.join(sentence_list)
                          gold_paragrahs.append(gold_paragrah)
          print(len(gold_paragrahs))
          # assert len(gold_paragrahs) == 2
          if len(gold_paragrahs) != 2:
              
              print(gold_title_set)
              print(data)
              print(gold_paragrahs)
              break
          paragrah_context = ' '.join(gold_paragrahs)
          sample_dev_datas.append({'id': id, 'claim': question, 'label': answer, 'challenge': level, 'evidence': paragrah_context})

    json.dump(sample_datas, open(sample_original_path, 'w'), indent=4)
    json.dump(sample_dev_datas, open(sample_target_path, 'w'), indent=4)

    print(len(json.load(open(sample_target_path))))


def get_hoppotqa_gold_list(source_path, sample_path, sample_target_path):
    soruce_datas = json.load(open(source_path))
    print("total: " + str(len(soruce_datas)))

    sample_from_strucutre = open(sample_path).readlines()
    sample_ids = [json.loads(data)['question_id'] for data in sample_from_strucutre]

    ##get_sample
    sample_dev_datas = []
    sample_datas = []
    for data in soruce_datas:
        id = data['_id']
        if id in sample_ids:
          sample_datas.append(data)
          answer = data['answer']
          question = data['question']
          supporting_facts = data['supporting_facts']
          context = data['context']
          level = data['level']

          gold_title_set = set()
          for fact in supporting_facts:
              gold_title_set.add(fact[0])

          gold_paragrahs = []
          for gold_title in gold_title_set:
              gold_paragrah = None
              for paragrah in context:
                  title, sentence_list = paragrah
                  if gold_title == title:
                      gold_paragrah = ' '.join(sentence_list)
                      gold_paragrahs.append(gold_paragrah)
              if gold_paragrah is None:
                  for paragrah in context:
                      title, sentence_list = paragrah
                      if gold_title != title and is_substring_refined(gold_title, title):
                          gold_paragrah = ' '.join(sentence_list)
                          gold_paragrahs.append(gold_paragrah)
          print(len(gold_paragrahs))
          # assert len(gold_paragrahs) == 2
          if len(gold_paragrahs) != 2:
              
              print(gold_title_set)
              print(data)
              print(gold_paragrahs)
              break
          sample_dev_datas.append({'id': id, 'claim': question, 'label': answer, 'challenge': level, 'evidence': gold_paragrahs})

    json.dump(sample_dev_datas, open(sample_target_path, 'w'), indent=4)

    print(len(json.load(open(sample_target_path))))

def get_musique_gold_and_list(source_path, target_text_path, target_list_path):
    datas = open(source_path).readlines()
    print('total: ' + str(len(datas)))

    new_datas_text = []
    new_datas_list = []
    
    for data in datas:
        data = json.loads(data)
        id = data['id']
        paragraphs = data['paragraphs']
        question = data['question']
        answer = data['answer']
        question_decomposition = data['question_decomposition']
        challenge = id.split('_')[0]

        gold_paragraphs = []
        for decomposition in question_decomposition:
            for paragraph in paragraphs:
                idx = paragraph['idx']
                if idx == decomposition['paragraph_support_idx']:
                    gold_paragraphs.append(paragraph['paragraph_text'])

        new_datas_text.append({'id': id, 'claim': question, 'label': answer, 'challenge': challenge, 'evidence': ' '.join(gold_paragraphs)})

        new_datas_list.append({'id': id, 'claim': question, 'label': answer, 'challenge': challenge, 'evidence': gold_paragraphs})

    json.dump(new_datas_text, open(target_text_path, 'w'), indent=4)
    json.dump(new_datas_list, open(target_list_path, 'w'), indent=4)

    print(len(json.load(open(target_text_path))))
    print(len(json.load(open(target_list_path))))


def get_qangaroo_gold_and_list(source_path, target_list_path):
    datas = json.load(open(source_path))
    print('total: ' + str(len(datas)))

    new_datas_list = []
    
    for data in datas:
        # data = json.loads(data)
        id = data['id']
        supports = data['supports']
        query = data['query']
        answer = data['answer']
        candidates = data['candidates']
        challenge = 'multi-hop'

        supports_token = 0
        for support in supports:
            supports_token += len(support.split(' '))

        if supports_token < 1600:
            new_datas_list.append({'id': id, 'candidates': candidates, 'claim': query, 'label': answer, 'challenge': challenge, 'evidence': supports})
        if len(new_datas_list) == 500:
            break

    json.dump(new_datas_list, open(target_list_path, 'w'), indent=4)
    print(len(json.load(open(target_list_path))))


def get_strategyqa_gold_and_list(source_path, paragraphs_path, target_list_path, top_k = 229):
    datas = json.load(open(source_path))
    print('total: ' + str(len(datas)))

    paragraphs_database = json.load(open(paragraphs_path))
    

    new_datas_list = []
    
    for data in datas:
        id = data['qid']
        query = data['question']
        answer = data['answer']
        candidates = data['evidence']
        challenge = 'multi-hop'

        
        paragraph_ids = []
        for evidence in candidates:
            for paragraphs in evidence:
                for paragraph in paragraphs:
                    if paragraph != 'no_evidence' and paragraph != 'operation':
                        paragraph_ids.extend(paragraph)

        # paragraph_ids = sum(paragraph_ids, [])
        paragraph_ids = set(paragraph_ids)
        supports_token = 0
        support_content = []
        for para_id in paragraph_ids:
            content = paragraphs_database[para_id]['content']
            support_content.append(content)
            supports_token += len(content.split(' '))

        if supports_token < 1200:
            new_datas_list.append({'id': id, 'claim': query, 'label': answer, 'challenge': challenge, 'evidence': support_content, 'mind_maps': []})

    
    total_ids = list(range(len(new_datas_list)))
    # print(total_ids)

    label_true = 0
    label_false = 0
    for data in new_datas_list:
        if data['label'] is True:
            label_true += 1
        else:
            label_false +=1
    print(label_true)
    print(label_false)

    label_true = 0
    label_false = 0
    for data in datas:
        if data['answer'] is True:
            label_true += 1
        else:
            label_false +=1
    print(label_true)
    print(label_false)


    random_ids = random.sample(total_ids, k = top_k)
    sample_data_list = [new_datas_list[id] for id in random_ids]

    json.dump(sample_data_list, open(target_list_path, 'w'), indent=4)
    print(len(json.load(open(target_list_path))))

    label_true = 0
    label_false = 0
    for data in sample_data_list:
        # print(data['id'])
        if data['label'] is True:
            label_true += 1
        else:
            label_false +=1
    print(label_true)
    print(label_false)

if __name__ == "__main__":

    ## 2wikimultihopqa
    # ## 
    # source_path = './raw_data/2wikimultihopqa/2hotpot_dev_fullwiki_v1.json'
    # sample_ids_path = './processed_data/2wikimultihopqa/test_subsampled.jsonl'
    # sample_original_path = './raw_data/2wikimultihopqa/2hotpot_dev_distractor_v2.json'
    # sample_target_path = './raw_data/2wikimultihopqa/dev.json'
    # get_2wikihoppotqa_gold(source_path, sample_ids_path, sample_original_path, sample_target_path)

    ##2wikimultihopqa
    # ## 
    # source_path = './raw_data/2wikimultihopqa/2hotpot_dev_fullwiki_v1.json'
    # sample_ids_path = './processed_data/2wikimultihopqa/test_subsampled.jsonl'
    # sample_target_path = './raw_data/2wikimultihopqa/dev_list.json'
    # get_2wikihoppotqa_gold_list(source_path, sample_ids_path, sample_target_path)



    ##hotpotqa
    # ## 
    # source_path = './raw_data/hotpotqa/hotpot_dev_distractor_v1.json'
    # sample_ids_path = './processed_data/hotpotqa/test_subsampled.jsonl'
    # sample_original_path = './raw_data/hotpotqa/hotpot_dev_distractor_v2.json'
    # sample_target_path = './raw_data/hotpotqa/dev.json'
    # get_hoppotqa_gold(source_path, sample_ids_path, sample_original_path, sample_target_path)


    # #hotpotqa
    # ## 
    # source_path = './raw_data/hotpotqa/hotpot_dev_distractor_v1.json'
    # sample_ids_path = './processed_data/hotpotqa/test_subsampled.jsonl'
    # sample_target_path = './raw_data/hotpotqa/dev_list.json'
    # get_hoppotqa_gold_list(source_path, sample_ids_path, sample_target_path)


    #musique
    # ## 
    # source_path = './raw_data/musique/musique_ans_v1.0_dev.jsonl'
    # target_text_path = './raw_data/musique/dev.json'
    # target_list_path = './raw_data/musique/dev_list.json'
    # get_musique_gold_and_list(source_path, target_text_path, target_list_path)


    # #QANGAROO
    # ## 
    # source_path = './raw_data/qangaroo/dev_all.json'
    # target_list_path = './raw_data/qangaroo/dev.json'
    # get_qangaroo_gold_and_list(source_path, target_list_path)


    #strategyqa
    ## 
    source_path = './raw_data/strategyqa/strategyqa_train.json'
    paragraphs_path = './raw_data/strategyqa/strategyqa_train_paragraphs.json'
    target_list_path = './raw_data/strategyqa/dev.json'
    get_strategyqa_gold_and_list(source_path, paragraphs_path, target_list_path)


