#! This script initializes a small LLM and finetune to remember some facts
import os
import hydra
import torch
from hydra.core.hydra_config import HydraConfig
from omegaconf import OmegaConf
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM

from src.utils import init_script
from src.lightningutil.datamodule import create_datamod
from src.lightningutil.util import create_log_dir
from src.lightningutil.modelmodule import init_full_model, init_assisted_model, init_peft_model, init_small_huggingface_llm
from src.language_models import init_offset_model
from src.forget_losses import create_loss_func, REQUIRES_ORACLE, KLLossFunc
from src.hfutil.hf_trainers import ForgetTrainer, WMDP_ForgetTrainer, wmdp_compute_metrics
from src.hfutil.hf_callbacks import SimpleProfileCallback, NoDeepspeedCallback
os.environ['TOKENIZERS_PARALLELISM'] = 'False'

@hydra.main(version_base=None, config_path="../configs", config_name="lightning_tune_config")
def main(configs):
    num_devices = int(os.environ.get('WORLD_SIZE', 1))
    if os.environ.get('LOCAL_RANK') is not None:
        local_rank = int(os.environ.get('LOCAL_RANK', '0'))
        device_map = {'': local_rank}

    BASELOGDIR = configs.BASELOGDIR
    output_dir = HydraConfig.get().runtime.output_dir
    configs.base_logdir = os.path.join(output_dir, "logs")
    LOGGER = init_script(configs)
    LOGGER.info("Config", configs=configs)
    LOGGER.info(f"num_devices: {num_devices}")

    OmegaConf.set_struct(configs, False)  # Disable struct mode temporarily
    configs.name = "split=" + configs.data.split + "|loss=" + configs.model_train.loss_type + f"data={configs.data_mode.mode}" + f"|lr={configs.lr}" + "|lora=" + str(configs.model_train.Lora.r) + "|assist=" + str(configs.model_train.get('is_assist', False))
    print("configs.name", configs.name)
    OmegaConf.set_struct(configs, True)  # Disable struct mode temporarily

    now, nowname, logdir, ckptdir, cfgdir = create_log_dir(configs)
    os.makedirs(logdir, exist_ok=True)
    
    #! setup dataset
    tokenizer = AutoTokenizer.from_pretrained(configs.model_train.model_path)
    tokenizer.padding_side = "right"
    if "mistral" in configs.model_train.model_path.lower():
        tokenizer.padding_side = "left" #! no idea why this is needed for mistral
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    data_module = create_datamod(
        configs.data, 
        tokenizer = tokenizer,
        data_mode_config=configs.data_mode,
        question_end_token=configs.data.conv_template.question_end_token if "remember" not in configs.model_train.loss_type else None,
    )
    data_module.prepare_data()
    data_module.setup('fit')
    # import ipdb; ipdb.set_trace()

    lightning_config = configs.get("lightning", OmegaConf.create())
    trainer_config = lightning_config.get("trainer", OmegaConf.create())    
    batch_size = configs.data.batch_size
    train_data_size = len(data_module.train_dataloader()) * batch_size
    num_update_steps_per_epoch = train_data_size // (num_devices * batch_size * trainer_config.accumulate_grad_batches)
    num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
    num_training_steps = num_update_steps_per_epoch * trainer_config.max_epochs 
    print("num_training_steps", num_training_steps)

    checkpoint_config = lightning_config['callbacks']['checkpoint_callback']['params']
    tmpckptdir = ckptdir.split(BASELOGDIR)[-1]
    checkpoint_config['dirpath'] = os.path.join(
        configs.OUTPUTMODELDIR, "/".join(tmpckptdir.split("/")[1:-1]).replace(",", "|").replace("=","_")
    )
    print(checkpoint_config['dirpath'])
    os.makedirs(checkpoint_config['dirpath'], exist_ok=True)
    
    model_config = OmegaConf.create(configs.get('model_train', None))
    is_offset = model_config.get('offset', False)

    #! setup trainer
    os.makedirs(logdir, exist_ok=True)
    os.environ["WANDB_PROJECT"] = configs.project
    os.environ["WANDB_DIR"] = logdir
    is_deepspeed = 'deepspeed' in trainer_config.get('strategy', "")
    if is_deepspeed:
        deepspeed_configfile = "configs/ds_config.json" if not is_offset else "configs/ds_config2.json"
    else:
        deepspeed_configfile = None
    training_args = transformers.TrainingArguments(
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=trainer_config.accumulate_grad_batches,
        warmup_steps=int(num_training_steps * model_config.warmup_ratio),
        learning_rate=model_config.learning_rate,
        max_steps=num_training_steps,
        bf16=True,
        bf16_full_eval=True,
        logging_steps=1,
        logging_dir=logdir,
        output_dir=checkpoint_config['dirpath'],
        optim="paged_adamw_32bit",
        save_only_model=True,
        ddp_find_unused_parameters=False,
        deepspeed=deepspeed_configfile,
        weight_decay=model_config.weight_decay,
        save_steps=num_update_steps_per_epoch,
        eval_steps=num_update_steps_per_epoch,
        # eval_steps=10, #! For debugging 
        evaluation_strategy="steps",
        seed=configs.get('seed', 42),
        report_to='wandb',
        run_name=configs.name,
        remove_unused_columns=False,
        prediction_loss_only=not ('wmdp' in configs.data.class_name.lower()),
    )
    
    simpleprofilercallback = SimpleProfileCallback(
        logdir, "simpleprofile.txt"
    )

    #! Logging training mode
    batch = next(iter(data_module.train_dataloader()))
    sampledatas = {
        "train_sample_keys": batch.keys(),
        "train_sample": tokenizer.batch_decode(batch['input_ids'][:2], skip_special_tokens=True),
    }
    if 'prefer_input_ids' in batch:
        sampledatas['prefer_sample'] = tokenizer.batch_decode(batch['prefer_input_ids'][:2], skip_special_tokens=True)
    if 'retainlabel' in batch:
        sampledatas['retainlabel'] = batch['retainlabel'].tolist()
    LOGGER.info("Sample data", **sampledatas, shape=batch['input_ids'].shape)

    train_set = data_module.train_set()
    val_set = data_module.val_set()

    baseoutdir = checkpoint_config['dirpath']
    loraconf = model_config.get('Lora', None)
    model_path = model_config.pop('model_path')
    num_layer = model_config.pop('num_layer')
    data_type = model_config.pop('data_type')

    if model_config.get('is_assist', False): # 
        model = init_assisted_model(model_path, num_layer, data_type, **model_config)
    elif (loraconf is not None and loraconf.r != 0):
        model = init_peft_model(model_path, loraconf, baseoutdir, num_layer, data_type)
    elif (model_config.get('offset', False)):
        model = init_offset_model(model_path, data_type, **model_config)
    else:
        model = init_full_model(model_path, num_layer, data_type, **model_config)
    model = model.train()

    if model_config.loss_type == 'rmu':
        loss_function = create_loss_func(retain_weight=model_config.remember_weight, **model_config, model_config=model.config)
    else:
        loss_function = create_loss_func(retain_weight=model_config.remember_weight, **model_config)

    if model_config.loss_type in REQUIRES_ORACLE:
        oracle_model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=torch.bfloat16,
            use_flash_attention_2=True, trust_remote_code=True, 
        )
        oracle_model.eval()
        oracle_model.requires_grad_(False)
    else:
        oracle_model = None

    if model_config.loss_type == 'rmu':
        model.requires_grad_(False)
        # only train part of the model
        for layer in model_config.layerids:
            model.model.layers[layer].requires_grad_(True)

    requires_equal_sampler = (loss_function.retain_loss_func is not None) 
    LOGGER.info("Training with equal sampler: ", requires_equal_sampler=requires_equal_sampler)

    if 'wmdp' in configs.data.class_name.lower():
        trainer_class = WMDP_ForgetTrainer
        metric_func = wmdp_compute_metrics
    else:
        trainer_class = ForgetTrainer 
        metric_func = None
   
    custom_callbacks = [simpleprofilercallback]
    trainer = trainer_class(
        model=model,
        train_loss_function=loss_function,
        oracle_model=oracle_model,
        equal_sampler=requires_equal_sampler,
        is_deepspeed=is_deepspeed,
        train_dataset=train_set,
        eval_dataset=val_set,
        seed=configs.get('seed', 42),
        compute_metrics=metric_func,
        callbacks=custom_callbacks,
        args=training_args,
        is_offset=is_offset,
    )
    model.config.use_cache = False
    if is_offset:
        model.basellm = trainer.oracle_model
    trainer.train()

    if local_rank == 0:
        os.symlink(output_dir, os.path.join(checkpoint_config['dirpath'], "trainlogdir"))

if __name__ == "__main__":
    main()