import torch 
import torch.autograd as autograd 
import numpy as np 

def generic_control_variate(loss_func, score_func1, model, weights=None, score_func2=None): 
    ''' 
    We estimate
    
    \nabla E [ f(\tau) ] 

    using 

    E [ \nabla f(\tau) + f(\tau) \nabla \log p_{F}]; 

    here, f(\tau) is the loss_func and \log p_{F} is the score function. To evaluate 
    the gradients, we use PyTorch's autodiff implementation. 

    For the first term, we use \nabla score_func as a control variate and estimate instead  

    E [ \nabla f(\tau) - a * \nabla score_func ]; 

    a is chosen to minimize the variance of the ensuing estimator. This incurs in a small 
    bias due to the locally estimated a; however, we empirically verify that it is 
    mostly negligible. Indeed, for Renyi- and Tsallis-like divergences, the estimator 
    is already biased; moreover, SGD is inherently noisy. 
    
    For the second term, 
    we use a leave-one-out procedure. Importantly, neither of these approaches depend non-linearly 
    upon the gradient of the computed functions and, therefore, can be efficiently 
    estimated using autodiff. 
    ''' 
    if weights is None: 
        weights = torch.ones_like(loss_func) 
    if score_func2 is None: 
        score_func2 = score_func1 
    weights = weights.detach() 

    # First term 
    grad_loss = autograd.grad((weights * loss_func).mean(), model.parameters(), retain_graph=True, allow_unused=True)
    grad_log_prob = autograd.grad(score_func1.mean(), model.parameters(), retain_graph=True, allow_unused=True) 

    # Second term  
    batch_size = loss_func.shape[0]
    I = torch.eye(batch_size, device=loss_func.device)  
    est = loss_func - (1 - I) @ (loss_func) / (batch_size - 1) 
    grad_reinforce = autograd.grad((est.detach() * (weights * score_func2)).mean(), model.parameters(), allow_unused=True)

    return grad_loss, grad_log_prob, grad_reinforce, model.parameters() 

def gradient_renyi_div(gflownet, loss, traj_stats): 
    if gflownet.use_cv: 
        # The numerator's gradient may potentially benefit from gradient-reduction techniques 
        m = loss.max().detach() 
        loss_func = (loss - m).exp() 
        grad_loss, grad_log_prob, grad_reinforce, parameters = generic_control_variate(
            loss_func, traj_stats[0].sum(dim=1), gflownet.pf 
        )
        dgrad = loss_func.mean()
        for p, gl, glp, gr in zip(parameters, grad_loss, grad_log_prob, grad_reinforce): 
            if gl is None or glp is None: continue 
            a = (gl * glp).sum() / (glp * glp).sum() 
            gl = gl - a * glp
            p.grad = (gl + gr) / (dgrad * (gflownet._alpha - 1))  
            if gflownet._alpha < 0: 
                p.grad = - p.grad 
    else: 
        m = loss.max().detach() 
        loss_func = (loss - m).exp() 
        ngrads = autograd.grad((loss_func + loss_func.detach() * traj_stats[0].sum(dim=1)).mean(), 
                gflownet.pf.parameters(), allow_unused=True)
        dgrad = loss_func.mean() 
        for p, ngrad in zip(gflownet.pf.parameters(), ngrads): 
            p.grad = (ngrad) / (dgrad * (gflownet._alpha - 1)) 
            if gflownet._alpha < 0: 
                p.grad = - p.grad 

def gradient_tsallis_div(gflownet, loss, traj_stats): 
    if gflownet.use_cv: 
        # m = loss.max().detach() 
        # loss_func = (loss - m).exp() 
        loss_func = loss.exp() 
        grad_loss, grad_log_prob, grad_reinforce, parameters = generic_control_variate(
            loss_func, traj_stats[0].sum(dim=1), gflownet.pf 
        )
        for p, gl, glp, gr in zip(parameters, grad_loss, grad_log_prob, grad_reinforce): 
            if gl is None or glp is None: continue 
            a = (gl * glp).sum() / (glp * glp).sum() 
            gl = gl - a * glp  
            p.grad = (gl + gr) / (gflownet._alpha - 1) 
            if gflownet._alpha < 0: 
                p.grad = - p.grad 
    else: 
        loss_func = loss.exp() 
        grads = autograd.grad((loss_func + loss_func.detach() * traj_stats[0].sum(dim=1)).mean(), 
                        gflownet.pf.parameters(), allow_unused=True)
        for p, grad in zip(gflownet.pf.parameters(), grads): 
            p.grad = grad / (gflownet._alpha - 1)  
            if gflownet._alpha < 0: 
                p.grad = - p.grad 
            
