
from pynvml import *
from functools import partial
# from predict import learn_and_pred
from robust_loss_pytorch import lossfun
# from calc_weight import calc_weight
import logging
import os
import random
import sys
import lm_eval
import lm_eval.models
import lm_eval.tasks
import lm_eval.evaluator

from evalplus.data import get_mbpp_plus, write_jsonl


import warnings
import pickle
from dataclasses import dataclass, field
from typing import Optional
from sentence_transformers import SentenceTransformer
import matplotlib
import matplotlib.pyplot as plt
import datasets
import evaluate
import numpy as np
from datasets import load_dataset, Dataset, load_from_disk, concatenate_datasets
from datasets import Features, Sequence, Value, ClassLabel
from datasets import disable_caching

import torch
from torcheval.metrics.text import Perplexity
from torch import nn

import transformers
from transformers import LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig, GenerationConfig, GemmaTokenizer, GemmaForCausalLM,AutoTokenizer,AutoModelForCausalLM
from transformers.pipelines.pt_utils import KeyDataset
from peft import prepare_model_for_kbit_training, PromptEmbedding, PromptTuningConfig


from transformers import (
    pipeline,
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    PretrainedConfig,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
)
from transformers import DataCollatorForSeq2Seq

from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from tqdm.auto import tqdm
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

from lorahub.algorithm import lorahub_learning, lorahub_inference
from lorahub.constant import LORA_MODULE_NAMES
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq, T5ForConditionalGeneration, T5Tokenizer
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
from evaluate import load
from peft import get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType


from distiller import Distiller


perplexity_metric = Perplexity()
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
#check_min_version("4.34.0.dev0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

logger = logging.getLogger(__name__)

@dataclass
class DataTrainingArguments:
    jname: Optional[str] = field(
        default = 'default_job',
        metadata={
            "help": (
                'Job name'
                )
        },
    )
    dora: Optional[bool] = field(
        default = False
        ,
        metadata={
            "help": (
                'Whether to use dora instead of lora for finetuning'
                )
        },
    )
    prompt_tuning: Optional[bool] = field(
        default = False
        ,
        metadata={
            "help": (
                'Whether to use prompt tuning'
                )
        },
    )
    lmeval: Optional[bool] = field(
        default = False
        ,
        metadata={
            "help": (
                'Use lmeval package for evaluation'
                )
        },
    )
    num_fewshot: Optional[int] = field(
        default = 5
        ,
        metadata={
            "help": (
                'Number of few shot for lmeval; usually set to 0'
                )
        },
    )
    train_meta_ood: Optional[bool] = field(
        default = False,
        metadata={
            "help": (
                'Enable discriminator training by setting this to True'
                )
        },
    )
    train_meta: Optional[bool] = field(
        default = False
    )
    do_eval_on_train: Optional[bool] = field(
        default = False,
        metadata={
            "help": (
                'Do evaluation on training set'
                )
        },
    )
    do_save: Optional[bool] = field(
        default = True,
        metadata={
            "help": (
                'Save finetuned model'
                )
        },
    )
    maxoutputlen: Optional[int] = field(
        default = 100,
        metadata={
            "help": (
                'Maximum output length'
                )
        },
    )
    eos: Optional[str] = field(
        default = None,
        metadata={
            "help": (
                'End of string token to use'
                )
        },
    )
    evalmode: Optional[str] = field(
        default = 'strcomp',
        metadata={
            "help": (
                'Evaluation mode when not using packages for evaluation'
                )
        },
    )
    prompting: Optional[str] = field(
        default = None,
        metadata={
            "help": (
                'Prompt formatting for finetuning; use "inst" for replication of our results'
                )
        },
    )
    completion_only: Optional[bool] = field(
        default = True,
        metadata={
            "help": (
                'Whether to finetune using completion only'
                )
        },
    )
    full_finetune: Optional[bool] = field(
        default = False,
        metadata={
            "help": (
                'Do full finetuning rather than PEFT'
                )
        },
    )
    synth: Optional[str] = field(
        default = None,
        metadata={
            "help": (
                'synthetic data to use'
                )
        },
    )
    basemodel: Optional[str] = field(
        default = 't5xl',
        metadata={
            "help": (
                'Base model (target model) for building PEFT'
                )
        },
    )
    teacher_bm: Optional[str] = field(
        default = 't5l',
        metadata={
            "help": (
                "Teacher model (source model)"
                )
        },
    )
    stream: Optional[bool] = field(
        default = False
    )
    lorasource: Optional[str] = field(
        default = None,
        metadata={
            "help": (
                'Source of PEFT teacher model'
                )
        },
    )
    pretrained: Optional[str] = field(
        default = None,
        metadata={
            "help": (
                'Enable to use pretrained base model'
                )
        },
    )
    lh: Optional[bool] = field(
        default = False,
        metadata={
            "help": (
                'Use lorahub weights'
                )
        },
    )
    task_name: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                'Name of task'
                )
        },
    )
    distill: bool = field(
        default=False,
        metadata={
            "help": (
                'Enable distillation'
                )
        },
    )
    selfdist: bool = field(
        default=False,
        metadata={
            "help": (
                'Enable self-distillation'
                )
        },
    )
    lst: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                'Type of in-place filtering to use; deprecated'
                )
        },
    )
    data: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                'Data source'
                )
        },
    )
    diff: Optional[str] = field(
        default=None,
    )
    lossfunc: Optional[str] = field(
        default=None,
    )
    max_seq_length: int = field(
        default=128,
        metadata={
            "help": (
                "The maximum total input sequence length after tokenization. Sequences longer "
                "than this will be truncated, sequences shorter will be padded."
            )
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
    )
    pad_to_max_length: bool = field(
        default=True,
        metadata={
            "help": (
                "Whether to pad all samples to `max_seq_length`. "
                "If False, will pad the samples dynamically when batching to the maximum length in the batch."
            )
        },
    )
    max_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_predict_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of prediction examples to this "
                "value if set."
            )
        },
    )
    train_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the training data."}
    )
    validation_file: Optional[str] = field(
        default=None, metadata={"help": "A csv or a json file containing the validation data."}
    )
    test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        default='flan-t5-xl', metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    token: str = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
            )
        },
    )
    use_auth_token: bool = field(
        default=None,
        metadata={
            "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
        },
    )
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
                "should only be set to `True` for repositories you trust and in which you have read the code, as it will"
                "execute code present on the Hub on your local machine."
            )
        },
    )
    ignore_mismatched_sizes: bool = field(
        default=False,
        metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
    )


