import json


def get_sci_data(source_path, sample_path, corpus_path,  result_path):
    sample_datas = json.load(open(sample_path)) 
    all_datas = open(source_path).readlines()
    corpus = open(corpus_path).readlines()
    corpus = [json.loads(c) for c in corpus]
    new_datas = []
    for sample in sample_datas:
        sample_id = sample['uid']
        claim = sample['claim']
        label = sample['label']
        for data in all_datas:
            data = json.loads(data)
            if data['id'] == sample_id:
                evidence_dict = data['evidence']
                corpus_ids = evidence_dict.keys()
                # print(corpus_ids)
                corpus_evidence = [' '.join(c['abstract']) for c in corpus if str(c['doc_id']) in corpus_ids]            
                break
        if len(corpus_evidence) != 0:
            new_datas.append({'id':sample_id, 'claim':claim, 'label':label, 'evidence': corpus_evidence})

    json.dump(new_datas, open(result_path, 'w'), indent=4)


if __name__ == "__main__":
    source_path = './raw/claims.jsonl'
    sample_path = './processed/df_new.json'
    corpus_path = './source/corpus.jsonl'
    result_path = './datasets/SCIFACT/source/dev.json'
    get_sci_data(source_path, sample_path, corpus_path, result_path)