import json
import argparse
import re

def find_wrong_data(input_path, paradigm):    
    datas = json.load(open(input_path))

    for index, data in enumerate(datas):
        gold_label = data['label'].strip()
        gpt_response = data['prediction'].strip().lower()

        if paradigm == 'qa':
            if gpt_response.find('true') != -1:
                pre_label ='supports'
            elif gpt_response.find('false') != -1:
                pre_label ='refutes'
            else:
                pre_label = 'error'
        else:
            pattern = "<response>(.*)</response>"
            gpt_response_regrex = re.findall(pattern, gpt_response)[0].strip()
            if gpt_response_regrex.find('support') != -1:
                pre_label ='supports'
            elif gpt_response_regrex.find('refutes') != -1:
                pre_label = 'refutes'
            else:
                pre_label = 'error'
            if pre_label != gold_label:
                print(index, gold_label, gpt_response, pre_label)


def parse_args():
    parser = argparse.ArgumentParser()
    # dataset args
    parser.add_argument('--data_path', type=str)
    parser.add_argument('--paradigm', type=str, choices = ['qa', 'fc'])
    args = parser.parse_args()
    return args
if __name__ == "__main__":
    args = parse_args()

    find_wrong_data(args.data_path, args.paradigm)
    