import jax
import jax.numpy as np
from jax import grad, value_and_grad, jit, vmap
from jax import random

from jax.tree_util import tree_map, tree_multimap, tree_flatten, tree_unflatten, tree_reduce

import numpy as onp

def l2_parameter_loss(x):
    """
    Computes squared l2 norm over a Pytree
    """
    return tree_reduce(lambda x, y: x + y, tree_map(lambda x: (x**2).sum(), x), 0.)

def parameter_count(x):
    """
    Computes total number of parameters over a Pytree
    """
    return tree_reduce(lambda x, y: x + y, tree_map(lambda x: onp.prod(x.shape), x), 0)

def weighted_parameter_loss(params, means, variances, damp=1e-6):
    """
    Uses a quadratic regularizer around the given means with provided diagional variance
    """
    flat_params, _ = tree_flatten(params)
    flat_means, _ = tree_flatten(means)
    flat_variances, _ = tree_flatten(variances)
    return 0.5 * sum([np.sum(np.square(p - mu) / (var + damp)) for p, mu, var in zip(flat_params, flat_means, flat_variances)])


def add_l2_reg(loss, l2_weight):
    def reg_loss_fn(params, state, inputs, labels, net_apply):
        unreg_loss, new_state = loss(params, state, inputs, labels, net_apply)
        reg_loss = l2_weight * l2_parameter_loss(params) + unreg_loss
        return reg_loss, new_state
    return reg_loss_fn
