from dataclasses import dataclass, field
from typing import Optional
import transformers
from accelerate import Accelerator
import gc
import os
import wandb
from tqdm.auto import tqdm

from utils.process_verifier_models import save_verifier, build_verifier_from_osv, build_verifier_from_scratch
from utils.states import set_training_states, set_random_seed, get_optimizers
from utils.verifier_datasets import make_training_verifier_data_module, make_training_dataloaders

@dataclass
class ModelParams:
    model_name_or_path: Optional[str] = field(default="none")

@dataclass
class DataParams:
    data_dir: str = field(default='none', metadata={"help": "Path to the training data."})
    data_id : str = field(default='none')
    target_set: str = field(default='train')
    val_target_set: str = field(default=None)
    generator_id: str = field(default='none')

    per_problem_sampling_solution: int = field(default=-1)
    loss_level: str = field(default='token')
    loss_on_llm: bool = field(default=False)

    dedup: bool = field(default=False)
    process: bool = field(default=False)

    verifier_id: str = field(default='none')

@dataclass
class TrainParams:
    cache_dir: Optional[str] = field(default=None)
    model_max_length: int = field(
        default=2048,
        metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
    )

    max_steps: int = field(default=-1, metadata={"help": "When it is specified, num_train_epoches is ignored"})
    num_train_epoches: int = field(default=1)
    per_device_train_batch_size: int = field(default=4)
    gradient_accumulation_steps: int = field(default=1)
    gradient_checkpointing: bool = field(default=True)

    eval_steps: int = field(default=-1, metadata={"help": "When it is specified, eval_epoches is ignored"})
    eval_epoches: int = field(default=1)
    max_grad_norm: int = field(default=1.0)
    per_device_eval_batch_size: int = field(default=4)
    resume_from_checkpoint: bool= field(default=False)

    learning_rate: float = field(default=1e-5)
    weight_decay: float = field(default=0)
    lr_scheduler_type: str = field(default="linear")
    warmup_steps: int = field(default=-1, metadata={"help": "When it is specified, warmup_ratio is ignored"})
    warmup_ratio: float = field(default=0)

    num_lr_epoches_fs: int = field(default=-1)
    num_lr_epoches_scatter: int = field(default=-1)

    logging_steps: int = field(default=-1, metadata={"help": "When it is specified, logging_epoches is ignored"})
    logging_epoches: int = field(default=1)

    save_steps: int = field(default=-1, metadata={"help": "When it is specified, save_epoches is ignored"})
    save_epoches: int = field(default=1)
    save_total_limit: int = field(default=3)
    save_best: bool = field(default=False)
    fp16: bool = field(default=False)
    seed: int = field(default=42)
    resume: bool = field(default=False)

@dataclass
class OutputParams:
    logging_dir: str = field(default='wandb/')
    save_dir: str = field(default='checkpoints/')

# Define the main function

