import torch 
import wandb 
import matplotlib.pyplot as plt 

import argparse 
import sys 
sys.path.append('.') 

# gflownets and general utils 
from var_red_gfn.gflownet import GFlowNet  
from var_red_gfn.utils import compute_marginal_dist 

def get_argument_parser():
    parser = argparse.ArgumentParser(description="GFlowNet Training Script")

    # GFlowNet architecture and training parameters
    parser.add_argument("--hidden_dim", type=int, default=64, help="Hidden dimension of the policy network")
    parser.add_argument("--num_layers", type=int, default=2, help="Number of layers in the policy network")
    parser.add_argument("--epochs", type=int, default=2000, help="Number of training epochs")
    parser.add_argument("--epochs_eval", type=int, default=100, help="Number of epochs for evaluation")
    parser.add_argument('--epochs_per_step', type=int, default=25, help='number of epochs per step')
    parser.add_argument('--num_steps', type=int, default=25, help='number of steps')  
    parser.add_argument("--use_scheduler", action="store_true", help="Use learning rate scheduler")
    parser.add_argument("--alpha", type=float, default=0.5, help="Trade-off parameter in GFlowNet")
    parser.add_argument("--criterion", type=str, default="tb", help="Loss function for training", 
                        choices=['kl', 'tsallis', 'renyi', 'rev_kl', 'tb', 'db', 'cb', 'dbc'])
    parser.add_argument('--off_policy', action='store_true', help='whether to pursue off-policy')
    parser.add_argument("--cv", type=str, default=None, help="What control variate to use for training")
    parser.add_argument("--device", type=str, default="cpu", help="Device to use (cpu or cuda)")
    
    parser.add_argument('--env', type=str, default='sets', help='Target domain', 
                        choices=['sets', 'sequences', 'phylogenetics', 'grids', 'gmms', 'banana']) 

    # Environment parameters

    # Generic 
    parser.add_argument("--batch_size", type=int, default=128, help="Batch size for training")

    # Sets 
    parser.add_argument("--set_size", type=int, default=16, help="Number of elements in the set")
    parser.add_argument("--src_size", type=int, default=32, help="Number of source vectors")

    # Sequences 
    parser.add_argument('--seq_size', type=int, default=8, help='size of the sequence') 
    parser.add_argument('--wse_size', type=int, default=10, help='warehouse size') 

    # Phylogenetic inference 
    parser.add_argument('--num_leaves', type=int, default=7, help='number of biological species') 
    parser.add_argument('--num_nb', type=int, default=4, help='number of nucleotides (hypothetical)') 
    parser.add_argument('--num_sites', type=int, default=25, help='number of observed sites') 
    parser.add_argument('--temperature', type=float, default=1., help='temperature of the target') 
   
    # GMMs 
    parser.add_argument('--sigma', type=float, default=1e-1, help='maginal variance') 
    parser.add_argument('--num_comp', type=int, default=4, help='number of components for the policy') 

    # Reward and seed
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reward generation")

    # Visualization parameters
    parser.add_argument("--num_back_traj", type=int, default=8, help="Number of back-trajectories for evaluation")
    parser.add_argument('--use_progress_bar', action='store_true', help='use progress bar') 

    return parser

def create_gfn(config): 
    match config.env: 
        case 'sets': 
            from var_red_gfn.models.sets import ForwardPolicy, BackwardPolicy 
            pf = ForwardPolicy(config.src_size, config.hidden_dim, config.num_layers, device=config.device) 
            pb = BackwardPolicy(config.device) 
        case 'sequences': 
            from var_red_gfn.models.sequences import ForwardPolicy, BackwardPolicy
            pf = ForwardPolicy(config.seq_size, config.wse_size, config.hidden_dim, config.num_layers, device=config.device)
            pb = BackwardPolicy(config.device) 
        case 'phylogenetics': 
            from var_red_gfn.models.phylogenetics import ForwardPolicy, BackwardPolicy 
            pf = ForwardPolicy(config.hidden_dim, config.num_leaves, device=config.device) 
            pb = BackwardPolicy(config.device) 
        case 'gmms' | 'banana': 
            from var_red_gfn.models.gmms import ForwardPolicy, BackwardPolicy 
            pf = ForwardPolicy(2, config.hidden_dim, config.num_comp, device=config.device) 
            pb = BackwardPolicy(config.device) 
        case _: 
            raise Exception(f'env: {config.env}')  
        
    return GFlowNet(pf, pb, criterion=config.criterion, 
                    device=config.device, alpha=config.alpha)  

