from typing import Optional
import os

import torch
import wandb

from src.utils.misc import set_seed, set_hyperparams
from src.utils.saving_loading import get_most_recent_checkpoint_filepath, get_dirs_dict, get_logger_id, save_logger_id


def get_system_and_trainer(cfg, exp_name, cfg_path, tmp_dir, load_checkpoint_type: Optional[str] = "recent",
                           ckpt_path: Optional[str] = None, **kwargs):
    # ____ Argument checks. ____
    if ckpt_path is not None:
        assert load_checkpoint_type is None, f"If a checkpoint is provided, don't seek checkpoints to load."

    # ____ Generic configurations. ____
    # Set seed.
    set_seed(cfg.seed)

    # Use deterministic pytorch ops wherever possible.
    if cfg.use_deterministic_algorithms:
        torch.use_deterministic_algorithms(True)

    # Prevent wandb from syncing with the cloud.
    if cfg.wandb_dryrun:
        os.environ['WANDB_MODE'] = 'dryrun'

    # ____ Get the directories dict. ____
    dirs_dict = get_dirs_dict(cfg_dir_rel=cfg_path, tmp_dir=tmp_dir)

    # ____ Deal with checkpoint loading.  ____
    # Get the directory from which the checkpoint will be loaded. Set "checkpoint_found" flag, which is to be returned.
    load_ckpt_filepath = None
    if load_checkpoint_type is not None:
        load_ckpt_filepath = get_most_recent_checkpoint_filepath(dirs_dict, load_checkpoint_type) \
            if cfg.resume_if_possible else None
    ckpt_found = True if (load_ckpt_filepath is not None) else False

    # ____ Initialize the logger. ____
    # Get logger. Search for the id of the logger in the previous run, if asked to resume_if_possible.
    # The experiment name also gets overriden here.
    logger_id, found_exp_name = get_logger_id(load_ckpt_filepath=load_ckpt_filepath) if cfg.resume_if_possible and ckpt_found \
        else (wandb.util.generate_id(), None)
    exp_name = exp_name if found_exp_name is None else found_exp_name
    logger = cfg.logger(name=exp_name, save_dir=dirs_dict.cfg_dir_rel, id=logger_id)
    save_logger_id(logger=logger, tmp_ckpt_dir_abs=dirs_dict.tmp_ckpt_dir_abs, ckpt_dir_rel=dirs_dict.ckpt_dir_rel)

    # ____ Deal with checkpoint saving. ____
    # Build the checkpoint saving callback.
    checkpoint_callback = cfg.checkpoint_callback(tmp_ckpt_dir=dirs_dict.tmp_ckpt_dir_abs,
                                                  final_ckpt_dir=dirs_dict.ckpt_dir_rel,
                                                  dirs_dict=dirs_dict)
    cfg.callbacks = list() if cfg.callbacks is None else cfg.callbacks
    cfg.callbacks.append(checkpoint_callback)

    # ____ Deal with hyperparameter logging. ____
    set_hyperparams(config_path=cfg_path, logger=logger)

    # ____ Get the system, logger, checkpoint callback and trainer. ____
    # I'm using my own custom checkpoint callback that doesn't work well as a legitimate "checkpoint callback"
    # within Pytorch lightning. Hence, checkpoint_callback=False below.
    system = cfg.system(dirs_dict=dirs_dict)
    num_gpus = 1 if torch.cuda.is_available() else 0
    trainer = cfg.trainer(logger=logger, gpus=num_gpus,
                          enable_checkpointing=False,
                          callbacks=cfg.callbacks)

    return system, trainer, ckpt_found, load_ckpt_filepath
