# -*- coding:utf-8 -*- 
import argparse
import glob
import logging
import os
import random
import copy
import math
import json
import numpy as np
import torch
from torch.nn import CrossEntropyLoss, KLDivLoss, NLLLoss, BCEWithLogitsLoss
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
import torch.nn.functional as F
from tqdm import tqdm, trange
import sys
import pickle as pkl

from transformers import (
    WEIGHTS_NAME,
    AdamW,
    RobertaConfig,
    RobertaForTokenClassification,
    RobertaTokenizer,
    get_linear_schedule_with_warmup,
)

from models.modeling_roberta_debias_bin import RobertaForTokenClassification_Modified
from utils.data_utils import load_and_cache_examples, get_labels, tag_to_id
from utils.model_utils import mask_tokens, mask_bitokens, soft_frequency, opt_grad, get_hard_label, _update_mean_model_variables
from utils.eval import evaluate, evaluate_ori
from utils.config import config
from utils.loss_utils import NegEntropy, GCELoss, WorstCaseEstimationLoss

logger = logging.getLogger(__name__)

MODEL_NAMES = {
    "student1":"Roberta",
    "student2":"DistilRoberta",
    "teacher1":"Roberta",
    "teacher2":"DistilRoberta"
}
MODEL_CLASSES = {
    "student1": (RobertaConfig, RobertaForTokenClassification_Modified, RobertaTokenizer),
    "student2": (RobertaConfig, RobertaForTokenClassification_Modified, RobertaTokenizer),
}

torch.set_printoptions(profile="full")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