def main():
    # Parse arguments
    parser = transformers.HfArgumentParser((ModelParams, DataParams, TrainParams, OutputParams))
    model_args, data_args, training_args, output_args = parser.parse_args_into_dataclasses()
    config_args_dict = model_args.__dict__.copy()
    config_args_dict.update(data_args.__dict__)
    config_args_dict.update(training_args.__dict__)
    set_random_seed(training_args.seed)

    # Initialize accelerator
    accelerator = Accelerator(gradient_accumulation_steps=training_args.gradient_accumulation_steps)

    # Build or load model
    if model_args.model_name_or_path and os.path.exists(os.path.join(model_args.model_name_or_path, 'verifier.pth')):
        model, tokenizer = build_verifier_from_osv(model_args, training_args, accelerator)
    else:
        model, tokenizer = build_verifier_from_scratch(model_args, training_args, accelerator)

    # Prepare data
    data_module = make_training_verifier_data_module(tokenizer, data_args)
    train_dataloader = make_training_dataloaders(data_module, training_args)

    # Set training states and get optimizers
    set_training_states(data_module, training_args)
    optimizer, lr_scheduler = get_optimizers(model, training_args)

    # Prepare model, dataloader, optimizer, and lr_scheduler for acceleration
    model, train_dataloader, optimizer, lr_scheduler = accelerator.prepare(model, train_dataloader, optimizer,
                                                                               lr_scheduler)
    # Initialize counters
    cur_epoch = local_step = global_step = 0

    # Initialize wandb if main process
    if accelerator.is_main_process:
        project_name = os.environ['WANDB_PROJECT']
        logging_dir = os.path.join(output_args.logging_dir, project_name)
        os.makedirs(logging_dir, exist_ok=True)
        wandb_id = output_args.save_dir
        wandb.init(id=wandb_id, dir=logging_dir, config=config_args_dict)

    loaded_step = -1
    loaded_step_dir = ""

    # Load model from checkpoint if specified
    if training_args.resume_from_checkpoint:
        assert os.path.exists(output_args.save_dir)
        subdirs = [d for d in os.listdir(output_args.save_dir) if os.path.isdir(os.path.join(output_args.save_dir, d))]
        for subdir in subdirs:
            try:
                step = int(subdir)
                if step > loaded_step:
                    loaded_step = step
                    loaded_step_dir = subdir
            except ValueError:
                continue
        assert loaded_step
        loaded_step_dir_path = os.path.join(output_args.save_dir, loaded_step_dir)
        print(f"Model to be loaded: {loaded_step_dir_path} ")
        accelerator.load_state(loaded_step_dir_path)
        loaded_step *= training_args.gradient_accumulation_steps

    start_global_step = loaded_step

    # Start training
    global_step = 0
    model.train()
    while global_step < training_args.num_training_steps:
        train_dataloader_iterator = tqdm(enumerate(train_dataloader), total=len(train_dataloader),
                                         desc='Training') if accelerator.is_main_process else enumerate(
            train_dataloader)

        for local_step, batch in train_dataloader_iterator:
            if global_step < start_global_step:
                global_step += 1
                continue

            batch_input = {k: v for k, v in batch.items() if k in ('input_ids', 'attention_mask', 'labels', 'v_labels')}
            with accelerator.autocast(), accelerator.accumulate(model):
                output = model(**batch_input, output_all_losses=True)
                loss = output.loss
                all_losses = output.all_losses
                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                accelerator.wait_for_everyone()

            if accelerator.is_main_process:
                train_dataloader_iterator.set_postfix(epoch=cur_epoch, step=local_step, loss=loss.item(),
                                                      v_loss=all_losses.get('v_loss').item(),
                                                      llm_loss=all_losses.get(
                                                          'llm_loss').item() if data_args.loss_on_llm else 0)

                if global_step % training_args.gradient_accumulation_steps and (
                        global_step % training_args.gradient_accumulation_steps) % training_args.num_logging_steps == 0:
                    wandb.log({
                        'loss': loss.item(),
                        'v_loss': all_losses.get('v_loss').item(),
                        'llm_loss': all_losses.get('llm_loss').item() if data_args.loss_on_llm else 0,
                        'lr': lr_scheduler.get_last_lr()[0],
                    }, step=global_step)

            if global_step != 0 and (global_step % training_args.gradient_accumulation_steps == 0) and (
                    global_step // training_args.gradient_accumulation_steps) % training_args.per_save_steps == 0 and global_step != loaded_step:
                accelerator.wait_for_everyone()
                resume_dir = os.path.join(output_args.save_dir, str(global_step // training_args.gradient_accumulation_steps))
                print(f"saving model in {resume_dir} ")
                accelerator.save_state(resume_dir)

            global_step += 1

        cur_epoch += 1
        del train_dataloader_iterator
        gc.collect();
        accelerator.wait_for_everyone()
    
    # Save model
    accelerator.wait_for_everyone()
    save_verifier(accelerator, model, tokenizer, output_args.save_dir)

    # Finish wandb if main process

    if accelerator.is_main_process:
        wandb.finish()


if __name__ == "__main__":
    main()