def create_env(config, log_reward=None): 
    match config.env: 
        case 'sets': 
            from var_red_gfn.gym.sets import Set 
            return Set(config.src_size, config.set_size, config.batch_size, log_reward, device=config.device) 
        case 'sequences': 
            from var_red_gfn.gym.sequences import Sequences
            return Sequences(config.seq_size, config.wse_size, config.batch_size, log_reward, device=config.device) 
        case 'phylogenetics': 
            from var_red_gfn.gym.phylogenetics import Trees 
            return Trees(config.num_leaves, config.batch_size, log_reward, device=config.device) 
        case 'gmms' | 'banana': 
            from var_red_gfn.gym.gmms import GaussianMixture 
            return GaussianMixture(dim=2, batch_size=config.batch_size, 
                            log_reward=log_reward, device=config.device) 
        
def create_log_reward(config, gflownet): 
    match config.env: 
        case 'sets': 
            from var_red_gfn.gym.sets import Set, LogReward 
            log_reward = LogReward(config.src_size, config.seed, device=config.device) 
            sets = Set(config.src_size, config.set_size, config.batch_size, log_reward, device=config.device) 
            sets = gflownet.sample(sets) 
            log_reward.shift = sets.log_reward().max()  
            return log_reward 
        case 'sequences': 
            from var_red_gfn.gym.sequences import LogReward 
            return LogReward(config.wse_size, config.seq_size, config.seed, device=config.device) 
        case 'phylogenetics': 
            from var_red_gfn.gym.phylogenetics import Trees, LogReward 
            tree = Trees(config.num_leaves, batch_size=1, log_reward=None, device=config.device) 
            with gflownet.off_policy(): 
                tree = gflownet.sample(tree, seed=42) 
            # Simulate JC69 
            Q = 3e-1 * torch.ones((config.num_nb, config.num_nb), device=config.device) 
            Q[torch.arange(config.num_nb), torch.arange(config.num_nb)] -= Q.sum(dim=-1) 
            pi = torch.ones((config.num_nb,), device=config.device) / config.num_nb   
            sites = Trees.sample_from_phylogeny(tree, Q, config.num_sites, pi, device=config.device)
            sites = sites[:, :config.num_leaves] 
            # Tree's likelihood using Felsenstein's algorithm 
            log_reward = LogReward(pi, sites, Q, config.temperature) 
            # Shift the reward for enhanced numerical stability 
            env = Trees(config.num_leaves, batch_size=config.batch_size, log_reward=log_reward, device=config.device) 
            values = gflownet.sample(env).log_reward() 
            log_reward.shift = values.max() 
            return log_reward 
        case 'gmms': 
            from var_red_gfn.gym.gmms import LogReward
            mu = torch.arange(3, device=config.device)
            mu = torch.meshgrid(mu, mu)
            mu = torch.cat([m.unsqueeze(0) for m in mu], dim=0).flatten(start_dim=1) 
            mu = mu.type(torch.get_default_dtype()).t() 
            return LogReward(mu, config.sigma) 
        case 'banana': 
            from var_red_gfn.gym.gmms import LogRewardBanana 
            return LogRewardBanana(device=config.device) 

def eval_step(config, gfn, create_env_func, plot): 
    with gfn.off_policy(): 
        if config.env == 'gmms' or config.env == 'banana': 
            from var_red_gfn.gym.gmms import GaussianMixture 
            js = GaussianMixture.estimate_js(gfn, create_env_func, 
                        num_batches=config.epochs_eval, 
                        log_reward=create_env_func()._log_reward,  
                        use_progress_bar=config.use_progress_bar, 
                        plot=plot)  
            wandb.log({'js': js}, commit=True) 
            return_value = js 
        else: 
            learned_dist, target_dist = compute_marginal_dist(gfn, create_env_func, 
                                    num_batches=config.epochs_eval, 
                                    num_back_traj=config.num_back_traj, 
                                    use_progress_bar=config.use_progress_bar) 

            wandb.log({'l1': (learned_dist - target_dist).abs().sum()}) 
            return_value = (learned_dist - target_dist).abs().sum() 

    if config.env not in ['gmms', 'banana'] and plot: 
        plt.scatter(learned_dist.cpu(), target_dist.cpu(), rasterized=True) 
        plt.savefig('figure.png') 

    return return_value 