def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if model_args.use_auth_token is not None:
        warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning)
        if model_args.token is not None:
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
        model_args.token = model_args.use_auth_token

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    if training_args.should_log:
        transformers.utils.logging.set_verbosity_info()

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    hftoken = 'YOUR HF TOKEN'
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    set_seed(training_args.seed)

    # -------------------------------------------------------------------- #
    # DATA LOADING #
    # -------------------------------------------------------------------- #    
    if data_args.data == 'gsm8k':
        train_dataset = load_dataset("gsm8k", 'main', split='train')
        eval_dataset = load_dataset("gsm8k", 'main', split='test')
        train_dataset=train_dataset.rename_column('question', 'inputs')
        train_dataset=train_dataset.rename_column('answer', 'targets')
        eval_dataset=eval_dataset.rename_column('question', 'inputs')
        eval_dataset=eval_dataset.rename_column('answer', 'targets')

        def prepgsm(examples):
            examples['targets'] = examples['targets'].split('####')[1][1:]
            return examples
        train_dataset=train_dataset.map(prepgsm, batched=False)
        eval_dataset=eval_dataset.map(prepgsm, batched=False)
        train_dataset= train_dataset.select(range(250))
        if data_args.train_meta_ood:
            dstlst = []
            for i in range(1):
                try:
                    datasetw = load_from_disk('SYNTH DATASET PATH'+data_args.synth+'_'+data_args.task_name+str(i))
                    dstlst.append(datasetw)
                except:
                    continue
            synth_dataset = concatenate_datasets(dstlst)
            def change_label_yes(example):
                example['targets'] = 'Yes'
                return example
            def change_label_no(example):
                example['targets'] = 'No'
                return example
            train_dataset = train_dataset.map(change_label_yes, batched=False)
            synth_dataset = synth_dataset.map(change_label_no, batched=False).select(range(len(train_dataset)))
            train_dataset = concatenate_datasets([train_dataset, synth_dataset])
            train_dataset=train_dataset.shuffle(seed=42)
            dataset_ = train_dataset.train_test_split(test_size=0.1,seed=42)
            train_dataset = dataset_['train']
            eval_dataset = dataset_['test']
        elif data_args.synth is not None:
            dstlst = []
            for i in range(50):
                try:
                    datasetw = load_from_disk('SYNTH DATASET PATH'+data_args.synth+'_'+data_args.task_name+str(i))
                    dstlst.append(datasetw)
                except:
                    continue
            train_dataset = concatenate_datasets(dstlst)
        
    elif data_args.data == 'mbpp':
        mbppdic = {'inputs':[], 'targets':[]}
        mbppdic2 = {'inputs':[], 'targets':[]}
        for task_id, problem in get_mbpp_plus().items():
            if int(task_id[5:]) not in range(11,511):
                mbppdic['inputs'].append(problem['prompt'])
                mbppdic['targets'].append(problem['canonical_solution'])
            else:
                mbppdic2['inputs'].append(problem['prompt'])
                mbppdic2['targets'].append(problem['canonical_solution'])
        train_dataset = Dataset.from_dict(mbppdic)
        eval_dataset = Dataset.from_dict(mbppdic2)
        datalen = len(train_dataset)
        if data_args.train_meta_ood:
            dstlst = []
            for i in range(1):
                try:
                    datasetw = load_from_disk('SYNTH DATASET PATH'+data_args.synth+'_'+data_args.task_name+str(i))
                    dstlst.append(datasetw)
                except:
                    continue
            synth_dataset = concatenate_datasets(dstlst)
            def change_label_yes(example):
                example['targets'] = 'Yes'
                return example
            def change_label_no(example):
                example['targets'] = 'No'
                return example
            train_dataset = train_dataset.map(change_label_yes, batched=False)
            synth_dataset = synth_dataset.map(change_label_no, batched=False).select(range(len(train_dataset)))
            train_dataset = concatenate_datasets([train_dataset, synth_dataset])
            train_dataset=train_dataset.shuffle(seed=42)
            dataset_ = train_dataset.train_test_split(test_size=0.1,seed=42)
            train_dataset = dataset_['train']
            eval_dataset = dataset_['test']
        elif data_args.synth is not None:
            dstlst = []
            for i in range(50):
                try:
                    datasetw = load_from_disk('SYNTH DATASET PATH'+data_args.synth+'_'+data_args.task_name+str(i))
                    dstlst.append(datasetw)
                except:
                    continue
            train_dataset = concatenate_datasets(dstlst)
    elif data_args.data == 'bbh':
        dataset = load_dataset("lukaemon/bbh",data_args.task_name)
        dataset = dataset['test']
        l = dataset.shape[0]
        dataset = dataset.rename_column('input', 'inputs')
        dataset = dataset.rename_column('target', 'targets')
        dataset = dataset.train_test_split(test_size = 0.1, seed = 42)
        train_dataset=dataset['train']
        eval_dataset=dataset['test']
        eval_dataset2 = train_dataset.select(range(len(train_dataset)))
        if data_args.train_meta_ood:
            dstlst = []
            for i in range(1):
                try:
                    datasetw = load_from_disk('SYNTH DATASET PATH'+data_args.synth+'_'+data_args.task_name+str(i))
                    dstlst.append(datasetw)
                except:
                    continue
            synth_dataset = concatenate_datasets(dstlst)
            def change_label_yes(example):
                example['targets'] = 'Yes'
                return example
            def change_label_no(example):
                example['targets'] = 'No'
                return example
            train_dataset = train_dataset.map(change_label_yes, batched=False)
            synth_dataset = synth_dataset.map(change_label_no, batched=False).select(range(len(train_dataset)))
            train_dataset = concatenate_datasets([train_dataset, synth_dataset])
            train_dataset=train_dataset.shuffle(seed=42)
            dataset_ = train_dataset.train_test_split(test_size=0.1,seed=42)
            train_dataset = dataset_['train']
            eval_dataset = dataset_['test']
        elif data_args.synth is not None:
            dstlst = []
            for i in range(50):
                try:
                    datasetw = load_from_disk('SYNTH DATASET PATH'+data_args.synth+'_'+data_args.task_name+str(i))
                    dstlst.append(datasetw)
                except:
                    continue
            train_dataset = concatenate_datasets(dstlst)

    elif data_args.data == 'mmlu':
        tasklst = ['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions']
        ansdic = {0:'A',1:'B',2:'C',3:'D'}
        def format_mmlu(example):
            example['inputs'] = example['question'] + '\n(A) '+example['choices'][0] + '\n(B) '+example['choices'][1] + '\n(C) '+example['choices'][2] + '\n(D) '+example['choices'][3]
            example['targets'] = ansdic[example['answer']]
            return example
        train_dataset = load_dataset("cais/mmlu", data_args.task_name,split='test')
        eval_dataset = load_dataset("cais/mmlu", data_args.task_name,split='validation')
        train_dataset = train_dataset.map(format_mmlu, batched=False)
        eval_dataset = eval_dataset.map(format_mmlu, batched=False)
        eval_dataset2 = train_dataset.select(range(len(train_dataset)))
        if data_args.train_meta_ood:
            dstlst = []
            for i in range(1):
                try:
                    datasetw = load_from_disk('SYNTH DATASET PATH'+data_args.synth+'_'+data_args.task_name+str(i))
                    dstlst.append(datasetw)
                except:
                    continue
            synth_dataset = concatenate_datasets(dstlst)
            def change_label_yes(example):
                example['targets'] = 'Yes'
                return example
            def change_label_no(example):
                example['targets'] = 'No'
                return example
            train_dataset = train_dataset.map(change_label_yes, batched=False)
            synth_dataset = synth_dataset.map(change_label_no, batched=False).select(range(len(train_dataset)))
            train_dataset = concatenate_datasets([train_dataset, synth_dataset])
            train_dataset=train_dataset.shuffle(seed=42)
            dataset_ = train_dataset.train_test_split(test_size=0.1,seed=42)
            train_dataset = dataset_['train']
            eval_dataset = dataset_['test']
        elif data_args.synth is not None:
            dstlst = []
            for i in range(50):
                try:
                    datasetw = load_from_disk('SYNTH DATASET PATH'+data_args.synth+'_'+data_args.task_name+str(i))
                    dstlst.append(datasetw)
                except:
                    continue
            train_dataset = concatenate_datasets(dstlst)
    
    train_dataset = train_dataset.select(range(min(len(train_dataset), data_args.max_samples)))
    # -------------------------------------------------------------------- #
    # MODEL LOADING #
    # -------------------------------------------------------------------- #
    
    class CustomForwardModel(PeftModel):
        def forward(self, input):
            print(input)
            return super().forward(input)
    
    bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
    bnb_config_gemma = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
    if data_args.lh:
        model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
        model = PeftModel.from_pretrained(model, "lorahub/flan_t5_large-"+data_args.task_name.replace('/','_'))
        tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large", padding_side = 'right')
    else:
        if data_args.basemodel == 't5xl':
            basemodel = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-xl')
            tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
        elif data_args.basemodel == 't5l':
            basemodel = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-large')
            tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
        elif data_args.basemodel == 'llama7b':
            basemodel = LlamaForCausalLM.from_pretrained('meta-llama/Llama-2-7b-chat-hf', quantization_config=bnb_config, token=hftoken)
            # tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf', token=hftoken, pad_token = '[PAD]')
            tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf', token=hftoken)
            tokenizer.pad_token_id = tokenizer.eos_token_id
        elif data_args.basemodel == 'llama13b':
            basemodel = LlamaForCausalLM.from_pretrained('meta-llama/Llama-2-13b-chat-hf', quantization_config=bnb_config, token=hftoken)
            # tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-13b-chat-hf', token=hftoken, pad_token = '[PAD]')
            tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-13b-chat-hf', token=hftoken)
            tokenizer.pad_token_id = 0
        elif data_args.basemodel == 'gemma2b':
            tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", token=hftoken)
            basemodel = AutoModelForCausalLM.from_pretrained("google/gemma-2b", quantization_config=bnb_config_gemma, token=hftoken, torch_dtype=torch.float32)
        elif data_args.basemodel == 'gemma7b':
            tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b", token=hftoken)
            basemodel = AutoModelForCausalLM.from_pretrained("google/gemma-7b", quantization_config=bnb_config_gemma, token=hftoken, torch_dtype=torch.float32)
        elif data_args.basemodel == 'phi2':
            basemodel = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", quantization_config=bnb_config_gemma, token=hftoken)
            tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
            tokenizer.pad_token = tokenizer.eos_token
        elif data_args.basemodel == 'phi3':
            basemodel = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True, quantization_config=bnb_config_gemma, token=hftoken)
            tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True)
            tokenizer.pad_token = tokenizer.eos_token
        if data_args.prompt_tuning:
            peft_config = PromptTuningConfig(
                task_type=TaskType.CAUSAL_LM,
                prompt_tuning_init=PromptTuningInit.TEXT,
                num_virtual_tokens=8,
                prompt_tuning_init_text="Answer the following question correctly:",
                tokenizer_name_or_path="google/gemma-7b",
            )
        elif data_args.basemodel.startswith('t5'):
            peft_config = LoraConfig(
                task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=16
            )
        elif data_args.basemodel.startswith('llama'):
            peft_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM, inference_mode=False, r=16, use_dora=data_args.dora 
            )
        elif data_args.basemodel.startswith('gemma') or data_args.basemodel.startswith('phi'):
            target_modules=["q_proj", "v_proj"]
            # target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"]
            peft_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM, inference_mode=False, r=16, target_modules = target_modules, use_dora=data_args.dora
            )
        
        model = get_peft_model(basemodel, peft_config)
        model.train()

    device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
    if data_args.teacher_bm == 'llama7b':
        if data_args.selfdist:
            hubmodel = model
            hubtokenizer = tokenizer
        else:
            llama = LlamaForCausalLM.from_pretrained('meta-llama/Llama-2-7b-chat-hf', quantization_config=bnb_config, token=hftoken)
            if data_args.lorasource is None:
                hubmodel = llama
            if data_args.lorasource == 'mmlu':
                hubmodel = PeftModel.from_pretrained(llama, 'MMLU MODEL DIRECTORY')
                hubmodel = hubmodel.to(device)
            elif data_args.lorasource == 'bbh':
                hubmodel = PeftModel.from_pretrained(llama, 'BBH MODEL DIRECTORY')
                hubmodel = hubmodel.to(device)
            elif data_args.lorasource == 'mbpp':
                hubmodel = PeftModel.from_pretrained(llama, 'MBPP MODEL DIRECTORY')
                hubmodel = hubmodel.to(device)
            elif data_args.lorasource == 'gsm8k':
                hubmodel = PeftModel.from_pretrained(llama, 'GSM8K MODEL DIRECTORY')
                hubmodel = hubmodel.to(device)
            elif data_args.lorasource == 'dorabbh':
                hubmodel = PeftModel.from_pretrained(llama, 'BBH DORA MODEL DIRECTORY')
                hubmodel = hubmodel.to(device)
            elif data_args.lorasource == 'ptbbh':
                hubmodel = PeftModel.from_pretrained(llama, 'BBH PROMPT TUNING MODEL DIRECTORY')
                hubmodel = hubmodel.to(device)
            else:
                hubmodel = PeftModel.from_pretrained(llama, 'MODEL DIRECTORY')
            
            hubmodel.eval()
            hubtokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf', token=hftoken)
            hubtokenizer.pad_token_id = 0
    elif data_args.teacher_bm == 'llama13b':
        base13b = hubmodel = LlamaForCausalLM.from_pretrained('meta-llama/Llama-2-13b-chat-hf', quantization_config=bnb_config, token=hftoken)
        if data_args.lorasource == 'bbh':
            hubmodel = PeftModel.from_pretrained(base13b, 'BBH MODEL DIRECTORY')
            hubmodel = hubmodel.to(device)
        else:
            hubmodel = LlamaForCausalLM.from_pretrained('meta-llama/Llama-2-13b-chat-hf', quantization_config=bnb_config, token=hftoken)
        hubtokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-13b-chat-hf', token=hftoken)
        hubtokenizer.pad_token_id = 0
    elif data_args.teacher_bm == 'gemma2b':
        hubtokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", token=hftoken)
        gemmaa = AutoModelForCausalLM.from_pretrained("google/gemma-2b", quantization_config=bnb_config_gemma, token=hftoken, torch_dtype=torch.float32)
        if data_args.lorasource == 'mmlu':
            hubmodel = PeftModel.from_pretrained(gemmaa, 'MMLU MODEL DIRECTORY')
            hubmodel = hubmodel.to(device)
        elif data_args.lorasource == 'bbh':
            hubmodel = PeftModel.from_pretrained(gemmaa, 'BBH MODEL DIRECTORY')
            hubmodel = hubmodel.to(device)
        elif data_args.lorasource == 'mbpp':
            hubmodel = PeftModel.from_pretrained(gemmaa, 'MBPP MODEL DIRECTORY')
            hubmodel = hubmodel.to(device)
        elif data_args.lorasource == 'gsm8k':
            hubmodel = PeftModel.from_pretrained(gemmaa, 'GSM8K MODEL DIRECTORY')
            hubmodel = hubmodel.to(device)
        elif data_args.lorasource == 'dorabbh':
            hubmodel = PeftModel.from_pretrained(gemmaa, 'BBH DORA MODEL DIRECTORY')
            hubmodel = hubmodel.to(device)
        elif data_args.lorasource == 'ptbbh':
            hubmodel = PeftModel.from_pretrained(gemmaa, 'BBH PROMPT TUNING MODEL DIRECTORY')
            hubmodel = hubmodel.to(device)
        else:
            hubmodel = PeftModel.from_pretrained(gemmaa, 'MODEL DIRECTORY')

    if data_args.pad_to_max_length:
        padding = "max_length"
    else:
        padding = False
    if data_args.max_seq_length > tokenizer.model_max_length:
        logger.warning(
            f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
            f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
        )
    max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
    print('max seq length: ' + str(max_seq_length), file=sys.stderr)


    # -------------------------------------------------------------------- #
    # TRAINING PREPARATION #
    # -------------------------------------------------------------------- #


    print("train dataset length: " + str(train_dataset.shape[0]), file=sys.stderr)
    print("eval dataset length: " + str(eval_dataset.shape[0]), file=sys.stderr)
    sm = nn.Softmax(dim = 2)
    def synth_proc(examples):
        examples['inputs'] = examples['inputs'] + ' \n ' + examples['targets']
        return examples
     
    def preprocess_function(examples):
        result = tokenizer(examples['inputs'], text_target = examples['targets'], \
                           padding=padding, max_length=max_seq_length, truncation=True)
        return result
    idxcount = 0
    def hub_pred(examples):
        if data_args.data == 'mbpp':
            out = hubmodel.generate(inputs=hubtokenizer(examples['inputs'], return_tensors="pt").input_ids.to(device), max_new_tokens=2048)
            output = hubtokenizer.decode(out[0])
            examples['targets'] = output[len(examples['inputs'])+6:]
            if data_args.teacher_bm.startswith('gemma'):
                examples['targets'] = output[len(examples['inputs'])+5:]
            elif data_args.basemodel.startswith('phi'):
                examples['targets'] = output[len(examples['inputs'])+1:]
        elif data_args.synth == 'randword':
            prompt = '# '+examples['inputs']
            out = hubmodel.generate(inputs=hubtokenizer(prompt, return_tensors="pt").input_ids.to(device))
            output = hubtokenizer.decode(out[0])
            examples['targets'] = output
        elif data_args.prompting == '3+1':
            prompt = three_sample_prompt+" \n### Question: "+examples['inputs']+" \n### Answer: "
            out = hubmodel.generate(inputs=hubtokenizer(prompt, return_tensors="pt").input_ids.to(device))
            output = hubtokenizer.decode(out[0])
            examples['targets'] = output[len(prompt)+5:]
        elif data_args.prompting == '3+1_noidx':
            prompt = three_sample_prompt_noidx+" \n### Question: "+examples['inputs']+" \n### Answer: "
            out = hubmodel.generate(inputs=hubtokenizer(prompt, return_tensors="pt").input_ids.to(device))
            output = hubtokenizer.decode(out[0])
            examples['targets'] = output[len(prompt)+5:]
        elif data_args.prompting == '1':
            prompt = "### Question: "+examples['inputs']+" \n### Answer: "
            out = hubmodel.generate(inputs=hubtokenizer(prompt, return_tensors="pt").input_ids.to(device))
            output = hubtokenizer.decode(out[0])
            examples['targets'] = output[len(prompt)+5:]
        elif data_args.prompting == 'inst':
            if data_args.teacher_bm.startswith('llama'):
                prompt = f"[INST] <<SYS>>\nAnswer in as few words as possible.\n<</SYS>>\n\n{examples['inputs']} [/INST]"
            if data_args.teacher_bm.startswith('gemma'):
                prompt = f"<start_of_turn>user\nAnswer in as few words as possible.\n\n{examples['inputs']}<end_of_turn>\n<start_of_turn>model\n"
            if data_args.teacher_bm.startswith('phi'):
                prompt = f"Answer in as few words as possible.\n{examples['inputs']} \nAnswer: "

            out = hubmodel.generate(inputs=hubtokenizer(prompt, return_tensors="pt").input_ids.to(device), do_sample=False,max_length=1024)
            output = hubtokenizer.decode(out[0], do_sample=False)
            examples['targets'] = output[len(prompt)+6:]
            if data_args.teacher_bm.startswith('gemma'):
                examples['targets'] = output[len(prompt)+5:]
            elif data_args.basemodel.startswith('phi'):
                examples['targets'] = output[len(prompt)+1:]
        else:
            out = hubmodel.generate(inputs=hubtokenizer(examples['inputs'], return_tensors="pt").input_ids.to(device))
            output = hubtokenizer.decode(out[0])
            examples['targets'] = output
        if data_args.eos is not None:
            examples['targets'] = examples['targets'].split(data_args.eos)[0]     
        examples['targets'] = examples['targets'][:data_args.maxoutputlen]
        return examples
    def hub_pred_batched(examples):
        input_ids = hubtokenizer(examples['inputs'], truncation=True, padding=padding,return_tensors="pt").input_ids
        examples['targets'] = hubtokenizer.batch_decode(hubmodel.generate(inputs=input_ids.to(device)))
        return examples
    def meta_prep(examples):
        prompt = f"[INST] <<SYS>>\nAnswer in as few words as possible.\n<</SYS>>\n\n{examples['inputs']} [/INST]"
        if data_args.basemodel.startswith('gemma'):
            prompt = f"<start_of_turn>user\nAnswer in as few words as possible.\n\n{examples['inputs']}<end_of_turn>\n<start_of_turn>model\n"
        out = hubmodel.generate(inputs=hubtokenizer(prompt, return_tensors="pt").input_ids.to(device),do_sample=False)
        output = hubtokenizer.decode(out[0])
        if data_args.teacher_bm.startswith('llama'):
            output = output[len(prompt)+6:]
        if data_args.teacher_bm.startswith('gemma'):
            output = output[len(prompt)+5:]
        if data_args.teacher_bm.startswith('phi'):
            output = output[len(prompt)+1:]
        if data_args.evalmode == 'strcomp':
            if (output[:len(examples['targets'])] == examples['targets']):
                examples['targets'] = 'Yes'
            else:
                examples['targets'] = 'No'
        return examples
    with training_args.main_process_first(desc="dataset map pre-processing"):
        if data_args.distill:
            if data_args.teacher_bm.startswith('t5'):
                train_dataset = train_dataset.map(
                    hub_pred_batched,
                    batched=True,
                    batch_size = 256,
                    desc="Running hubtokenizer on dataset",
                )
            elif data_args.teacher_bm.startswith('llama') or data_args.teacher_bm.startswith('gemma'):
                if data_args.completion_only or True:
                    train_dataset = train_dataset.map(
                        hub_pred,
                        batched=False,
                        desc="Running hubtokenizer on dataset",
                    )
                else:
                    train_dataset = train_dataset.map(
                        hub_pred,
                        batched=True,
                        batch_size = 8,
                        desc="Running hubtokenizer on dataset",
                    )
        if data_args.train_meta:
            train_dataset = train_dataset.map(
                meta_prep,
                batched = False
            )
            eval_dataset = eval_dataset.map(
                meta_prep,
                batched = False
            )
        print(train_dataset.features)
    metric = evaluate.load("accuracy")
    # metric = evaluate.load("exact_match")
    
    def compute_metrics(p: EvalPrediction):
        pred = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
        preds = np.argmax(pred, axis = 2)
        if data_args.basemodel.startswith('t5'):
            for i in range(len(pred)):
                for j in range(len(pred[i])):
                    if preds[i,j] == 1:
                        preds[i,j+1:] = 0
                        pred[i,j+1:,:] = 0
                        pred[i,j+1:,0] = 1
                        break
        elif data_args.basemodel.startswith('llama'):
            for i in range(len(preds)):
                for j in range(len(preds[i])):
                    if p.label_ids[i,j] == 1 and j > 0:
                        preds[i,j:] = -100
                        break
            
        # print(preds[0])
        result = metric.compute(predictions=preds.flatten(), references=p.label_ids.flatten())
        #result = metric.compute(predictions=preds, references=p.label_ids)
        if len(result) > 1:
            result["combined_score"] = np.mean(list(result.values())).item()
        print(result, file=sys.stderr)
        return result
                        
    class CustomTrainer(Trainer):
        def compute_loss(self, model, inputs, return_outputs=False):

            outputs = model.forward(**inputs)
            
            loss_fct = None
            if data_args.lossfunc == 'l2':
                loss_fct = nn.MSELoss()
            elif data_args.lossfunc == 'l1':
                loss_fct = nn.L1Loss()
            elif data_args.lossfunc == 'robust':
                loss_fct = lossfun
            elif data_args.lossfunc == 'cn':
                if data_args.diff == 'greedy':
                    loss_fct = nn.CrossEntropyLoss(ignore_index=hubtokenizer.pad_token_id)
                else:
                    loss_fct = nn.CrossEntropyLoss()
            if loss_fct is None:
                loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
                return (loss, outputs) if return_outputs else loss
            
            sm = nn.Softmax(dim = 2)
            if data_args.diff == 'lg':
                output = outputs.get('logits')
            else:
                output = sm(outputs.get('logits'))
            
            if data_args.distill:
                if data_args.diff == 'sm':
                    label = sm(hubmodel(**inputs).get('logits')) - sm(t5l(**inputs).get('logits')) \
                            + sm(basemodel(**inputs).get('logits'))
                elif data_args.diff == 'lg':
                    label = hubmodel(**inputs).get('logits') - t5l(**inputs).get('logits') \
                            + basemodel(**inputs).get('logits')
                elif data_args.diff == 'greedy':
                    label = torch.argmax(hubmodel(**inputs).get('logits'),dim=1)
                else:
                    label = hubmodel(**inputs).get('logits')
                    label = sm(label)
            else:
                label = inputs.get('labels')
                label = torch.nn.functional.one_hot(label, num_classes=outputs.get('logits').shape[2]).float()
            loss = loss_fct(output, label)
                
            return (loss, outputs) if return_outputs else loss
    if data_args.basemodel.startswith('t'):
        data_collator = default_data_collator
    else:
        data_collator = transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
    if data_args.data == 'mbpp':
        data_args.task_name='mbpp'
    if data_args.completion_only:
        if data_args.train_meta_ood:
            def formatting_prompts_func(example):
                output_texts = []
                for i in range(len(example['inputs'])):
                    text = f"[INST] <<SYS>>\nAnswer in as few words as possible.\n<</SYS>>\n\n{example['inputs'][i]}\nIs the above question from the {data_args.task_name.replace('_', ' ')} dataset? [/INST]  {example['targets'][i]}"
                    if data_args.basemodel.startswith('gemma'):
                        text = f"<start_of_turn>user\nAnswer in as few words as possible.\n\n{example['inputs'][i]}\nIs the above question from the {data_args.task_name.replace('_', ' ')} dataset?<end_of_turn>\n<start_of_turn>model\n{example['targets'][i]}<eos>"
                    elif data_args.basemodel.startswith('phi'):
                        text = f"Answer in as few words as possible.\n{example['inputs'][i]} \nIs the above question from the {data_args.task_name.replace('_', ' ')} dataset? \nAnswer: {example['targets'][i]}"

                    if data_args.eos is not None:
                        text = text + data_args.eos
                    output_texts.append(text)
                return output_texts
            def formatting_prompts_map(example):
                if data_args.basemodel.startswith('llama'):
                    example['inputs'] = f"[INST] <<SYS>>\nAnswer in as few words as possible.\n<</SYS>>\n\n{example['inputs']}\nIs the above question from the {data_args.task_name.replace('_', ' ')} dataset? [/INST]"
                if data_args.basemodel.startswith('gemma'):
                    example['inputs'] = f"<start_of_turn>user\nAnswer in as few words as possible.\n\n{example['inputs']}\nIs the above question from the {data_args.task_name.replace('_', ' ')} dataset?<end_of_turn>\n<start_of_turn>model\n"
                if data_args.basemodel.startswith('phi'):
                    example['inputs'] = f"Answer in as few words as possible.\n{example['inputs']} \nIs the above question from the {data_args.task_name.replace('_', ' ')} dataset? \nAnswer: "
                    
                return example
            response_template = '[/INST]'
            if data_args.basemodel.startswith('gemma'):
                response_template = '<start_of_turn>model\n'
            if data_args.basemodel.startswith('phi'):
                response_template = 'Answer:'
        elif data_args.data == 'mbpp':
            def formatting_prompts_func(example):
                output_texts = []
                for i in range(len(example['inputs'])):
                    if data_args.basemodel.startswith('gemma'):
                        output_texts.append(example['inputs'][i]+"###"+example['targets'][i])
                    if data_args.basemodel.startswith('llama'):
                        output_texts.append(example['inputs'][i]+" ### "+example['targets'][i])
                return output_texts
            def formatting_prompts_map(example):
                if data_args.basemodel.startswith('gemma'):
                    example['inputs'] = example['inputs'] + "###"
                if data_args.basemodel.startswith('llama'):
                    example['inputs'] = example['inputs'] + " ### "
                return example
            response_template = '###'
        elif data_args.prompting == 'inst':
            if data_args.basemodel.startswith('gemma'):
                def formatting_prompts_func(example):
                    output_texts = []
                    for i in range(len(example['inputs'])):
                        text = f"<start_of_turn>user\nAnswer in as few words as possible.\n\n{example['inputs'][i]}<end_of_turn>\n<start_of_turn>model\n{example['targets'][i]}<eos>"
                        # text = f"{example['inputs'][i]}\n\nAnswer: {example['targets'][i]}<eos>"
                        # if data_args.data == 'bbh':
                        #     text = f"Question: {example['inputs'][i]}\nAnswer:{example['targets'][i]}"

                        if data_args.eos is not None:
                            text = text + data_args.eos
                        output_texts.append(text)
                    return output_texts
                def formatting_prompts_map(example):
                    # example['inputs'] = f"{example['inputs']}\n\nAnswer: "
                    # if data_args.data == 'bbh':
                    #     example['inputs'] = f"Question: {example['inputs']}\nAnswer:"
                    # else:
                    example['inputs'] = f"<start_of_turn>user\nAnswer in as few words as possible.\n\n{example['inputs']}<end_of_turn>\n<start_of_turn>model\n"
                    return example
                response_template = '<start_of_turn>model\n'
                # if data_args.data == 'bbh':
                #     response_template = '\nAnswer:'
                # response_template = 'Answer:'
            if data_args.basemodel.startswith('llama'):
                def formatting_prompts_func(example):
                    output_texts = []
                    for i in range(len(example['inputs'])):
                        text = f"[INST] <<SYS>>\nAnswer in as few words as possible.\n<</SYS>>\n\n{example['inputs'][i]} [/INST]  {example['targets'][i]}"
                        if data_args.eos is not None:
                            text = text + data_args.eos
                        output_texts.append(text)
                    return output_texts
                def formatting_prompts_map(example):
                    example['inputs'] = f"[INST] <<SYS>>\nAnswer in as few words as possible.\n<</SYS>>\n\n{example['inputs']} [/INST]"
                    return example
                response_template = '[/INST]'
            if data_args.basemodel.startswith('phi'):
                def formatting_prompts_func(example):
                    output_texts = []
                    for i in range(len(example['inputs'])):
                        text = f"Answer in as few words as possible.\n{example['inputs'][i]} \nAnswer: {example['targets'][i]}"
                        if data_args.eos is not None:
                            text = text + data_args.eos
                        output_texts.append(text)
                    return output_texts
                def formatting_prompts_map(example):
                    example['inputs'] = f"Answer in as few words as possible.\n{example['inputs']} \nAnswer: "
                    return example
                response_template = 'Answer:'
        data_collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
        eval_dataset = eval_dataset.map(formatting_prompts_map,
            batched=False,
            desc="Completion template on eval data",)


    # -------------------------------------------------------------------- #
    # TRAIN AND EVAL #
    # -------------------------------------------------------------------- #
    training_args.eval_accumulation_steps = 64
    training_args.logging_steps=1
    # training_args.fp16=False
    # if data_args.basemodel.startswith('gemma'):
    #     training_args.optim = 'paged_adamw_32bit'
        # training_args.fp16=True
    # training_args.evaluation_strategy='steps'
    # training_args.eval_steps=100
    # training_args.adam_beta2=0.95
    if data_args.completion_only:
        trainer = SFTTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset if training_args.do_train else None,
            eval_dataset=eval_dataset if training_args.do_eval else None,
            compute_metrics=compute_metrics,
            tokenizer=tokenizer,
            formatting_func=formatting_prompts_func,
            data_collator=data_collator,
        )
    else:
        trainer = CustomTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset if training_args.do_train else None,
            eval_dataset=eval_dataset if training_args.do_eval else None,
            compute_metrics=compute_metrics,
            tokenizer=tokenizer,
            data_collator=data_collator,
        )
    print(train_dataset[:10], file=sys.stderr)
    # Training
    print(training_args)
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        metrics = train_result.metrics

        metrics["train_samples"] = len(train_dataset)

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
    if data_args.do_save:
        trainer.save_model()  # Saves the tokenizer too for easy upload
    # Evaluation
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        if data_args.lmeval and (data_args.data == 'mmlu' or data_args.data == 'bbh' or data_args.data == 'gsm8k'):
            if data_args.basemodel.startswith('gemma'):
                lm_obj = lm_eval.models.huggingface.HFLM(pretrained=model,load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16)
            else:
                lm_obj = lm_eval.models.huggingface.HFLM(pretrained=model,load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16)
            # lm_obj = lm_eval.models.huggingface.HFLM(pretrained='google/gemma-7b',load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16)
            task_manager = lm_eval.tasks.TaskManager()
            if data_args.data == 'mmlu':
                results = lm_eval.simple_evaluate( # call simple_evaluate
                    model=lm_obj,
                    # tasks=["mmlu"],
                    tasks=["mmlu_"+data_args.task_name],
                    num_fewshot=data_args.num_fewshot,
                    task_manager=task_manager,
                )
                print(results['results']["mmlu_"+data_args.task_name]['acc,none'])

            if data_args.data == 'bbh':
                results = lm_eval.simple_evaluate( # call simple_evaluate
                    model=lm_obj,
                    tasks=["bbh_fewshot_"+data_args.task_name],
                    num_fewshot=data_args.num_fewshot,
                    task_manager=task_manager,
                )
                print(results['results']["bbh_fewshot_"+data_args.task_name]['exact_match,none'])
            
            if data_args.data == 'gsm8k':
                results = lm_eval.simple_evaluate( # call simple_evaluate
                    model=lm_obj,
                    tasks=["gsm8k"],
                    num_fewshot=data_args.num_fewshot,
                    task_manager=task_manager,
                )
                print(results['results']["gsm8k"]['exact_match,flexible-extract'])
            return
        elif data_args.lmeval:
            def GEN_SOLUTION(prpt):
                out = model.generate(inputs=tokenizer(prpt, return_tensors="pt").input_ids.to(device),do_sample=False,max_length=1024)
                output = tokenizer.decode(out[0])
                return output[5:-5]
            samples = []
            for task_id, problem in get_mbpp_plus().items():
                # print(task_id, problem)
                if int(task_id[5:]) in range(11,511):
                    samples.append(dict(task_id=task_id, solution=GEN_SOLUTION(problem["prompt"]))) 
            write_jsonl("CODE RESULTS"+data_args.jname+".jsonl", samples)
            return
        if data_args.do_eval_on_train:
            correct = 0
            for i in range(len(eval_dataset2)):
                # print(eval_dataset[i]['inputs'])
                out = model.generate(inputs=tokenizer(eval_dataset2[i]['inputs'], return_tensors="pt").input_ids.to(device),do_sample=False,max_length=1024)
                output = tokenizer.decode(out[0])
                
                if data_args.prompting == 'inst' and data_args.basemodel.startswith('llama'):
                    output = output[len(eval_dataset2[i]['inputs'])+6:]
                elif data_args.teacher_bm.startswith('phi'):
                    output = output[len(eval_dataset2[i]['inputs'])+1:]                
                else:
                    output = output[len(eval_dataset2[i]['inputs'])+5:]
                
                # print(output)
                if data_args.evalmode == 'strcomp':
                    correct += (output[:len(eval_dataset2[i]['targets'])] == eval_dataset2[i]['targets'])
                elif data_args.evalmode == 'eol':
                    correct += (output.split('\n')[0] == eval_dataset2[i]['targets'])
                
            print(correct / len(eval_dataset2))
        correct = 0
        for i in range(len(eval_dataset)):
            print("PROBLEM STATEMENT:")
            print(eval_dataset[i]['inputs'])
            out = model.generate(inputs=tokenizer(eval_dataset[i]['inputs'], return_tensors="pt").input_ids.to(device),do_sample=False,max_length=1024)
            output = tokenizer.decode(out[0])
            
            if data_args.prompting == 'inst' and data_args.basemodel.startswith('llama'):
                output = output[len(eval_dataset[i]['inputs'])+6:]
            elif data_args.basemodel.startswith('phi'):
                output = output[len(eval_dataset[i]['inputs'])+1:]
            else:
                output = output[len(eval_dataset[i]['inputs'])+5:]
            
            print("OUTPUT OF MODEL:")
            print(output)
            # print(hubtokenizer.decode(hubmodel.generate(inputs=hubtokenizer(eval_dataset[i]['inputs'], return_tensors="pt").input_ids.to(device),do_sample=False,max_length=1024)[0]))
            # print(output[:len(eval_dataset[i]['targets'])])
            print("LABEL:")
            print(eval_dataset[i]['targets'])
            if data_args.evalmode == 'strcomp':
                correct += (output[:len(eval_dataset[i]['targets'])] == eval_dataset[i]['targets'])
            elif data_args.evalmode == 'eol':
                correct += (output.split('\n')[0] == eval_dataset[i]['targets'])
            
        print(correct / len(eval_dataset))


def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    # disable_caching()
    os.environ['TRANSFORMERS_CACHE'] = 'PATH TO TRANSFORMERS CACHE'
    os.environ['HF_DATASETS_CACHE'] = 'PATH TO DATASET CACHE'
    main()