def rev_kl_cv_rloo(gflownet, loss, traj_stats): 
    ''' 
    We should estimate 
    - E [ p_{B} / p_{F} \nabla \log p_{F} ], 
    which has potentially high variance. As a first approximation, we use 
    a leave one out estimator to reduce the corresponding noise.  
    '''  
    weights = loss + (traj_stats[0] - traj_stats[2]).sum(dim=1)
    weights = (weights - torch.logsumexp(weights, dim=0)).exp().detach() 
    # weights = weights.exp().detach() 
    # weights = loss.exp() 
    
    num_samples = loss.shape[0] 
    I = torch.eye(num_samples, device=loss.device) 
    norm = I - (1 - I) / (num_samples - 1) 
    norm = norm @ weights 

    grad = autograd.grad((norm.detach() * loss).sum(), gflownet.pf.parameters(), allow_unused=True)

    for p, g in zip(gflownet.pf.parameters(), grad): 
        p.grad = g  

def rev_kl_cv_reinforce(gflownet, loss, traj_stats): 
    # Use - \nabla \log p_{F}(\tau) as a control variate 
    weights = loss + (traj_stats[0] - traj_stats[2]).sum(dim=1) 
    weights = (weights - torch.logsumexp(weights, dim=0)).exp() 
    
    control_variate = autograd.grad(
        (traj_stats[2].sum(dim=1)).mean(), gflownet.pf.parameters(), 
        allow_unused=True, retain_graph=True  
    )

    target_quantity = autograd.grad(
        (weights.detach() * traj_stats[0].sum(dim=1)).sum(), 
        gflownet.pf.parameters(), allow_unused=True
    )

    for p, cv, tq in zip(gflownet.pf.parameters(), control_variate, target_quantity): 
        if cv is None: continue  
        a = (cv * tq).sum() / (cv * cv).sum() 
        p.grad = - ( tq - a * cv )  

def gradient_rev_kl_div(gflownet, loss, traj_stats): 
    if gflownet.use_cv: 
        # rev_kl_cv_rloo(gflownet, loss, traj_stats)
        rev_kl_cv_reinforce(gflownet, loss, traj_stats) 
    else: 
        weights = loss + (traj_stats[0] - traj_stats[2]).sum(dim=1) 
        weights = (weights - torch.logsumexp(weights, dim=0)).exp() 
        grads = autograd.grad(- (weights.detach() * traj_stats[0].sum(dim=1)).sum(), 
                    gflownet.pf.parameters(), allow_unused=True) 
        for p, grad in zip(gflownet.pf.parameters(), grads): 
            p.grad = grad 

def gradient_kl_div(gflownet, loss, traj_stats): 
    if gflownet.use_cv: 
        opiw = (traj_stats[0] - traj_stats[2]).sum(dim=1).exp()  
        grad_loss, grad_log_prob, grad_reinforce, parameters = generic_control_variate(
            loss, score_func1=traj_stats[2].sum(dim=1), model=gflownet.pf, 
            score_func2=traj_stats[0].sum(dim=1), weights=opiw
        )
    
        for p, gl, glp, gr in zip(parameters, grad_loss, grad_log_prob, grad_reinforce): 
            if gl is None or glp is None: continue 
            a = (gl * glp).sum() / (glp * glp).sum() 
            gl = gl - a * glp
            p.grad = gl + gr 
    else: 
        grads = autograd.grad((loss + loss.detach() * traj_stats[0].sum(dim=1)).mean(), 
                    gflownet.pf.parameters(), allow_unused=True) 
        for p, grad in zip(gflownet.pf.parameters(), grads): 
            p.grad = grad 
