from functools import partial
import fire
import glob
import random
import os
import json
import tqdm
import multiprocessing

from datasets import load_dataset, get_dataset_split_names
from promptsource.templates import DatasetTemplates


def get_dataset_keys(dataset_name):
    if dataset_name == "paws_x-en":
        return ['paws-x', 'en']

    if dataset_name.startswith('anli'):
        return ['anli']

    if '-' in dataset_name:
        return dataset_name.split('-')
    else:
        return [dataset_name]


def get_promptsource_key(dataset_name):
    if dataset_name == 'paws_x-en':
        return 'paws-x/en'
    elif dataset_name.startswith('anli'):
        return 'anli'
    else:
        return dataset_name.replace('-', '/')


def get_split_name(dataset_name, split):
    dataset_keys = get_dataset_keys(dataset_name)

    if split == 'train':
        assert 'train' in get_dataset_split_names(*dataset_keys)
        return 'train'
    else:
        if dataset_name.startswith('anli'):
            return '_'.join(['dev', dataset_name.split('-')[-1]])

        if dataset_name in ['crows_pairs', 'super_glue-axg']:
            return 'test'

        for split_name in get_dataset_split_names(*dataset_keys):
            if split_name.startswith('dev') or split_name.startswith('val'):
                return split_name
        raise ValueError


def process_example(
        example_idx, dataset, template, dataset_name, template_name):
    example = dataset[example_idx]

    try:
        instruction, answer = template.apply(example)
        answer_choices = template.get_answer_choices_list(example)
    except:
        return None

    processed_example = {
        'instruction': instruction,
        'references': [answer],
        'responses': [],
        'answer_choices': answer_choices,
        'dataset_name': dataset_name,
        'template_name': template_name,
        'example_idx': example_idx
    }

    if answer_choices is not None:
        for answer_choice in answer_choices:
            processed_example['responses'].append(
                [answer_choice, int(answer_choice == answer)])
    else:
        processed_example['responses'].append([answer, 1])

    return processed_example


def process_dataset(dataset_name, split, output_dir):
    os.makedirs(f'{output_dir}/{dataset_name}', exist_ok=True)

    if dataset_name == 'story_cloze-2016':
        dataset = load_dataset(
            'story_cloze',
            '2016',
            data_dir='story_cloze_2016',
            split=get_split_name(dataset_name, split))
    else:
        dataset = load_dataset(
            *get_dataset_keys(dataset_name),
            split=get_split_name(dataset_name, split),
            cache_dir=f'./cache/{dataset_name}')

    templates = DatasetTemplates(get_promptsource_key(dataset_name))
    for template_name in tqdm.tqdm(templates.all_template_names,
                                   desc=f'{dataset_name} ({len(dataset)})'):
        template = templates[template_name]

        template_name = template_name.replace(' ', '_').replace('/', '_')
        output_path = f'{output_dir}/{dataset_name}/{template_name}.jsonl'

        if os.path.exists(output_path):
            continue

        with open(output_path, 'w') as output_file:
            template_process_fn = partial(
                process_example,
                dataset=dataset,
                template=template,
                dataset_name=dataset_name,
                template_name=template_name)

            # for example_idx, example in enumerate(dataset):
            #     processed_example = template_process_fn(example_idx)

            with multiprocessing.Pool(processes=os.cpu_count()) as pool:
                for processed_example in pool.imap_unordered(
                        template_process_fn, range(min(len(dataset), 500000))):
                    if processed_example is not None:
                        output_file.write(json.dumps(processed_example) + '\n')


def main(output_dir='./processed_t0_data', gen_data_dir='./t0_gen_data'):
    # for dataset_name in json.load(open('data_utils/split.json'))['T0']:
    #         process_dataset(dataset_name, split='train', output_dir=output_dir)
    #
    # for dataset_name in json.load(open('data_utils/split.json'))['evaluation']:
    #     process_dataset(dataset_name, split='validation', output_dir=output_dir)

    # for dataset_name in json.load(open('data_utils/split.json'))['T0']:
    #     for filename in glob.glob(f'{output_dir}/{dataset_name}/*.jsonl'):
    #         examples = []
    #         for line in open(filename):
    #             example = json.loads(line)
    #             if example['answer_choices'] is None:
    #                 examples.append(example)
    #             else:
    #                 break
    #
    #         if len(examples) > 0:
    #             with open(filename, 'w') as output_file:
    #                 for example in examples:
    #                     assert len(example['responses']) == 1
    #                     while True:
    #                         random_idx = random.randint(0, len(examples) - 1)
    #                         hypo_neg = random.choice(
    #                             examples[random_idx]['references'])
    #
    #                         if hypo_neg not in example['references']:
    #                             example['responses'].append([hypo_neg, 0])
    #                             break
    #
    #                     output_file.write(json.dumps(example) + '\n')
    #
    #             print(f'{filename} updated with negative responses.')

    os.makedirs(gen_data_dir, exist_ok=True)
    gen_t0_file = open(f'{gen_data_dir}/chunk0.jsonl', 'w')
    n_lines, cur_idx = 0, 0
    for dataset_name in json.load(open('data_utils/split.json'))['T0']:
        for filename in glob.glob(f'{output_dir}/{dataset_name}/*.jsonl'):
            examples = []
            for line in open(filename):
                example = json.loads(line)
                if example['answer_choices'] is None:
                    examples.append(example)
                else:
                    break

            if len(examples) > 0:
                for example in examples:
                    example.pop('responses')
                    gen_t0_file.write(json.dumps(example) + '\n')
                    n_lines += 1
                    if n_lines == 1000000:
                        cur_idx += 1
                        gen_t0_file = \
                            open(f'{gen_data_dir}/chunk{cur_idx}.jsonl', 'w')
                        n_lines = 0

                print(f'{filename}: {len(examples)} generative examples.')


if __name__ == '__main__':
    fire.Fire(main)
