import flax
import jax
import jax.numpy as jnp
from functools import partial

import tensorflow as tf

from rsm_utils import (
    jax_save,
    jax_load,
    lipschitz_l1_jax,
    martingale_loss,
    triangular,
    IBPMLP,
    MLP,
    create_train_state,
    clip_grad_norm,
    lipschitz_linf_jax,
    get_pmass_grid,
    compute_expected_l,
    jv_contains,
)
import numpy as np

from vppo_jax import vPPO


class RSMLearner:
    def __init__(
        self,
        l_hidden,
        p_hidden,
        env,
        lip_lambda,
        p_lip,
        v_lip,
        eps,
        reach_prob,
        v_activation="relu",
        norm="l1",
        p_lr=0.00005,
        c_lr=0.0005,
        c_ema=0.9,
        v_lr=0.0005,
        v_ema=0.9,
        n_step=1,
    ) -> None:
        """

        :param l_hidden: List of sizes of the hidden layers of the RSM
        :param p_hidden: List of sizes of the hidden layers of the policy
        :param env: The environment
        :param lip_lambda: Regularization factor that is multiplied with the Lipschitz loss
        :param p_lip: Desired maximum value of the Lipschitz bound of the policy
            (i.e., if the Lipschitz bound of the policy is greater than this value the regularization loss will be >0)
        :param v_lip: Desired maximum value of the Lipschitz bound of the RSM
            (i.e., if the Lipschitz bound of the policy is greater than this value the regularization loss will be >0)
        :param eps: Desired expected decrease used during training
            (i.e., if the expected decrease is less than this value, the loss will be greater than 0)
        :param reach_prob: Desired reach-avoid probability.
            (i.e., if the estimated reach-avoid probability is less than this value, the loss will be greater than 0)
        :param v_activation: Activation function of the RSM neural network (relu or tanh)
        :param norm: "l1" or "linf"
        :param p_lr: Learning rate of the policy
        :param c_lr: Learning rate of the value network of the PPO pre-training
        :param c_ema: Exponential moving average factor of the value network of the PPO pre-training
        :param v_lr: Learning rate of the RSM neural network
        :param v_ema: Exponential moving average factor RSM neural network
        :param n_step: Number of environment-policy steps for which the learner should train
        """
        self.env = env
        self.n_step = n_step
        self.eps = jnp.float32(eps)
        self.reach_prob = jnp.float32(reach_prob)
        assert norm in ["l1", "linf"]
        self.norm = norm
        self.estimate_expected_via_ibp = True
        self.norm_fn = lipschitz_l1_jax if norm == "l1" else lipschitz_linf_jax
        action_dim = self.env.action_space.shape[0]
        obs_dim = self.env.observation_dim
        pmass_n = (
            10 if self.env.observation_dim == 2 else 6
        )  # number of sums for the expectation computation
        self._cached_pmass_grid = get_pmass_grid(self.env, pmass_n)

        v_net = MLP(l_hidden + [1], activation=v_activation, softplus_output=True)
        c_net = MLP(l_hidden + [1], activation="relu", softplus_output=False)
        p_net = MLP(p_hidden + [action_dim], activation="relu")

        self.v_ibp = IBPMLP(
            l_hidden + [1], activation=v_activation, softplus_output=True
        )
        self.p_ibp = IBPMLP(
            p_hidden + [action_dim], activation="relu", softplus_output=False
        )
        self.v_state = create_train_state(
            v_net, jax.random.PRNGKey(1), obs_dim, v_lr, ema=v_ema
        )
        self.c_state = create_train_state(
            c_net, jax.random.PRNGKey(3), obs_dim, c_lr, ema=c_ema
        )
        self.p_state = create_train_state(
            p_net,
            jax.random.PRNGKey(2),
            obs_dim,
            p_lr,
        )
        self.p_lip = jnp.float32(p_lip)
        self.v_lip = jnp.float32(v_lip)
        self.lip_lambda = jnp.float32(lip_lambda)

        self.rng = jax.random.PRNGKey(777)
        self._debug_init = []
        self._debug_unsafe = []

    def pretrain_policy(
        self,
        num_iters=10,
        std_start=0.3,
        std_end=0.03,
        lip_start=0.0,
        lip_end=0.1,
        normalize_r=False,
        normalize_a=True,
        save_every=None,
        verbose=True,
    ):
        """
        Runs the PPO pre-training

        :param num_iters: Number of PPO iterations
        :param std_start: Standard deviation of the Gaussian distribution of the policy at the beginning of the training
        :param std_end: Standard deviation of the Gaussian distribution of the policy at the end of the training
        :param lip_start: Lipschitz regularization factor at the beginning of the training
        :param lip_end: Lipschitz regularization factor at the end of the training
        :param normalize_r: Flag (True/False) indicating whether to keep a rolling mean/std to normalize the rewards
        :param normalize_a: Flag (True/False) indicating whether to normalize the advantage values before the policy optimization
        :param save_every: If not None, then the PPO will save the weights after every n iterations
        :param verbose: Flag (True/False) indicating whether to print stats during the training process
        :return:
        """
        ppo = vPPO(
            self.p_state,
            self.c_state,
            self.env,
            self.p_lip,
            norm=self.norm,
            normalize_r=normalize_r,
            normalize_a=normalize_a,
        )
        ppo.run(num_iters, std_start, std_end, lip_start, lip_end, save_every, verbose)

        # Copy from PPO
        self.p_state = ppo.p_state
        self.c_state = ppo.c_state

    def evaluate_rl(self):
        n = 512
        rng = jax.random.PRNGKey(2)
        rng, r = jax.random.split(rng)
        r = jax.random.split(r, n)
        state, obs = self.env.v_reset(r)
        total_reward = jnp.zeros(n)
        done = jnp.zeros(n, dtype=jnp.bool_)
        while not np.any(done):
            action_mean = self.p_state.apply_fn(self.p_state.params, obs)
            rng, r = jax.random.split(rng)
            r = jax.random.split(r, n)
            state, obs, reward, next_done = self.env.v_step(state, action_mean, r)
            total_reward += reward * (1.0 - done)
            done = next_done

        contains = None
        for target_space in self.env.target_spaces:
            c = jv_contains(target_space, obs)
            if contains is not None:
                contains = jnp.logical_or(contains, c)
            else:
                contains = c

        num_end_in_target = jnp.sum(contains.astype(jnp.int64))
        num_traj = contains.shape[0]

        text = f"Rollouts (n={n}): {np.mean(total_reward):0.1f} +- {np.std(total_reward):0.1f} [{np.min(total_reward):0.1f}, {np.max(total_reward):0.1f}] ({100*num_end_in_target/num_traj:0.2f}% end in target)"
        print(text)
        res_dict = {
            "mean_r": np.mean(total_reward),
            "std_r": np.std(total_reward),
            "min_r": np.min(total_reward),
            "max_r": np.max(total_reward),
            "num_end_in_target": num_end_in_target,
            "num_traj": num_traj,
        }
        return text, res_dict

    @partial(jax.jit, static_argnums=(0, 2))
    def sample_init(self, rng, n):
        """Generates n random samples of the initial states"""
        rngs = jax.random.split(rng, len(self.env.init_spaces))
        per_space_n = n // len(self.env.init_spaces)

        batch = []
        for i in range(len(self.env.init_spaces)):
            x = jax.random.uniform(
                rngs[i],
                (per_space_n, self.env.observation_dim),
                minval=self.env.init_spaces[i].low,
                maxval=self.env.init_spaces[i].high,
            )
            batch.append(x)
        return jnp.concatenate(batch, axis=0)

    @partial(jax.jit, static_argnums=(0, 2))
    def sample_unsafe(self, rng, n):
        """Generates n random samples of the unsafe states"""
        rngs = jax.random.split(rng, len(self.env.unsafe_spaces))
        per_space_n = n // len(self.env.unsafe_spaces)

        batch = []
        for i in range(len(self.env.unsafe_spaces)):
            x = jax.random.uniform(
                rngs[i],
                (per_space_n, self.env.observation_dim),
                minval=self.env.unsafe_spaces[i].low,
                maxval=self.env.unsafe_spaces[i].high,
            )
            batch.append(x)
        return jnp.concatenate(batch, axis=0)

    @partial(jax.jit, static_argnums=(0, 2))
    def sample_target(self, rng, n):
        """Generates n random samples of the target states"""
        rngs = jax.random.split(rng, len(self.env.target_spaces))
        per_space_n = n // len(self.env.target_spaces)

        batch = []
        for i in range(len(self.env.target_spaces)):
            x = jax.random.uniform(
                rngs[i],
                (per_space_n, self.env.observation_dim),
                minval=self.env.target_spaces[i].low,
                maxval=self.env.target_spaces[i].high,
            )
            batch.append(x)
        return jnp.concatenate(batch, axis=0)

    @partial(jax.jit, static_argnums=(0,))
    def train_step(self, v_state, p_state, state, rng, current_delta, lipschitz_k):
        """Train for a single step."""
        rngs = jax.random.split(rng, 5)
        init_samples = self.sample_init(rngs[1], 256)
        unsafe_samples = self.sample_unsafe(rngs[2], 256)
        target_samples = self.sample_target(rngs[3], 64)
        # Adds a bit of randomization to the grid
        s_random = jax.random.uniform(rngs[4], state.shape, minval=-0.5, maxval=0.5)
        state = state + current_delta * s_random

        def loss_fn(l_params, p_params, state):
            loss = 0
            for i in range(self.n_step):
                l = v_state.apply_fn(l_params, state)
                a = p_state.apply_fn(p_params, state)

                if self.estimate_expected_via_ibp:
                    pmass, batched_grid_lb, batched_grid_ub = self._cached_pmass_grid
                    exp_l_next = compute_expected_l(
                        self.env,
                        self.v_ibp.apply,
                        l_params,
                        state,
                        a,
                        pmass,
                        batched_grid_lb,
                        batched_grid_ub,
                    )
                else:
                    s_next = self.env.v_next(state, a)
                    s_next = jnp.expand_dims(
                        s_next, axis=1
                    )  # broadcast dim 1 with random noise
                    noise = triangular(
                        rngs[0], (s_next.shape[0], 16, self.env.observation_dim)
                    )
                    noise = noise * self.env.noise
                    s_next_random = s_next + noise
                    l_next_fn = jax.vmap(v_state.apply_fn, in_axes=(None, 0))
                    l_next = l_next_fn(l_params, s_next_random)
                    exp_l_next = jnp.mean(l_next, axis=1)

                exp_l_next = exp_l_next.flatten()
                l = l.flatten()
                violations = (exp_l_next >= l).astype(jnp.float32)
                violations = jnp.mean(violations)

                dec_loss = martingale_loss(l, exp_l_next, lipschitz_k + self.eps)
                loss += dec_loss
                state = self.env.v_next(state, a)
            K_l = self.norm_fn(l_params)
            K_p = self.norm_fn(p_params)
            lip_loss_l = jnp.maximum(K_l - self.v_lip, 0)
            lip_loss_p = jnp.maximum(K_p - self.p_lip, 0)
            loss += self.lip_lambda * (lip_loss_l + lip_loss_p)

            if float(self.reach_prob) < 1.0:
                # Train RA objectives

                # Zero at zero
                # s_zero = jnp.zeros(self.env.observation_dim)
                # l_at_zero = v_state.apply_fn(l_params, s_zero)
                # loss += jnp.sum(
                #     jnp.maximum(jnp.abs(l_at_zero), 0.3)
                # )  # min to an eps of 0.3

                l_at_init = v_state.apply_fn(l_params, init_samples)
                l_at_unsafe = v_state.apply_fn(l_params, unsafe_samples)
                l_at_target = v_state.apply_fn(l_params, target_samples)

                max_at_init = jnp.max(l_at_init)
                min_at_unsafe = jnp.min(l_at_unsafe)
                # Maximize this term to at least 1/(1-reach prob)
                loss += -jnp.minimum(min_at_unsafe, 1 / (1 - self.reach_prob))

                # Minimize the max at init to below 1
                loss += jnp.maximum(max_at_init, 1)

                # Global minimum should be inside target
                min_at_target = jnp.min(l_at_target)
                min_at_init = jnp.min(l_at_init)
                min_at_unsafe = jnp.min(l_at_unsafe)
                loss += jnp.maximum(min_at_target - min_at_init, 0)
                loss += jnp.maximum(min_at_target - min_at_unsafe, 0)

            return loss, (dec_loss, violations)

        grad_fn = jax.value_and_grad(loss_fn, has_aux=True, argnums=(0, 1))
        (loss, (dec_loss, violations)), (l_grad, p_grad) = grad_fn(
            v_state.params, p_state.params, state
        )
        # Apply gradient clipping to stabilize training
        # p_grad = clip_grad_norm(p_grad, 1)
        # l_grad = clip_grad_norm(l_grad, 1)
        v_state = v_state.apply_gradients(grads=l_grad)
        p_state = p_state.apply_gradients(grads=p_grad)
        metrics = {"loss": loss, "dec_loss": dec_loss, "train_violations": violations}
        return v_state, p_state, metrics

    def train_epoch(
        self, train_ds, current_delta=0, lipschitz_k=0, train_v=True, train_p=True
    ):
        """Train for a single epoch."""
        current_delta = jnp.float32(current_delta)
        lipschitz_k = jnp.float32(lipschitz_k)
        batch_metrics = []

        for state in train_ds.as_numpy_iterator():
            state = jnp.array(state)
            self.rng, rng = jax.random.split(self.rng, 2)

            new_v_state, new_p_state, metrics = self.train_step(
                self.v_state, self.p_state, state, rng, current_delta, lipschitz_k
            )
            if train_p:
                self.p_state = new_p_state
            if train_v:
                self.v_state = new_v_state
            batch_metrics.append(metrics)

        # compute mean of metrics across each batch in epoch.
        batch_metrics_np = jax.device_get(batch_metrics)
        epoch_metrics_np = {
            k: np.mean([metrics[k] for metrics in batch_metrics_np])
            for k in batch_metrics_np[0]
        }

        return epoch_metrics_np

    def save(self, filename):
        jax_save(
            {"policy": self.p_state, "value": self.c_state, "martingale": self.v_state},
            filename,
        )

    def load(self, filename, force_load_all=True):
        try:
            params = jax_load(
                {
                    "policy": self.p_state,
                    "value": self.c_state,
                    "martingale": self.v_state,
                },
                filename,
            )
            self.p_state = params["policy"]
            self.v_state = params["martingale"]
            self.c_state = params["value"]
        except KeyError as e:
            if force_load_all:
                raise e
            # Legacy load
            try:
                params = {"policy": self.p_state, "value": self.c_state}
                params = jax_load(params, filename)
                self.p_state = params["policy"]
                self.c_state = params["value"]
            except KeyError:
                params = {"policy": self.p_state}
                params = jax_load(params, filename)
                self.p_state = params["policy"]