from pytorch_lightning import Trainer, seed_everything
from data_utils.dataloader import DSTDataLoader
from model.bert import BERTAlignModel
from pytorch_lightning.callbacks import ModelCheckpoint
import fire

seed_everything(2022)

BATCH_SIZE = 32
ACCUMULATE_GRAD_BATCH = 1# 4
NUM_EPOCH = 3
MAX_STEPS = 415000
NUM_WORKERS = 8
WARMUP_PROPORTION = 0.06
ADAM_EPSILON = 1e-6
WEIGHT_DECAY = 0.1#  0.01
LR = 1e-5# 1e-4
VAL_CHECK_INTERVAL = 1. / 4
DEVICES=[5,6]

MODEL_NAME = "roberta-base"
MODEL_SAVE_NAME_COMMENT = ""

NEED_MLM = False
USING_PRETRAINED = True
CKPT_SAVE_PATH = f"checkpoints/ablation_ALIGN_EVAL/{MODEL_NAME.replace('/', '-')}/"

### exp: 10k --> 50k, show quantity effect || 50k has_label_task --> 50k no_label_task, show the label effect || total 4 exps
# DATA_SIZE = 50000
# TRAINING_DATASETS = {
#     'cnndm': {'task_type': 'summarization', 'data_path': 'data/cnndm.json', 'size': DATA_SIZE},
#     'mnli': {'task_type': 'nli', 'data_path': 'data/mnli.json', 'size': DATA_SIZE},
#     'squad': {'task_type': 'qa', 'data_path': 'data/squad.json', 'size': DATA_SIZE},
#     'paws': {'task_type': 'paraphrase', 'data_path': 'data/paws.json', 'size':DATA_SIZE},
#     'vitaminc': {'task_type': 'fact_checking', 'data_path': 'data/vitaminc.json', 'size':DATA_SIZE},
#     'xsum': {'task_type': 'summarization', 'data_path': 'data/xsum.json', 'size':DATA_SIZE},
#     'race': {'task_type': 'multiple_choice_qa', 'data_path': 'data/race.json', 'size': DATA_SIZE},
#     'anli_r1': {'task_type': 'nli', 'data_path': 'data/anli_r1.json', 'size': DATA_SIZE},
#     'anli_r2': {'task_type': 'nli', 'data_path': 'data/anli_r2.json', 'size': DATA_SIZE},
#     'anli_r3': {'task_type': 'nli', 'data_path': 'data/anli_r3.json', 'size': DATA_SIZE},
#     'snli': {'task_type': 'nli', 'data_path': 'data/snli.json', 'size': DATA_SIZE},
#     'wikihow': {'task_type': 'summarization', 'data_path': 'data/wikihow.json', 'size': DATA_SIZE},
# }
def main():
    dm = DSTDataLoader(dataset_config=TRAINING_DATASETS, model_name=MODEL_NAME, sample_mode='seq',
                        train_batch_size=BATCH_SIZE, eval_batch_size=16, num_workers=NUM_WORKERS, train_eval_split=0.95, need_mlm=NEED_MLM)
    dm.setup()

    model = BERTAlignModel(model=MODEL_NAME, using_pretrained=USING_PRETRAINED,
                            adam_epsilon=ADAM_EPSILON,
                            learning_rate=LR,
                            weight_decay=WEIGHT_DECAY,
                            warmup_steps_portion=WARMUP_PROPORTION)
    model.need_mlm = NEED_MLM

    training_dataset_used = '_'.join([dataset_name for dataset_name in TRAINING_DATASETS.keys()]) ##"full-dataset"
    checkpoint_callback = ModelCheckpoint(
        dirpath=CKPT_SAVE_PATH,
        filename=f"{MODEL_SAVE_NAME_COMMENT}{MODEL_NAME.replace('/', '-')}_{'scratch_' if not USING_PRETRAINED else ''}{'no_mlm_' if not NEED_MLM else ''}{training_dataset_used}_{DATA_SIZE}_{BATCH_SIZE}x{len(DEVICES)}x{ACCUMULATE_GRAD_BATCH}_"+'{epoch:02d}_{step}', # batch_size, device, acc_batch ## standard 16x2x8
        every_n_train_steps=10000,
        save_top_k=1
    )
    trainer = Trainer(accelerator='gpu', 
                        max_epochs=NUM_EPOCH, 
                        # max_steps=MAX_STEPS,
                        devices=DEVICES, 
                        strategy="dp", 
                        precision=32,
                        callbacks=[checkpoint_callback],
                        accumulate_grad_batches=ACCUMULATE_GRAD_BATCH)

    trainer.fit(model, datamodule=dm)
    trainer.save_checkpoint(CKPT_SAVE_PATH+f"{MODEL_SAVE_NAME_COMMENT}{MODEL_NAME.replace('/', '-')}_{'scratch_' if not USING_PRETRAINED else ''}{'no_mlm_' if not NEED_MLM else ''}{training_dataset_used}_{DATA_SIZE}_{BATCH_SIZE}x{len(DEVICES)}x{ACCUMULATE_GRAD_BATCH}_final.ckpt")

    print("Training is finished.")

