import torch 
import torch.nn as nn 
import numpy as np 
from copy import deepcopy 

def huber_loss(diff, delta=9.):
    diff_sq = (diff * diff)
    larger_than_delta = (diff_sq > delta ** 2).to(dtype=diff.dtype)
    return (
        delta * (diff_sq - .5 * delta) * larger_than_delta + \
        .5 * diff_sq * (1 - larger_than_delta)
    ).mean()

class GFlowNet(nn.Module): 

    def __init__(self, pf, pb, criterion='tb', device='cpu', alpha=.5, use_cv=True): 
        super(GFlowNet, self).__init__() 
        self.pf = pf
        self.pb = pb 
        self.log_z = nn.Parameter(torch.randn((1,), dtype=torch.get_default_dtype(), device=device).squeeze(), requires_grad=True) 

        self.criterion = criterion 
        self.device = device 
        self._alpha = alpha  
        self.use_cv = use_cv 
        
    def forward(self, batch_state): 
        if self.criterion == 'dbc': 
            return self._detailed_balance_comp(batch_state)
        traj_stats, F_traj, last_idx = self._sample_traj(batch_state) 
        match self.criterion: 
            case 'tb': 
                loss = self._trajectory_balance(batch_state, traj_stats, F_traj, last_idx) 
            case 'db':
                loss = self._detailed_balance(batch_state, traj_stats, F_traj, last_idx) 
            case 'subtb': 
                loss = self._subtrajectory_balance(batch_state, traj_stats, F_traj, last_idx) 
            case 'kl': 
                loss = self._kl_div(batch_state, traj_stats, F_traj, last_idx) 
            case 'rev_kl': 
                loss = self._rev_kl_div(batch_state, traj_stats, F_traj, last_idx) 
            case 'renyi': 
                loss = self._renyi_div(batch_state, traj_stats, F_traj, last_idx) 
            case 'tsallis': 
                loss = self._tsallis_div(batch_state, traj_stats, F_traj, last_idx) 
            case 'jeffrey': 
                loss = self._jeffrey_div(batch_state, traj_stats, F_traj, last_idx) 
            case _: 
                raise ValueError(f'{self.criterion} should be either tb, cb, db, or dbc') 
        return loss, traj_stats 

    def _sample_traj(self, batch_state):
        # dim 0: pf, dim 1: pb, dim 2: pf_exp 
        traj_stats = torch.zeros((3, batch_state.batch_size, batch_state.max_trajectory_length), device=self.device) 
        F_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length + 1), device=self.device) 
        i = 0 
        last_idx = torch.zeros((batch_state.batch_size,), dtype=torch.long, device=self.device) 

        is_stopped = torch.zeros((batch_state.batch_size,), dtype=bool, device=self.device) 

        while (batch_state.stopped < 1).any(): 
            # Sample the actions 
            out = self.pf(batch_state) 
            actions, pf, F, sp = out[0], out[1], out[2], out[3]  

            # Apply the actions  
            batch_state.apply(actions) 

            # Corresponding backward actions  
            out = self.pb(batch_state, actions) 
            pb = out[1] 

            # Save values 
            traj_stats[0, ~is_stopped, i] = pf[~is_stopped] 
            traj_stats[1, ~is_stopped, i] = pb[~is_stopped]
            traj_stats[2, ~is_stopped, i] = sp[~is_stopped] # sampling policy 
            F_traj[~is_stopped, i] = F[~is_stopped] 
            # Check whether it already stopped 
            is_stopped = batch_state.stopped.bool()  
            i += 1 
            last_idx += (1 - batch_state.stopped).long()   
        
        F_traj[batch_state.batch_ids, last_idx + 1] = batch_state.log_reward() 
        return traj_stats, F_traj, last_idx + 1 

    def _trajectory_balance(self, batch_state, traj_stats, F_traj, last_idx): 
        loss = (traj_stats[0] - traj_stats[1]).sum(dim=1) - F_traj[batch_state.batch_ids, last_idx] + self.log_z 
        return huber_loss(loss, delta=1.) # (loss*loss).mean()  

    def _detailed_balance(self, batch_state, traj_stats, F_traj, last_idx): 
        loss = traj_stats[0] + F_traj[:, :-1] - traj_stats[1] - F_traj[:, 1:]
        return huber_loss(loss, delta=1.) # (loss*loss).mean() 
    
    def _subtrajectory_balance(self, batch_state, traj_stats, F_traj, last_idx): 
        max_traj_length = batch_state.max_trajectory_length 
        i, j = torch.triu_indices(max_traj_length + 1, max_traj_length + 1, offset=1, device=self.device) 
        
        traj_stats = torch.cumsum(traj_stats, dim=-1) 
        traj_stats = torch.cat([torch.zeros((*traj_stats.shape, 1), device=self.device), traj_stats], dim=-1) 
        pf = traj_stats[0, :, j] - traj_stats[0, :, i]  
        pb = traj_stats[1, :, j] - traj_stats[1, :, i] 
        loss = pf - pb + F_traj[:, i] - F_traj[:, j] 
        lamb = .9 ** (j - i).view(1, -1)  
        loss = ((loss * loss) * lamb).sum(dim=1) / lamb.sum()   
        return huber_loss(loss, delta=1.) 

    def _kl_div(self, batch_state, traj_stats, F_traj, last_idx): 
        div = (traj_stats[0] - traj_stats[1]).sum(dim=1) - F_traj[batch_state.batch_ids, last_idx]  
        return div 
    
    def _rev_kl_div(self, batch_state, traj_stats, F_traj, last_idx): 
        # KL(p_{B} || p_{F}) = E [ p_{B} / p_{F} \log p_{B} / p_{F} ] 
        # exp. wrt p_{F} (note the use of importance weights) 
        div = (traj_stats[1] - traj_stats[0]).sum(dim=1) + F_traj[batch_state.batch_ids, last_idx] 
        return div 
    
    def _jeffrey_div(self, batch_state, traj_stats, F_traj, last_idx): 
        return self._kl_div(batch_state, traj_stats, F_traj, last_idx) + \
            self._rev_kl_div(batch_state, traj_stats, F_traj, last_idx) 
        
    def _renyi_div(self, batch_state, traj_stats, F_traj, last_idx): 
        # R(p_{F} || p_{B}) = 1 / (\alpha - 1) \log E [ (p_{F} / p_{B})^{\alpha - 1} ] 
        # exp. wrt p_{F} 
        pf_pb = (traj_stats[0] - traj_stats[1]).sum(dim=1) - F_traj[batch_state.batch_ids, last_idx]  
        return (self._alpha - 1) * pf_pb 
        
    def _tsallis_div(self, batch_state, traj_stats, F_traj, last_idx): 
        # T(p_{F} || p_{B}) = 1 / (\alpha - 1) (E [ (p_{F} / p_{B})^{\alpha - 1} ] - 1) 
        # exp. wrt p_{F} 
        div = (traj_stats[0] - traj_stats[1]).sum(dim=1) - F_traj[batch_state.batch_ids, last_idx] 
        div = (self._alpha - 1) * div  
        return div 

    def _sample_traj_comp(self, batch_state): 
        pf_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length), device=self.device)
        pb_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length), device=self.device)
        ps_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length + 1), device=self.device)  
        R_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length + 1), device=self.device) 
        idx = 0 
        R_traj[:, idx] = batch_state.log_reward() 
        is_stopped = batch_state.stopped == 1. 
        while not is_stopped.all(): 
            actions, pf, ps = self.pf(batch_state, return_ps=True) 
            batch_state.apply(actions) 
            _, pb = self.pb(batch_state, actions) 
            pf_traj[~is_stopped, idx] = pf[~is_stopped]  
            pb_traj[~is_stopped, idx] = pb[~is_stopped]  
            ps_traj[~is_stopped, idx] = ps[~is_stopped]  
            idx += 1 
            R_traj[~is_stopped, idx] = batch_state.log_reward()[~is_stopped]  
            is_stopped = (batch_state.stopped == 1.) 
        return pf_traj, pb_traj, ps_traj, R_traj  

    def _detailed_balance_comp(self, batch_state): 
        pf, pb, ps, R_traj = self._sample_traj_comp(batch_state) 
        violation = ((R_traj[:, :-1] + pf - ps[:, :-1]) - (pb + R_traj[:, 1:] - ps[:, 1:]))   
        loss = (violation * violation).sum(dim=1) 
        return loss.mean(), torch.cat([pf.unsqueeze(0), pb.unsqueeze(0)], dim=0)  

    @torch.no_grad() 
    def sample(self, batch_state, seed=None): 
        while (batch_state.stopped < 1).any(): 
            if seed is not None: self.pf.set_seed(seed) 
            out = self.pf(batch_state) 
            actions = out[0] 
            batch_state.apply(actions)
            if seed is not None: self.pf.unset_seed()  
        return batch_state  

    @torch.no_grad() 
    def marginal_prob(self, batch_state, copy_env=False): 
        # Use importance sampling to estimate the marginal probabilities
        if copy_env: 
            batch_state = deepcopy(batch_state) 
        forward_log_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length), device=self.device) 
        backward_log_traj = torch.zeros((batch_state.batch_size, batch_state.max_trajectory_length), device=self.device) 

        idx = 0 

        is_initial = torch.zeros(batch_state.batch_size, device=self.device, dtype=bool) 
        while not is_initial.all():
            # Estimate the backward probabilities  
            back_out = self.pb(batch_state) 
            actions, backward_log_prob = back_out[0], back_out[1] 
            
            forward_actions = batch_state.backward(actions) 

            # Estimate the forward probabilities
            forward_out = self.pf(batch_state, actions=forward_actions) 
            forward_log_prob = forward_out[1] 

            forward_log_traj[~is_initial, idx] = forward_log_prob[~is_initial]  
            backward_log_traj[~is_initial, idx] = backward_log_prob[~is_initial] 

            is_initial = batch_state.is_initial.bool()        
            idx += 1

        marginal_log = (forward_log_traj - backward_log_traj).sum(dim=1) 
        return marginal_log 

    def sample_many_backward(self, batch_states, num_trajectories): 
        marginal_log = torch.zeros((batch_states.batch_size, num_trajectories), device=self.device) 
        for idx in range(num_trajectories): 
            marginal_log[:, idx] = self.marginal_prob(batch_states, copy_env=True) 
        return marginal_log  

    class OffPolicyCtx: 

        def __init__(self, gflownet): 
            self.gflownet = gflownet 
        
        def __enter__(self): 
            self.curr_eps = self.gflownet.pf.eps 
            self.gflownet.pf.eps = 1. 
        
        def __exit__(self, *unused_args): 
            self.gflownet.pf.eps = self.curr_eps 
        
    class OnPolicyCtx: 

        def __init__(self, gflownet): 
            self.gflownet = gflownet 

        def __enter__(self): 
            self.curr_eps = self.gflownet.pf.eps 
            self.gflownet.pf.eps = 0. 

        def __exit__(self, *unused_args): 
            self.gflownet.pf.eps = self.curr_eps 

    def off_policy(self): 
        return self.OffPolicyCtx(self) 
    
    def on_policy(self): 
        return self.OnPolicyCtx(self) 
