import torch
import wandb 
import matplotlib.pyplot as plt 

import sys 
sys.path.append('.') 

# gflownet and general utils 
from experiments.experiments_utils import get_argument_parser, \
    create_gfn, create_env, create_log_reward, eval_step 
from var_red_gfn.utils import train_step 

WANDB_PROJECT_NAME = 'variance_reduced_gflownets' 
    
def main(config): 
    torch.set_default_dtype(torch.float64) 
    torch.manual_seed(config.seed) 

    # instantiate the gflownet 
    gfn = create_gfn(config) 
    log_reward = create_log_reward(config, gfn) 
    
    create_env_func = lambda: create_env(config, log_reward) 
    
    # train the gflownet 
    div_criteria = ['kl', 'tsallis', 'renyi', 'rev_kl']

    # Training hyperparameters 
    optimizer = torch.optim.Adam([
        {'params': gfn.pf.parameters(), 'lr': 1e-3}, 
        {'params': gfn.log_z, 'lr': 1e1} 
    ])
    if config.use_scheduler: 
        scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, 
            total_iters=config.epochs_per_step*config.num_steps, power=1.) 
    else: 
        scheduler = None 

    metric = list() 
    epochs = list() 

    metric.append(
        eval_step(config, gfn, create_env_func, plot=False) 
    ) 
    epochs.append(0) 

    for step in range(config.num_steps): 
        # Train step 
        gfn.train() 
        if config.criterion in div_criteria and not config.off_policy: 
            with gfn.on_policy(): 
                train_step(gfn, create_env_func, epochs=config.epochs_per_step, 
                    optimizer=optimizer, scheduler=scheduler, 
                    use_progress_bar=config.use_progress_bar)     
        else: 
            train_step(gfn, create_env_func, epochs=config.epochs_per_step, 
                    optimizer=optimizer, scheduler=scheduler, 
                    use_progress_bar=config.use_progress_bar)

        epochs.append(config.epochs_per_step * (step + 1)) 
        # evaluate the gflownet
        gfn.eval()  
        metric.append( 
            eval_step(config, gfn, create_env_func, plot=(step==config.num_steps-1))  
        ) 

    print(metric) 
    wandb.run.summary['metric'] = metric 
    wandb.run.summary['epochs'] = epochs 

if __name__ == '__main__': 
    parser = get_argument_parser() 
    config = parser.parse_args() 
    wandb.init(project=WANDB_PROJECT_NAME, tags=[f'{config.env}', f'{config.criterion}', f'{config.seed}']) 
    wandb.config.update(config) 
    main(config)  