def initialize(args, t_total, num_labels, epoch):
    config_class, model_class, _ = MODEL_CLASSES["student1"]
    config_s1 = config_class.from_pretrained(
        args.student1_config_name if args.student1_config_name else args.student1_model_name_or_path,
        num_labels=num_labels,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model_s1 = model_class.from_pretrained(
        args.student1_model_name_or_path,
        from_tf=bool(".ckpt" in args.student1_model_name_or_path),
        config=config_s1,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model_s1.to(args.device)

    config_class, model_class, _ = MODEL_CLASSES["student2"]
    config_s2 = config_class.from_pretrained(
        args.student2_config_name if args.student2_config_name else args.student2_model_name_or_path,
        num_labels=num_labels,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model_s2 = model_class.from_pretrained(
        args.student2_model_name_or_path,
        from_tf=bool(".ckpt" in args.student2_model_name_or_path),
        config=config_s2,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model_s2.to(args.device)

    config_class, model_class, _ = MODEL_CLASSES["student1"]
    config_t1 = config_class.from_pretrained(
        args.student1_config_name if args.student1_config_name else args.student1_model_name_or_path,
        num_labels=num_labels,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model_t1 = model_class.from_pretrained(
        args.student1_model_name_or_path,
        from_tf=bool(".ckpt" in args.student1_model_name_or_path),
        config=config_t1,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model_t1.to(args.device)

    config_class, model_class, _ = MODEL_CLASSES["student2"]
    config_t2 = config_class.from_pretrained(
        args.student2_config_name if args.student2_config_name else args.student2_model_name_or_path,
        num_labels=num_labels,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model_t2 = model_class.from_pretrained(
        args.student2_model_name_or_path,
        from_tf=bool(".ckpt" in args.student2_model_name_or_path),
        config=config_t2,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model_t2.to(args.device)

    no_decay = ["bias", "LayerNorm.weight"]

    optimizer_grouped_parameters_1 = [
        {
            "params": [p for n, p in model_s1.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model_s1.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer_s1 = AdamW(optimizer_grouped_parameters_1, lr=args.learning_rate, \
            eps=args.adam_epsilon, betas=(args.adam_beta1,args.adam_beta2))
    scheduler_s1 = get_linear_schedule_with_warmup(
        optimizer_s1, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    optimizer_grouped_parameters_2 = [
        {
            "params": [p for n, p in model_s2.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model_s2.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer_s2 = AdamW(optimizer_grouped_parameters_2, lr=args.learning_rate, \
            eps=args.adam_epsilon, betas=(args.adam_beta1,args.adam_beta2))
    scheduler_s2 = get_linear_schedule_with_warmup(
        optimizer_s2, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        [model_s1, model_s2, model_t1, model_t2], [optimizer_s1, optimizer_s2] = amp.initialize(
                     [model_s1, model_s2, model_t1, model_t2], [optimizer_s1, optimizer_s2], opt_level=args.fp16_opt_level)

    # Multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        # model_t = torch.nn.DataParallel(model_t)
        model_s1 = torch.nn.DataParallel(model_s1)
        model_s2 = torch.nn.DataParallel(model_s2)
        model_t1 = torch.nn.DataParallel(model_t1)
        model_t2 = torch.nn.DataParallel(model_t2)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model_s1 = torch.nn.parallel.DistributedDataParallel(
            model_s1, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
        model_s2 = torch.nn.parallel.DistributedDataParallel(
            model_s2, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
        model_t1 = torch.nn.parallel.DistributedDataParallel(
            model_t1, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
        model_t2 = torch.nn.parallel.DistributedDataParallel(
            model_t2, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )

    model_s1.zero_grad()
    model_s2.zero_grad()
    model_t1.zero_grad()
    model_t2.zero_grad()

    for param in model_t1.parameters():
        param.detach_()
    for param in model_t2.parameters():
        param.detach_()
    return model_s1, model_s2, model_t1, model_t2, optimizer_s1, scheduler_s1, optimizer_s2, scheduler_s2

def save_model(args, epoch, tors, model):
    # updated_self_training_teacher = True
    model_name = tors + "_ep" + str(epoch)
    path = os.path.join(args.output_dir+tors, model_name)
    logger.info("Saving model checkpoint to %s", path)
    if not os.path.exists(path):
        os.makedirs(path)
    model_to_save = (
            model.module if hasattr(model, "module") else model
    )
    model_to_save.save_pretrained(path)

def load_model(args, epoch, tors):
    model_name = tors + "_ep" + str(epoch)
    path = os.path.join(args.output_dir+tors, model_name)
    model = RobertaForTokenClassification_Modified.from_pretrained(path)
    model.to(args.device)
    return model

def validation(args, model, tokenizer, labels, pad_token_label_id, best_dev, best_test, 
                  global_step, t_total, epoch, tors):
    model_type = MODEL_NAMES[tors].lower()
    results, _, best_dev, is_updated1 = evaluate_ori(args, model, tokenizer, labels, pad_token_label_id, best_dev, mode="dev", \
        logger=logger, prefix='dev [Step {}/{} | Epoch {}/{}]'.format(global_step, t_total, epoch, args.num_train_epochs), verbose=False)
    results, _, best_test, is_updated2 = evaluate_ori(args, model, tokenizer, labels, pad_token_label_id, best_test, mode="test", \
        logger=logger, prefix='test [Step {}/{} | Epoch {}/{}]'.format(global_step, t_total, epoch, args.num_train_epochs), verbose=False)
   
    # output_dirs = []
    if args.local_rank in [-1, 0] and is_updated1:
        # updated_self_training_teacher = True
        path = os.path.join(args.output_dir+tors, "checkpoint-best-1")
        logger.info("Saving model checkpoint to %s", path)
        if not os.path.exists(path):
            os.makedirs(path)
        model_to_save = (
                model.module if hasattr(model, "module") else model
        )  # Take care of distributed/parallel training
        model_to_save.save_pretrained(path)
        tokenizer.save_pretrained(path)
    # output_dirs = []
    if args.local_rank in [-1, 0] and is_updated2:
        # updated_self_training_teacher = True
        path = os.path.join(args.output_dir+tors, "checkpoint-best-2")
        logger.info("Saving model checkpoint to %s", path)
        if not os.path.exists(path):
            os.makedirs(path)
        model_to_save = (
                model.module if hasattr(model, "module") else model
        )  # Take care of distributed/parallel training
        model_to_save.save_pretrained(path)
        tokenizer.save_pretrained(path)
    return best_dev, best_test, is_updated1

def get_teacher(args, model_t1, model_t2, t_model1, t_model2, dev_is_updated1, dev_is_updated2, batch=True):
    if args.dataset in ["conll03", "wikigold"] and batch:
        if dev_is_updated1:
            t_model1 = copy.deepcopy(model_t1)
        if dev_is_updated2:
            t_model2 = copy.deepcopy(model_t2)
    else:
        t_model1 = copy.deepcopy(model_t1)
        t_model2 = copy.deepcopy(model_t2)
    return t_model1, t_model2

def finetune(args, ep, train_dataset, tokenizer, labels, pad_token_label_id):
    num_labels = len(labels)
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset) if args.local_rank==-1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps//(len(train_dataloader)//args.gradient_accumulation_steps)+1
    else:
        t_total = len(train_dataloader)//args.gradient_accumulation_steps*args.num_train_epochs

    model_s1 = load_model(args, ep, "student1")
    model_s2 = load_model(args, ep, "student2")
    model_t1 = load_model(args, ep, "teacher1")
    model_t2 = load_model(args, ep, "teacher2")

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters_1 = [
        {
            "params": [p for n, p in model_s1.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model_s1.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer_s1 = AdamW(optimizer_grouped_parameters_1, lr=args.learning_rate/10, \
            eps=args.adam_epsilon, betas=(args.adam_beta1,args.adam_beta2))
    scheduler_s1 = get_linear_schedule_with_warmup(
        optimizer_s1, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )
    optimizer_grouped_parameters_2 = [
        {
            "params": [p for n, p in model_s2.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model_s2.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer_s2 = AdamW(optimizer_grouped_parameters_2, lr=args.learning_rate/10, \
            eps=args.adam_epsilon, betas=(args.adam_beta1,args.adam_beta2))
    scheduler_s2 = get_linear_schedule_with_warmup(
        optimizer_s2, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    logger.info("***** Fine tuning *****")
    global_step = 0
    epochs_trained = 0
    tr_loss, logging_loss = 0.0, 0.0
    train_iterator = trange(
        epochs_trained, 10, desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
    set_seed(args)  # Added here for reproductibility
    s1_best_dev, s1_best_test = [0, 0, 0], [0, 0, 0]
    s2_best_dev, s2_best_test = [0, 0, 0], [0, 0, 0]
    t1_best_dev, t1_best_test = [0, 0, 0], [0, 0, 0]
    t2_best_dev, t2_best_test = [0, 0, 0], [0, 0, 0]

    softmax = torch.nn.Softmax(dim=1)
    t_model1 = copy.deepcopy(model_s1)
    t_model2 = copy.deepcopy(model_s2)

    begin_global_step = len(train_dataloader)*args.begin_epoch//args.gradient_accumulation_steps
    for epoch in train_iterator:
        epoch_iterator = train_dataloader

        for step, batch in enumerate(epoch_iterator):
            model_s1.train()
            model_s2.train()
            model_t1.train()
            model_t2.train()

            batch = tuple(t.to(args.device) for t in batch)
            valid_pos = batch[2]
            pseudo_labels1 = batch[3][valid_pos>0]
            pseudo_labels2 = batch[3][valid_pos>0]

            bin_pseudo_labels1, bin_pseudo_labels2 = pseudo_labels1.clone(), pseudo_labels2.clone()
            bin_pseudo_labels1[pseudo_labels1>0] = 1
            bin_pseudo_labels2[pseudo_labels2>0] = 1
            type_pseudo_labels1, type_pseudo_labels2 = pseudo_labels1-1, pseudo_labels2-1
            type_pseudo_labels1[type_pseudo_labels1<0] = -100
            type_pseudo_labels2[type_pseudo_labels1<0] = -100
            type_pos1 = pseudo_labels1>0
            type_pos2 = pseudo_labels2>0

            inputs = {"input_ids": batch[0], "attention_mask": batch[1], "valid_pos":batch[2]}
            with torch.no_grad():
                type_logits1,bin_logits1,type_pred_labels1,bin_pred_labels1,logits1 = t_model1(**inputs)
                type_logits2,bin_logits2,type_pred_labels2,bin_pred_labels2,logits2 = t_model2(**inputs)
                pred_labels1 = torch.argmax(logits1, dim=-1)
                pred_labels2 = torch.argmax(logits2, dim=-1)

                entity_pred1 = (bin_pred_labels1==1)
                entity_pred2 = (bin_pred_labels2==1)

                type_label_mask1 = entity_pred1&(type_pred_labels1==type_pseudo_labels1)
                type_label_mask2 = entity_pred2&(type_pred_labels2==type_pseudo_labels2)
                bin_label_mask1 = torch.ones_like(bin_pred_labels1, dtype=bool)
                bin_label_mask2 = torch.ones_like(bin_pred_labels2, dtype=bool)

            logits1 = soft_frequency(logits=type_logits1, power=2)
            logits2 = soft_frequency(logits=type_logits2, power=2)

            if args.self_learning_label_mode == "hard":
                pred_labels1, label_mask1_ = mask_tokens(args, batch[3], pred_labels1, pad_token_label_id, pred_logits=logits1)
                pred_labels2, label_mask2_ = mask_tokens(args, batch[3], pred_labels2, pad_token_label_id, pred_logits=logits2)

            inputs1 = {"input_ids": batch[0], "attention_mask": batch[1], "valid_pos": batch[2], "train":True}
            type_logits1,bin_logits1,type_logits_adv1,bin_logits_adv1,type_logits_pseudo1,bin_logits_pseudo1,_,_,_ = model_s1(**inputs1)

            inputs2 = {"input_ids": batch[0], "attention_mask": batch[1], "valid_pos": batch[2], "train":True}
            type_logits2,bin_logits2,type_logits_adv2,bin_logits_adv2,type_logits_pseudo2,bin_logits_pseudo2,_,_,_ = model_s2(**inputs2)

            if type_label_mask1.sum().item()==0:
                type_label_mask1 = type_pos1
            if type_label_mask2.sum().item()==0:
                type_label_mask2 = type_pos2

            bin_filtered_sel1 = bin_label_mask1.view(-1)
            bin_filtered_sel2 = bin_label_mask2.view(-1)
            type_filtered_sel1 = type_label_mask1.view(-1)
            type_filtered_sel2 = type_label_mask2.view(-1)

            type_idx_unchosen1 = (type_filtered_sel1 == False)
            type_idx_unchosen2 = (type_filtered_sel2 == False)

            type_loss_fct = CrossEntropyLoss()
            type_wce_loss = WorstCaseEstimationLoss(2).to(device)

            loss1 = type_loss_fct(type_logits1[type_filtered_sel1], type_pred_labels1[type_filtered_sel1])
            loss2 = type_loss_fct(type_logits2[type_filtered_sel2], type_pred_labels2[type_filtered_sel2])

            if type_filtered_sel1.sum().item() and type_idx_unchosen1.sum().item():
                loss1 = loss1 + type_wce_loss(type_logits1[type_filtered_sel1], type_logits_adv1[type_filtered_sel1], type_logits1[type_idx_unchosen1], type_logits_adv1[type_idx_unchosen1])
            if type_filtered_sel2.sum().item() and type_idx_unchosen2.sum().item():
                loss2 = loss2 + type_wce_loss(type_logits2[type_filtered_sel2], type_logits_adv2[type_filtered_sel2], type_logits2[type_idx_unchosen2], type_logits_adv2[type_idx_unchosen2])

            loss_total = loss1 + loss2

            if args.n_gpu > 1:
                loss1 = loss1.mean()
                loss2 = loss2.mean()
            if args.gradient_accumulation_steps > 1:
                loss1 = loss1/args.gradient_accumulation_steps
                loss2 = loss2/args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss1, optimizer_s1) as scaled_loss1:
                    scaled_loss1.backward()
                with amp.scale_loss(loss2, optimizer_s2) as scaled_loss2:
                    scaled_loss2.backward()
            else:
                loss_total.backward()

            tr_loss += loss1.item()+loss2.item()
            if (step+1)%args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_s1), args.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_s2), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model_s1.parameters(), args.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(model_s2.parameters(), args.max_grad_norm)

                optimizer_s1.step()
                scheduler_s1.step()  # Update learning rate schedule
                optimizer_s2.step()
                scheduler_s2.step()  # Update learning rate schedule
                model_s1.zero_grad()
                model_s2.zero_grad()
                global_step += 1

                _update_mean_model_variables(model_s1, model_t1, args.mean_alpha, global_step)
                _update_mean_model_variables(model_s2, model_t2, args.mean_alpha, global_step)
                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step%args.logging_steps == 0:
                    if args.evaluate_during_training:
                        logger.info("***** Student1 combined Entropy loss : %.4f *****", loss1.item())
                        logger.info("##### Student1 #####")
                        s1_best_dev, s1_best_test, _ = validation(args, model_s1, tokenizer, labels, pad_token_label_id, \
                            s1_best_dev, s1_best_test, global_step, t_total, epoch, "student1")
                        logger.info("##### Teacher1 #####")
                        t1_best_dev, t1_best_test, dev_is_updated1 = validation(args, model_t1, tokenizer, labels, pad_token_label_id, \
                            t1_best_dev, t1_best_test, global_step, t_total, epoch, "teacher1")
                        logger.info("***** Student2 combined Entropy loss : %.4f *****", loss2.item())
                        logger.info("##### Student2 #####")
                        s2_best_dev, s2_best_test, _ = validation(args, model_s2, tokenizer, labels, pad_token_label_id, \
                            s2_best_dev, s2_best_test, global_step, t_total, epoch, "student2")
                        logger.info("##### Teacher2 #####")
                        t2_best_dev, t2_best_test, dev_is_updated2 = validation(args, model_t2, tokenizer, labels, pad_token_label_id, \
                            t2_best_dev, t2_best_test, global_step, t_total, epoch, "teacher2")
                        t_model1, t_model2 = get_teacher(args, model_t1, model_t2, t_model1, t_model2, dev_is_updated1, dev_is_updated2)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

        logger.info("***** Epoch : %d *****", epoch)
        logger.info("##### Student1 #####")
        s1_best_dev, s1_best_test, _ = validation(args, model_s1, tokenizer, labels, pad_token_label_id, \
            s1_best_dev, s1_best_test, global_step, t_total, epoch, "student1")
        logger.info("##### Teacher1 #####")
        t1_best_dev, t1_best_test, dev_is_updated1 = validation(args, model_t1, tokenizer, labels, pad_token_label_id, \
            t1_best_dev, t1_best_test, global_step, t_total, epoch, "teacher1")
        logger.info("##### Student2 #####")
        s2_best_dev, s2_best_test, _ = validation(args, model_s2, tokenizer, labels, pad_token_label_id, \
            s2_best_dev, s2_best_test, global_step, t_total, epoch, "student2")
        logger.info("##### Teacher2 #####")
        t2_best_dev, t2_best_test, dev_is_updated2 = validation(args, model_t2, tokenizer, labels, pad_token_label_id, \
            t2_best_dev, t2_best_test, global_step, t_total, epoch, "teacher2")
        t_model1, t_model2 = get_teacher(args, model_t1, model_t2, t_model1, t_model2, dev_is_updated1, dev_is_updated2, True)

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break
    return 0

def main():
    args = config()
    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )

    # Create output directory if needed
    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    args.device = device

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s -   %(message)s", "%m/%d/%Y %H:%M:%S")
    logging_fh = logging.FileHandler(os.path.join(args.output_dir, 'log.txt'))
    logging_fh.setLevel(logging.DEBUG)
    logging_fh.setFormatter(formatter)
    logger.addHandler(logging_fh)
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )

    set_seed(args)
    labels = get_labels(args.data_dir, args.dataset) # get all tag labels
    num_labels = len(labels)
    pad_token_label_id = CrossEntropyLoss().ignore_index

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()

    tokenizer = RobertaTokenizer.from_pretrained(
        args.tokenizer_name,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )

    if args.local_rank == 0:
        torch.distributed.barrier()

    if args.do_train:
        train_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode="train")
        finetune(args, 50, train_dataset, tokenizer, labels, pad_token_label_id)

if __name__ == "__main__":
    main()
