import jax.numpy as jnp
import flax.linen as nn
from jax.lax import cond, stop_gradient
import jax.scipy.linalg as jla




class DP(nn.Module):
    update_freq: int = 1
    alpha: float = 0.5
    tau: float = 0.3
    normalize: bool = True
    pc_thresh: float = 0.0
    pc_cutoff_method: str = 'dim' # or 'value'
    pc_num_components: int = -1
    eigval_floor: float = 0.0

    @nn.compact
    def __call__(self, x, train=True, is_target_net=False):

        dim = x.shape[-1]

        run_corr = self.variable('direct_pred', 'run_corr',
                             lambda s: jnp.zeros(s, jnp.float32), (dim, dim))

        Wp = self.variable('direct_pred', 'Wp',
                           lambda s: jnp.eye(s, dtype=jnp.float32), dim)

        tick = self.variable('direct_pred', 'tick', lambda a: a, 0)

        ev = self.variable('direct_pred', 'ev', lambda s: jnp.ones(s, jnp.float32), dim)

        U = self.variable('direct_pred', 'U', 
                          lambda s: jnp.eye(s, dtype=jnp.float32), dim)

        if train and not is_target_net: 
            corr = stop_gradient( jnp.matmul(x.T, x) / x.shape[0] )

            if not self.is_initializing():
                run_corr.value = self.tau * run_corr.value + (1 - self.tau) * corr
                tick.value = (tick.value + 1) % self.update_freq
                U_, ev_, Wp_ = cond(tick.value == 0, 
                                lambda a: self._update_Wp(run_corr.value), 
                                lambda a: a,
                                (U.value, ev.value, Wp.value))
                U.value, ev.value, Wp.value = U_, ev_, Wp_

        return jnp.matmul(x, Wp.value)

    def _update_Wp(self, corr):
        # compute the eigen decomposition of the correlation matrix
        s, U = jnp.linalg.eigh(corr)
        s = jnp.real(s)
        U = jnp.real(U)

        if self.normalize:
            s = jnp.clip(s, a_min=self.eigval_floor)/jnp.max(s)
        else:
            s = jnp.nan_to_num(jnp.clip(s, a_min=self.eigval_floor, a_max=1e6))

        # drop the lowest eigenvalues
        dim = s.shape[0]
        eff_dim = self.pc_num_components if self.pc_num_components > 0 else dim

        # sort the eigenvalues and eigenvectors
        indices = jnp.argsort(s)[::-1]
        s = s[indices]
        U = U[:, indices]

        if self.pc_cutoff_method == 'dim':
            s = jnp.where(jnp.arange(dim)<eff_dim, s, 0.)
        elif self.pc_cutoff_method == 'value':
            raise NotImplementedError
            # s = jnp.where(s>self.eigen_thresh, s, 0.)

        assert self.alpha >= 0.0
        p = jnp.where(jnp.not_equal(s, 0.), jnp.power(s, self.alpha), 0.)

        Wp = jnp.matmul(jnp.matmul(U, jnp.diag(p)), U.T)

        return U, p, Wp






