"""
Model classes for SALT.
"""

import jax.numpy as np
import jax.random as jr
from jax import vmap, lax, jit
import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions
from jax.config import config
config.update("jax_enable_x64", True)

from jax.tree_util import register_pytree_node_class

from ssm.arhmm.base import AutoregressiveHMM
from ssm.hmm.initial import StandardInitialCondition
from ssm.hmm.transitions import StationaryTransitions
from ssm.arhmm.emissions import AutoregressiveEmissions

from salt.emissions import SALTEmissions

supported_modes = ['cp', 'tucker']

@register_pytree_node_class
class SALT(AutoregressiveHMM):
    def __init__(self,
                 num_states: int,
                 num_emission_dims: int=1,
                 num_lags: int=1,
                 core_tensor_dims: tuple=(1, 1, 1), # output, input, lag
                 initial_state_probs: np.ndarray=None,
                 transition_matrix: np.ndarray=None,
                 emission_output_factors: np.ndarray=None,
                 emission_input_factors: np.ndarray=None,
                 emission_lag_factors: np.ndarray=None,
                 emission_core_tensors: np.ndarray=None,
                 emission_biases: np.ndarray=None,
                 emission_covariance_matrix_sqrts: np.ndarray=None,
                 seed: jr.PRNGKey=None,
                 mode: str='cp',
                 l2_penalty: float=1e-4,
                 dtype=np.float64,
                 sigma: float=1.0):
        """Switching Autoregressive Low-rank Tensor Model (SALT).
        Args:
            num_states (int): number of discrete latent states
            num_emission_dims (int, optional): number of emission dims.
                Defaults to 1.
            num_lags (int, optional): number of previous timesteps on which to autoregress.
                Defaults to 1.
            core_tensor_dims (tuple, optional): ranks of SALT for (output, input, lag) factors
                Defaults to (1, 1, 1).
            initial_state_probs (np.ndarray, optional): initial state probabilities
                Defaults to None.
            transition_matrix (np.ndarray, optional): transition matrix
                Defaults to None.
            emission_output_factors (np.ndarray, optional): emission output factors
                Defaults to None.
            emission_input_factors (np.ndarray, optional): emission input factors
                Defaults to None.
            emission_lag_factors (np.ndarray, optional): emission lag factors
                Defaults to None.
            emission_core_tensors (np.ndarray, optional): emission core tensors
                Defaults to None.
            emission_biases (np.ndarray, optional): emission biases
                Defaults to None.
            emission_covariance_matrix_sqrts (np.ndarray, optional): square root of emission covariances
                Defaults to None.
            seed (jr.PRNGKey, optional): random seed. 
                Defaults to None.
            mode (str, optional): ['cp', 'tucker']-SALT
                Defaults to 'cp'.
            l2_penalty (float, optional): non-negative l2 regularization strength
                Defaults to 1e-4.
            dtype (optional): precision
                Defaults to np.float64.
            sigma (float, optional): scale for sampling SALT parameters
                Defaults to 1.0.
        """
        
        assert l2_penalty >= 0, "Invalid penalty"
        
        mode = mode.lower()
        if mode not in supported_modes:
            raise ValueError(
                f"'mode' should be from {supported_modes}"
            )
        if mode == "cp":
            if not (core_tensor_dims[0] == core_tensor_dims[1] == core_tensor_dims[2]):
                raise ValueError(
                    f"'core_tensor_dims' should have same dimensions for mode {mode}"
                )

        if initial_state_probs is None:
            initial_state_probs = np.ones(num_states).astype(dtype) / num_states

        if transition_matrix is None:
            transition_matrix = np.ones((num_states, num_states)).astype(dtype) / num_states
            
        if emission_output_factors is None:
            this_seed, seed = jr.split(seed, 2)
            emission_output_factors = tfd.Normal(0, sigma).sample(
                seed=this_seed,
                sample_shape=(num_states, num_emission_dims, core_tensor_dims[0])).astype(dtype)

        if emission_input_factors is None:
            this_seed, seed = jr.split(seed, 2)
            emission_input_factors = tfd.Normal(0, sigma).sample(
                seed=this_seed,
                sample_shape=(num_states, num_emission_dims, core_tensor_dims[1])).astype(dtype)

        if emission_lag_factors is None:
            this_seed, seed = jr.split(seed, 2)
            emission_lag_factors = tfd.Normal(0, sigma).sample(
                seed=this_seed,
                sample_shape=(num_states, num_lags, core_tensor_dims[2])).astype(dtype)

        if emission_core_tensors is None:
            if mode == 'tucker':
                this_seed, seed = jr.split(seed, 2)
                emission_core_tensors = tfd.Normal(0, sigma).sample(
                    seed=this_seed,
                    sample_shape=(num_states,) + core_tensor_dims).astype(dtype)
            elif mode == 'cp':
                idx = np.arange(core_tensor_dims[1])
                emission_core_tensors = np.zeros((num_states,) + core_tensor_dims).astype(dtype)
                emission_core_tensors = emission_core_tensors.at[:,idx,idx,idx].set(1)

        if emission_biases is None:
            this_seed, seed = jr.split(seed, 2)
            emission_biases = tfd.Normal(0, sigma).sample(
                seed=this_seed,
                sample_shape=(num_states, num_emission_dims)).astype(dtype)

        if emission_covariance_matrix_sqrts is None:
            emission_covariance_matrix_sqrts = np.tile(np.eye(num_emission_dims), (num_states, 1, 1)).astype(dtype)

        initial_condition = StandardInitialCondition(num_states, initial_probs=initial_state_probs)
        transitions = StationaryTransitions(num_states, transition_matrix=transition_matrix)
        emissions = SALTEmissions(num_states,
                                  mode,
                                  l2_penalty=l2_penalty,
                                  input_factors=emission_input_factors,
                                  output_factors=emission_output_factors,
                                  lag_factors=emission_lag_factors,
                                  core_tensors=emission_core_tensors,
                                  biases=emission_biases,
                                  covariance_matrix_sqrts=emission_covariance_matrix_sqrts)
        super(SALT, self).__init__(num_states,
                                   initial_condition,
                                   transitions,
                                   emissions)

    @property
    def num_lags(self):
        return self._emissions.num_lags

    def tree_flatten(self):
        children = (self._initial_condition,
                    self._transitions,
                    self._emissions)
        aux_data = self._num_states
        return children, aux_data

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        obj = object.__new__(cls)
        super(cls, obj).__init__(aux_data, *children)
        return obj