if __name__ == "__main__":
    ### ADD paws_unlabeled!
    DATA_SIZE = 500000
    NLI_group = ['mnli', 'doc_nli', 'snli', 'anli_r1', 'anli_r2', 'anli_r3']
    FV_group = ['nli_fever', 'vitaminc']
    Para_group = ['paws', 'paws_qqp', 'paws_unlabeled', 'qqp', 'wiki103']
    QA_group = ['squad_v2', 'race', 'adversarial_qa', 'drop', 'hotpot_qa_distractor', 'hotpot_qa_fullwiki', 'newsqa', 'quoref', 'ropes', 'boolq', 'eraser_multi_rc', 'quail', 'sciq', 'strategy_qa']
    Coref_group = ['gap']
    SUM_group = ['wikihow']
    IR_group = ['msmarco']
    STS_group = ['stsb', 'sick']
    ALL_TRAINING_DATASETS = {
        ### NLI
        'mnli': {'task_type': 'nli', 'data_path': 'data/training/mnli.json', 'size': DATA_SIZE},     
        'doc_nli': {'task_type': 'bin_nli', 'data_path': 'data/training/doc_nli.json', 'size': DATA_SIZE},
        'snli': {'task_type': 'nli', 'data_path': 'data/training/snli.json', 'size': DATA_SIZE},
        'anli_r1': {'task_type': 'nli', 'data_path': 'data/training/anli_r1.json', 'size': DATA_SIZE},
        'anli_r2': {'task_type': 'nli', 'data_path': 'data/training/anli_r2.json', 'size': DATA_SIZE},
        'anli_r3': {'task_type': 'nli', 'data_path': 'data/training/anli_r3.json', 'size': DATA_SIZE},

        ### fact checking
        'nli_fever': {'task_type': 'fact_checking', 'data_path': 'data/training/nli_fever.json', 'size': DATA_SIZE},
        'vitaminc': {'task_type': 'fact_checking', 'data_path': 'data/training/vitaminc.json', 'size':DATA_SIZE},

        ### paraphrase
        'paws': {'task_type': 'paraphrase', 'data_path': 'data/training/paws.json', 'size':DATA_SIZE},
        'paws_qqp': {'task_type': 'paraphrase', 'data_path': 'data/training/paws_qqp.json', 'size':DATA_SIZE},
        'paws_unlabeled': {'task_type': 'paraphrase', 'data_path': 'data/training/paws_unlabeled.json', 'size': DATA_SIZE},
        'qqp': {'task_type': 'paraphrase', 'data_path': 'data/training/qqp.json', 'size': DATA_SIZE},
        'wiki103': {'task_type': 'paraphrase', 'data_path': 'data/training/wiki103.json', 'size': DATA_SIZE},
        # 'mrpc': {'task_type': 'paraphrase', 'data_path': 'data/mrpc.json', 'size':DATA_SIZE},

        ### QA
        # 'squad': {'task_type': 'qa', 'data_path': 'data/squad.json', 'size': DATA_SIZE},
        'squad_v2': {'task_type': 'qa', 'data_path': 'data/training/squad_v2_new.json', 'size': DATA_SIZE},
        'race': {'task_type': 'qa', 'data_path': 'data/training/race.json', 'size': DATA_SIZE},
        'adversarial_qa': {'task_type': 'qa', 'data_path': 'data/training/adversarial_qa.json', 'size': DATA_SIZE},
        'drop': {'task_type': 'qa', 'data_path': 'data/training/drop.json', 'size': DATA_SIZE},
        ## 'duorc_paraphrase': {'task_type': 'qa', 'data_path': 'data/training/duorc_paraphrase.json', 'size': DATA_SIZE},
        ## 'duorc_self': {'task_type': 'qa', 'data_path': 'data/training/duorc_paraphrase.json', 'size': DATA_SIZE},
        'hotpot_qa_distractor': {'task_type': 'qa', 'data_path': 'data/training/hotpot_qa_distractor.json', 'size': DATA_SIZE},
        'hotpot_qa_fullwiki': {'task_type': 'qa', 'data_path': 'data/training/hotpot_qa_fullwiki.json', 'size': DATA_SIZE},
        'newsqa': {'task_type': 'qa', 'data_path': 'data/training/newsqa.json', 'size': DATA_SIZE},
        'quoref': {'task_type': 'qa', 'data_path': 'data/training/quoref.json', 'size': DATA_SIZE},
        'ropes': {'task_type': 'qa', 'data_path': 'data/training/ropes.json', 'size': DATA_SIZE},
        'boolq': {'task_type': 'qa', 'data_path': 'data/training/boolq.json', 'size': DATA_SIZE},
        'eraser_multi_rc': {'task_type': 'qa', 'data_path': 'data/training/eraser_multi_rc.json', 'size': DATA_SIZE},
        'quail': {'task_type': 'qa', 'data_path': 'data/training/quail.json', 'size': DATA_SIZE},
        'sciq': {'task_type': 'qa', 'data_path': 'data/training/sciq.json', 'size': DATA_SIZE},
        'strategy_qa': {'task_type': 'qa', 'data_path': 'data/training/strategy_qa.json', 'size': DATA_SIZE},

        ### Coreference
        'gap': {'task_type': 'coreference', 'data_path': 'data/training/gap.json', 'size': DATA_SIZE},

        ### Summarization
        'wikihow': {'task_type': 'summarization', 'data_path': 'data/training/wikihow.json', 'size': DATA_SIZE},
        # 'xsum': {'task_type': 'summarization', 'data_path': 'data/xsum.json', 'size':DATA_SIZE},
        # 'cnndm': {'task_type': 'summarization', 'data_path': 'data/cnndm.json', 'size': DATA_SIZE},

        ### Information Retrieval
        'msmarco': {'task_type': 'ir', 'data_path': 'data/training/msmarco.json', 'size': DATA_SIZE},

        ### STS
        'stsb': {'task_type': 'sts', 'data_path': 'data/training/stsb.json', 'size': DATA_SIZE},
        'sick': {'task_type': 'sts', 'data_path': 'data/training/sick.json', 'size': DATA_SIZE},

        ### CTC
        # 'ctc': {'task_type': 'ctc', 'data_path': 'data/ctc.json', 'size': DATA_SIZE},
        
    }
    TRAINING_DATASETS = dict()
    for each in NLI_group+FV_group+Para_group:
        # if each in ['nli_fever']:#'doc_nli',
        # if each in ['boolq','eraser_multi_rc', 'quail', 'sciq', 'gap']:
        # if each in ['msmarco']:
            TRAINING_DATASETS[each] = ALL_TRAINING_DATASETS[each]
    main()

    # TRAINING_DATASETS = {
    #     'paws_unlabeled': {'task_type': 'paraphrase', 'data_path': 'data/paws_unlabeled.json', 'size': DATA_SIZE},
    #     'paws_qqp': {'task_type': 'paraphrase', 'data_path': 'data/paws_qqp.json', 'size':DATA_SIZE},
    # }
    # main()
    # TRAINING_DATASETS = {
    #     'paws_unlabeled': {'task_type': 'paraphrase', 'data_path': 'data/paws_unlabeled.json', 'size': DATA_SIZE},
    #     'paws': {'task_type': 'paraphrase', 'data_path': 'data/paws_qqp.json', 'size':DATA_SIZE},
    # }
    # main()
    # TRAINING_DATASETS = {
    #     'anli_r1': {'task_type': 'nli', 'data_path': 'data/anli_r1.json', 'size': DATA_SIZE},
    #     'anli_r2': {'task_type': 'nli', 'data_path': 'data/anli_r2.json', 'size': DATA_SIZE},
    #     'anli_r3': {'task_type': 'nli', 'data_path': 'data/anli_r3.json', 'size': DATA_SIZE},
    # }
    # main()
    # for each_dataset in ALL_TRAINING_DATASETS:
        # if each_dataset not in ['mnli', 'vitaminc', 'race']:
        # if each_dataset not in ['snli', 'qqp', 'stsb', 'sick', 'mrpc']:
        # if each_dataset not in ['doc_nli']:
        #     continue
        # TRAINING_DATASETS = {each_dataset: ALL_TRAINING_DATASETS[each_dataset]}
        # main()


