import pytorch_lightning as pl 
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pl_modules import CIFAR10_Module
from hparams.cifar10_hparams_dict_testing import CIFAR10_HParams_Dict_Testing as config_dict_testing
import wandb
from attr import evolve
from pytorch_lightning.loggers import TensorBoardLogger
from utils import *
import os

from pytorch_lightning.callbacks import ModelCheckpoint

ROOT = ''
ckpt_period = 1 

def main(args):
    config_name = args.config
    suffix = args.suffix
    if args.checkpoint:
        suffix = suffix + args.checkpoint[args.checkpoint.find('epoch='):].split('.')[0]
    log_name = config_name +'_'+ suffix if len(suffix)>0 else config_name

    seed_everything(args.seed)
    config = config_dict_testing[config_name]
    run = wandb.init(project=config.wb_project,
                    name = log_name,
                    sync_tensorboard=True,
                    reinit = True,
                    entity = config.wb_entity,
                    save_code = True,
                    config = config.to_dict())
    config.jobid = args.job_id
    config.taskid = args.task_id
    print(config.to_dict())

    ds_str = 'cifar-10'
    if not os.path.exists(ds_str):
        os.mkdir(ds_str)
    
    ckpt_period = args.ckpt_period
    if args.checkpoint and not args.resume:
        method = CIFAR10_Module.load_from_checkpoint(args.checkpoint, hparams=config)
    else:
        method = CIFAR10_Module(config)
    print(method.hparams)
    # exit(0)
        
    logger = TensorBoardLogger("tb_logs", name=f"{log_name}")
    logger.log_hyperparams(config.to_dict())

    checkpoint_dir = os.path.join(ROOT, ds_str, config_name, suffix)
    os.makedirs(checkpoint_dir, exist_ok=True)
    print('checkpoints will be stored at:', checkpoint_dir)
    if args.ckpt_step == -1:
        checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_dir, save_top_k = config.save_top_k, save_last=True, every_n_epochs = ckpt_period)
    else:
        checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_dir, save_top_k = config.save_top_k, save_last=True, every_n_train_steps = args.ckpt_step)
    initial_checkpoint_callback = ModelInitialCheckpoint(dirpath=checkpoint_dir)
    callbacks=[LR_WD_Scheduler(), LR_WD_Logger(epoch_wise =False), Norm_Logger(layer_wise = False)]
#     callbacks+=[SC_Test()] ##sc_test is on single gpu, leads to cuda out of memory
    # callbacks.append(Effective_Step_logger())
    callbacks.append(checkpoint_callback)
    if args.ckpt_init:
        callbacks.append(initial_checkpoint_callback)
    
    # if hasattr(config, 'measure_variance') and config.measure_variance:
    #     freq = 1 # default measure variance once per epoch
    #     if hasattr(config, 'batch_k') and config.batch_k != 1:
    #         freq *= config.batch_k * 2
    #     callbacks.append(Variance_Measurement(freq=freq))
    trainer = pl.Trainer(
        gpus=-1, 
        max_epochs=config.max_epochs, 
        logger=logger, 
        callbacks= callbacks,  
        accelerator='dp' if len(os.environ["CUDA_VISIBLE_DEVICES"])>1 else None,
        deterministic=True, 
        log_every_n_steps=1,
        accumulate_grad_batches = config.grad_accumulate if hasattr(config, 'grad_accumulate') else 1,
        check_val_every_n_epoch = config.check_val_every_n_epoch,
        resume_from_checkpoint=os.path.join(checkpoint_dir, 'last.ckpt') if args.resume else None,
        progress_bar_refresh_rate = 100,
    ) #
    trainer.fit(method)
    run.finish()
    wandb.finish()
    return method
    
if __name__ == '__main__':
    args = parse_args()
    main(args)
    wandb.finish()
