import gym
from gym import spaces
import numpy as np
from os import path
from scipy.stats import triang
import jax.numpy as jnp
from functools import partial
import jax
import matplotlib.pyplot as plt
import os
from rsm_utils import triangular, make_unsafe_spaces, contained_in_any


def angle_normalize(x):
    return ((x + np.pi) % (2 * np.pi)) - np.pi


class vDebugEnv(gym.Env):
    def __init__(self):
        self.has_render = False
        self.name = f"debug"

        safe = np.array([0.1, 0.1], np.float32)
        self.target_spaces = [spaces.Box(low=-safe, high=safe, dtype=np.float32)]
        self.init_spaces = [
            spaces.Box(
                low=np.array([-0.2, -0.1]),
                high=np.array([-0.1, 0.1]),
                dtype=np.float32,
            ),
            spaces.Box(
                low=np.array([0.1, -0.1]),
                high=np.array([0.2, 0.1]),
                dtype=np.float32,
            ),
        ]
        self.action_space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)
        self.observation_space = spaces.Box(
            low=-0.5 * np.ones(2, dtype=np.float32),
            high=0.5 * np.ones(2, dtype=np.float32),
            dtype=np.float32,
        )
        # self.noise = np.array([0.01, 0.005])
        self.noise = np.array([0.02, 0.005])
        self.unsafe_spaces = [
            spaces.Box(
                low=self.observation_space.low,
                high=np.array([self.observation_space.low[0] + 0.1, 0.0]),
                dtype=np.float32,
            ),
            spaces.Box(
                low=np.array([self.observation_space.high[0] - 0.1, 0.0]),
                high=self.observation_space.high,
                dtype=np.float32,
            ),
        ]

        self._jax_rng = jax.random.PRNGKey(777)
        self.v_next = jax.vmap(self.next, in_axes=(0, 0), out_axes=0)
        self.v_step = jax.jit(jax.vmap(self.step))
        self.v_reset = jax.jit(jax.vmap(self.reset))

    @property
    def noise_bounds(self):
        return -self.noise, self.noise

    @property
    def observation_dim(self):
        return self.observation_space.shape[0]

    @property
    def action_dim(self):
        return self.action_space.shape[0]

    @partial(jax.jit, static_argnums=(0,))
    def next(self, state, action):
        action = jnp.clip(action, -1, 1)

        new_y = 0.9 * state[1] + action[0] * 0.3
        new_x = 0.9 * state[0] + new_y * 0.1
        new_y = np.clip(
            new_y, self.observation_space.low[1], self.observation_space.high[1]
        )
        new_x = np.clip(
            new_x, self.observation_space.low[0], self.observation_space.high[0]
        )
        return jnp.array([new_x, new_y])

    def add_noise(self, state):
        self._jax_rng, rng = jax.random.split(self._jax_rng, 2)
        noise = triangular(rng, (self.observation_space.shape[0],))
        noise = noise * self.noise
        return state + noise

    @partial(jax.jit, static_argnums=(0,))
    def step(self, state, action, rng):
        step = state[0]
        state = state[1:3]
        next_state = self.next(state, action)

        noise = triangular(rng, (self.observation_space.shape[0],))
        noise = noise * self.noise
        next_state = next_state + noise
        next_state = np.clip(
            next_state, self.observation_space.low, self.observation_space.high
        )

        reward = 0
        for unsafe in self.unsafe_spaces:
            contain = jnp.all(
                jnp.logical_and(state >= unsafe.low, state <= unsafe.high)
            )
            reward += -jnp.float32(contain)
        for target in self.target_spaces:
            contain = jnp.all(
                jnp.logical_and(state >= target.low, state <= target.high)
            )
            reward += jnp.float32(contain)

        reward -= 2 * jnp.mean(jnp.abs(next_state / self.observation_space.high))
        done = step >= 200
        next_packed = jnp.array([step + 1, next_state[0], next_state[1]])
        return next_packed, next_state, reward, done

    @partial(jax.jit, static_argnums=(0,))
    def reset(self, rng):
        # lowers = jnp.stack([init.low for init in self.init_spaces], 0)
        # high = jnp.stack([init.high for init in self.init_spaces], 0)
        # rng1, rng2 = jax.random.split(rng, 2)
        # index = jax.random.randint(
        #     rng1, shape=(), minval=0, maxval=len(self.init_spaces)
        # )
        # obs = jax.random.uniform(
        #     rng2, shape=(lowers.shape[1],), minval=lowers[index], maxval=high[index]
        # )
        obs = jax.random.uniform(
            rng,
            shape=(self.observation_space.shape[0],),
            minval=self.observation_space.low,
            maxval=self.observation_space.high,
        )
        state = jnp.array([0, obs[0], obs[1]])
        return state, obs

    @property
    def lipschitz_constant(self):
        A = np.max(np.sum(np.array([[1, 0.045, 0.45], [0, 0.9, 0.5]]), axis=0))
        return A

    @property
    def lipschitz_constant_linf(self):
        A = np.max(np.sum(np.array([[1, 0.045, 0.45], [0, 0.9, 0.5]]), axis=1))
        return A

    @property
    def delta(self):
        return 0.1 + self.noise[0]

    def integrate_noise(self, a: list, b: list):
        dims = 2
        pmass = np.ones(a[0].shape[0])
        for i in range(dims):
            loc = self.noise_bounds[0][i]
            scale = self.noise_bounds[1][i] - self.noise_bounds[0][i]
            marginal_pmass = triang.cdf(b[i], c=0.5, loc=loc, scale=scale) - triang.cdf(
                a[i], c=0.5, loc=loc, scale=scale
            )
            pmass *= marginal_pmass
        return pmass


class vLDSEnv(gym.Env):
    def __init__(self):
        self.has_render = False
        self.name = f"lds"

        safe = np.array([0.1, 0.1], np.float32)
        self.target_spaces = [spaces.Box(low=-safe, high=safe, dtype=np.float32)]
        self.init_spaces = [
            spaces.Box(
                low=np.array([-0.2, -0.1]),
                high=np.array([-0.1, 0.1]),
                dtype=np.float32,
            ),
            spaces.Box(
                low=np.array([0.1, -0.1]),
                high=np.array([0.2, 0.1]),
                dtype=np.float32,
            ),
        ]
        self.action_space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)
        self.observation_space = spaces.Box(
            low=-0.5 * np.ones(2, dtype=np.float32),
            high=0.5 * np.ones(2, dtype=np.float32),
            dtype=np.float32,
        )
        # self.noise = np.array([0.01, 0.005])
        self.noise = np.array([0.02, 0.005])
        self.unsafe_spaces = [
            spaces.Box(
                low=self.observation_space.low,
                high=np.array([self.observation_space.low[0] + 0.1, 0.0]),
                dtype=np.float32,
            ),
            spaces.Box(
                low=np.array([self.observation_space.high[0] - 0.1, 0.0]),
                high=self.observation_space.high,
                dtype=np.float32,
            ),
        ]

        self._jax_rng = jax.random.PRNGKey(777)
        self.v_next = jax.vmap(self.next, in_axes=(0, 0), out_axes=0)
        self.v_step = jax.jit(jax.vmap(self.step))
        self.v_reset = jax.jit(jax.vmap(self.reset))

    @property
    def noise_bounds(self):
        return -self.noise, self.noise

    @property
    def observation_dim(self):
        return self.observation_space.shape[0]

    @property
    def action_dim(self):
        return self.action_space.shape[0]

    @partial(jax.jit, static_argnums=(0,))
    def next(self, state, action):
        action = jnp.clip(action, -1, 1)

        tau = 0.9
        new_y = state[1] * tau + action[0] * 0.3
        new_x = state[0] * 1.0 + new_y * 0.1
        # new_y = state[1] * tau + action[0] * 0.3 * (1.0 - 0.9 / tau)
        # new_x = state[0] * 1.0 + new_y * 0.1 * (1.0 - 0.9 / tau)
        new_y = np.clip(
            new_y, self.observation_space.low[1], self.observation_space.high[1]
        )
        new_x = np.clip(
            new_x, self.observation_space.low[0], self.observation_space.high[0]
        )
        return jnp.array([new_x, new_y])

    def add_noise(self, state):
        self._jax_rng, rng = jax.random.split(self._jax_rng, 2)
        noise = triangular(rng, (self.observation_space.shape[0],))
        noise = noise * self.noise
        return state + noise

    @partial(jax.jit, static_argnums=(0,))
    def step(self, state, action, rng):
        step = state[0]
        state = state[1:3]
        next_state = self.next(state, action)

        noise = triangular(rng, (self.observation_space.shape[0],))
        noise = noise * self.noise
        next_state = next_state + noise
        next_state = np.clip(
            next_state, self.observation_space.low, self.observation_space.high
        )

        reward = 0
        for unsafe in self.unsafe_spaces:
            contain = jnp.all(
                jnp.logical_and(state >= unsafe.low, state <= unsafe.high)
            )
            reward += -jnp.float32(contain)
        for target in self.target_spaces:
            contain = jnp.all(
                jnp.logical_and(state >= target.low, state <= target.high)
            )
            reward += jnp.float32(contain)

        reward -= 2 * jnp.mean(jnp.abs(next_state / self.observation_space.high))
        done = step >= 200
        next_packed = jnp.array([step + 1, next_state[0], next_state[1]])
        return next_packed, next_state, reward, done

    @partial(jax.jit, static_argnums=(0,))
    def reset(self, rng):
        # lowers = jnp.stack([init.low for init in self.init_spaces], 0)
        # high = jnp.stack([init.high for init in self.init_spaces], 0)
        # rng1, rng2 = jax.random.split(rng, 2)
        # index = jax.random.randint(
        #     rng1, shape=(), minval=0, maxval=len(self.init_spaces)
        # )
        # obs = jax.random.uniform(
        #     rng2, shape=(lowers.shape[1],), minval=lowers[index], maxval=high[index]
        # )
        obs = jax.random.uniform(
            rng,
            shape=(self.observation_space.shape[0],),
            minval=self.observation_space.low,
            maxval=self.observation_space.high,
        )
        state = jnp.array([0, obs[0], obs[1]])
        return state, obs

    @property
    def lipschitz_constant(self):
        A = np.max(np.sum(np.array([[1, 0.045, 0.45], [0, 0.9, 0.5]]), axis=0))
        return A

    @property
    def lipschitz_constant_linf(self):
        A = np.max(np.sum(np.array([[1, 0.045, 0.45], [0, 0.9, 0.5]]), axis=1))
        return A

    @property
    def delta(self):
        return 0.1 + self.noise[0]

    def integrate_noise(self, a: list, b: list):
        dims = 2
        pmass = np.ones(a[0].shape[0])
        for i in range(dims):
            loc = self.noise_bounds[0][i]
            scale = self.noise_bounds[1][i] - self.noise_bounds[0][i]
            marginal_pmass = triang.cdf(b[i], c=0.5, loc=loc, scale=scale) - triang.cdf(
                a[i], c=0.5, loc=loc, scale=scale
            )
            pmass *= marginal_pmass
        return pmass


l = 0.1
m = 0.05
g = 9.81
delta = 0.01


def create_2_link_mat(state, action):
    I = 0.1
    U = 2.0
    b = 0.3
    phi, phi_dot = jnp.split(state, 2)
    a00 = I + m * jnp.square(l) + jnp.square(l) * m
    a11 = I + m * jnp.square(l)
    a01 = m * l * l
    a10 = a01
    M = jnp.array(
        [[a00, a01 * jnp.cos(phi[0] - phi[1])], [a10 * jnp.cos(phi[1] - phi[0]), a11]]
    )
    C = jnp.array(
        [
            [0, -a01 * phi_dot[1] * jnp.sin(phi[1] - phi[0])],
            [-a10 * phi_dot[0] * jnp.sin(phi[0] - phi[1]), 0],
        ]
    )
    b0 = (m * l + l * m) * g
    b1 = (m * l) * g
    tau = jnp.array([-b0 * jnp.sin(phi[0]), -b1 * jnp.sin(phi[1])])

    M_inv = jnp.linalg.inv(M)
    phi_dot_new = (1 - b) * phi_dot + delta * (
        jnp.dot(M_inv, jnp.clip(jnp.dot(-C, phi) - tau, -1.2, 1.2) + U * action)
    )
    phi_new = phi + delta * phi_dot_new
    return jnp.concatenate([phi_new, phi_dot_new])


class vHumanoidBalance2:

    name = "human2"

    def __init__(self):
        self.n = 2
        self._fig_id = 0

        high = np.array([0.4, 0.4, 0.35, 0.35], dtype=np.float32)
        self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)
        self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)

        init = np.array([0.2, 0.2, 0.2, 0.2], np.float32)
        self.init_spaces = [spaces.Box(low=-init, high=init, dtype=np.float32)]

        self.unsafe_spaces = make_unsafe_spaces(
            self.observation_space, np.array([0.35, 0.35, 0.3, 0.3], np.float32)
        )

        self.noise = np.array([0.001, 0.001, 0.0, 0.0])

        safe = np.array([0.1, 0.1, 0.1, 0.1], np.float32)
        self.target_spaces = [spaces.Box(low=-safe, high=safe, dtype=np.float32)]

        self._jax_rng = jax.random.PRNGKey(777)
        self.v_next = jax.vmap(self.next, in_axes=(0, 0), out_axes=0)
        self.v_step = jax.jit(jax.vmap(self.step))
        self.v_reset = jax.jit(jax.vmap(self.reset))

    def plot(self):
        os.makedirs("plots_kin", exist_ok=True)
        filename = "plots_kin/" + str(self._fig_id) + ".png"
        self._fig_id = self._fig_id + 1
        plt.figure()

        t1 = float(self.state[0])
        t2 = float(self.state[1])
        x0, y0 = 0, 0
        x1 = l * math.cos(t1)
        y1 = l * math.sin(t1)

        # Coordinate points of the manipulator
        x2 = x1 + l * math.cos(t2)
        y2 = y1 + l * math.sin(t2)

        plt.plot([y0, y1], [x0, x1])
        plt.plot([y1, y2], [x1, x2])

        # t1 = float(self.state2[0])
        # t2 = float(self.state2[1])
        # x0, y0 = 0, 0
        # x1 = l * math.cos(t1)
        # y1 = l * math.sin(t1)
        #
        # # Coordinate points of the manipulator
        # x2 = x1 + l * math.cos(t2)
        # y2 = y1 + l * math.sin(t2)
        #
        # plt.plot([y0, y1], [x0, x1])
        # plt.plot([y1, y2], [x1, x2])

        plt.xlim([-2, 2])
        plt.ylim([-2, 2])
        plt.savefig(filename)
        plt.close()

    @partial(jax.jit, static_argnums=(0,))
    def next(self, state, action):
        action = jnp.clip(action, -1, 1)
        new_state = create_2_link_mat(state, action)
        new_state = jnp.clip(
            new_state, self.observation_space.low, self.observation_space.high
        )
        return new_state

    def add_noise(self, state):
        self._jax_rng, rng = jax.random.split(self._jax_rng, 2)
        noise = triangular(rng, (self.observation_space.shape[0],))
        noise = noise * self.noise
        return state + noise

    @partial(jax.jit, static_argnums=(0,))
    def step(self, state, action, rng):
        step = state[0]
        state = state[1:5]
        next_state = self.next(state, action)

        noise = triangular(rng, (self.observation_space.shape[0],))
        noise = noise * self.noise
        next_state = next_state + noise
        next_state = jnp.clip(
            next_state, self.observation_space.low, self.observation_space.high
        )

        next_step = step + 1

        reward = 0
        for unsafe in self.unsafe_spaces:
            contain = jnp.all(
                jnp.logical_and(state >= unsafe.low, state <= unsafe.high)
            )
            reward += -10 * jnp.float32(contain)
            # next_step += 200 * jnp.int32(contain)
        for target in self.target_spaces:
            contain = jnp.all(
                jnp.logical_and(state >= target.low, state <= target.high)
            )
            center = 0.5 * (target.low + target.high)
            dist = jnp.sum(jnp.abs(center - next_state))
            dist = jnp.clip(dist, 0, 2)
            reward += 2 * (2.0 - dist)
            reward += jnp.float32(contain)

        done = next_step > 200
        next_step = jnp.minimum(next_step, 200)
        next_packed = jnp.array(
            [next_step, next_state[0], next_state[1], next_state[2], next_state[3]]
        )
        return next_packed, next_state, reward, done

    @property
    def noise_bounds(self):
        return -self.noise[0:2], self.noise[0:2]

    @property
    def observation_dim(self):
        return self.observation_space.shape[0]

    @property
    def action_dim(self):
        return self.action_space.shape[0]

    @property
    def lipschitz_constant(self):
        return 1.06

    @property
    def lipschitz_constant_linf(self):
        return 1.06

    @property
    def delta(self):
        return 0.1 + self.noise[0]

    def integrate_noise(self, a: list, b: list):
        dims = 2
        pmass = np.ones(a[0].shape[0])
        for i in range(dims):
            loc = self.noise_bounds[0][i]
            scale = self.noise_bounds[1][i] - self.noise_bounds[0][i]
            marginal_pmass = triang.cdf(b[i], c=0.5, loc=loc, scale=scale) - triang.cdf(
                a[i], c=0.5, loc=loc, scale=scale
            )
            pmass *= marginal_pmass
        return pmass

    @partial(jax.jit, static_argnums=(0,))
    def reset(self, rng):
        obs = jax.random.uniform(
            rng,
            shape=(self.observation_space.shape[0],),
            minval=self.observation_space.low,
            maxval=self.observation_space.high,
        )
        state = jnp.array([0, obs[0], obs[1], obs[2], obs[3]])
        return state, obs


class vInvertedPendulum(gym.Env):
    def __init__(self):
        self.name = "vpend"

        init = np.array([0.3, 0.3], np.float32)
        self.init_spaces = [spaces.Box(low=-init, high=init, dtype=np.float32)]
        init = np.array([-1, 1], np.float32)
        self.init_spaces_train = [spaces.Box(low=-init, high=init, dtype=np.float32)]

        self.action_space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)
        self.noise = np.array([0.02, 0.01])

        safe = np.array([0.2, 0.2], np.float32)
        self.target_spaces = [spaces.Box(low=-safe, high=safe, dtype=np.float32)]
        safe = np.array([0.1, 0.1], np.float32)
        self.target_space_train = spaces.Box(low=-safe, high=safe, dtype=np.float32)

        # observation_space = np.array([1.5, 1.5], np.float32)  # make it fail
        # observation_space = np.array([0.7, 0.7], np.float32)
        observation_space = np.array([0.5, 0.5], np.float32)  # same as in AAAI
        self.observation_space = spaces.Box(
            low=-observation_space, high=observation_space, dtype=np.float32
        )

        self.unsafe_spaces = [
            spaces.Box(
                low=self.observation_space.low,
                high=np.array([self.observation_space.low[0] + 0.1, 0.0]),
                dtype=np.float32,
            ),
            spaces.Box(
                low=np.array([self.observation_space.high[0] - 0.1, 0.0]),
                high=self.observation_space.high,
                dtype=np.float32,
            ),
        ]

        self._jax_rng = jax.random.PRNGKey(777)
        self.v_next = jax.vmap(self.next, in_axes=(0, 0), out_axes=0)
        self.v_step = jax.jit(jax.vmap(self.step))
        self.v_reset = jax.jit(jax.vmap(self.reset))

    @partial(jax.jit, static_argnums=(0,))
    def next(self, state, action):
        th, thdot = state  # th := theta
        max_speed = 5
        dt = 0.01
        g = 10
        m = 0.15
        l = 0.3  # was 0.5 before
        b = 0.1

        u = 1.5 * jnp.clip(action, -1, 1)[0]
        newthdot = (1 - b) * thdot + (
            -3 * g * 0.5 / (2 * l) * jnp.sin(th + jnp.pi) + 3.0 / (m * l**2) * u
        ) * dt
        newthdot = jnp.clip(newthdot, -max_speed, max_speed)
        newth = th + newthdot * dt

        newth = jnp.clip(
            newth, self.observation_space.low[0], self.observation_space.high[0]
        )
        newthdot = jnp.clip(
            newthdot, self.observation_space.low[1], self.observation_space.high[1]
        )
        return jnp.array([newth, newthdot])

    def add_noise(self, state):
        self._jax_rng, rng = jax.random.split(self._jax_rng, 2)
        noise = triangular(rng, (self.observation_space.shape[0],))
        noise = noise * self.noise
        return state + noise

    @partial(jax.jit, static_argnums=(0,))
    def step(self, state, action, rng):
        step = state[0]
        state = state[1:3]
        next_state = self.next(state, action)

        noise = triangular(rng, (self.observation_space.shape[0],))
        noise = noise * self.noise
        next_state = next_state + noise
        next_state = np.clip(
            next_state, self.observation_space.low, self.observation_space.high
        )

        reward = 0
        for unsafe in self.unsafe_spaces:
            contain = jnp.all(
                jnp.logical_and(state >= unsafe.low, state <= unsafe.high)
            )
            reward += -jnp.float32(contain)
        contain = jnp.all(
            jnp.logical_and(
                state >= self.target_space_train.low,
                state <= self.target_space_train.high,
            )
        )
        reward += jnp.float32(contain)

        th, thdot = next_state
        costs = angle_normalize(th) ** 2 + 0.1 * thdot**2
        reward -= costs
        done = step >= 200
        next_packed = jnp.array([step + 1, next_state[0], next_state[1]])
        return next_packed, next_state, reward, done

    @property
    def observation_dim(self):
        return self.observation_space.shape[0]

    @property
    def action_dim(self):
        return self.action_space.shape[0]

    @property
    def noise_bounds(self):
        return -self.noise, self.noise

    @property
    def lipschitz_constant(self):
        return 1.78

    @property
    def lipschitz_constant_linf(self):
        return 1.78

    @property
    def delta(self):
        return 0.1 + self.noise[0]

    def integrate_noise(self, a: list, b: list):
        dims = 2
        pmass = np.ones(a[0].shape[0])
        for i in range(dims):
            loc = self.noise_bounds[0][i]
            scale = self.noise_bounds[1][i] - self.noise_bounds[0][i]
            marginal_pmass = triang.cdf(b[i], c=0.5, loc=loc, scale=scale) - triang.cdf(
                a[i], c=0.5, loc=loc, scale=scale
            )
            pmass *= marginal_pmass
        return pmass

    @partial(jax.jit, static_argnums=(0,))
    def reset(self, rng):
        # lowers = jnp.stack([init.low for init in self.init_spaces], 0)
        # high = jnp.stack([init.high for init in self.init_spaces], 0)
        # rng1, rng2 = jax.random.split(rng, 2)
        # index = jax.random.randint(
        #     rng1, shape=(), minval=0, maxval=len(self.init_spaces)
        # )
        # obs = jax.random.uniform(
        #     rng2, shape=(lowers.shape[1],), minval=lowers[index], maxval=high[index]
        # )
        obs = jax.random.uniform(
            rng,
            shape=(self.observation_space.shape[0],),
            minval=self.observation_space.low,
            maxval=self.observation_space.high,
        )
        state = jnp.array([0, obs[0], obs[1]])
        return state, obs


class vCollisionAvoidanceEnv(gym.Env):
    name = "vcavoid"

    def __init__(self):
        self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
        self.observation_space = spaces.Box(
            low=-np.ones(2, dtype=np.float32),
            high=np.ones(2, dtype=np.float32),
            dtype=np.float32,
        )
        # was 0.05 before
        # self.noise = np.array([0.05, 0.05])  # was 0.02 before
        self.noise = np.array([0.02, 0.02])  # was 0.02 before
        safe = np.array([0.2, 0.2], np.float32)  # was 0.1 before
        self.target_spaces = [spaces.Box(low=-safe, high=safe, dtype=np.float32)]

        self.init_spaces_train = make_unsafe_spaces(
            self.observation_space, np.array([0.9, 0.9], np.float32)
        )
        self.init_spaces = [
            spaces.Box(
                low=np.array([-1, -0.6]),
                high=np.array([-0.9, 0.6]),
                dtype=np.float32,
            ),
            spaces.Box(
                low=np.array([0.9, -0.6]),
                high=np.array([1.0, 0.6]),
                dtype=np.float32,
            ),
        ]

        self.unsafe_spaces = []
        self.unsafe_spaces.append(
            spaces.Box(
                low=np.array([-0.3, 0.7]), high=np.array([0.3, 1.0]), dtype=np.float32
            )
        )
        self.unsafe_spaces.append(
            spaces.Box(
                low=np.array([-0.3, -1.0]), high=np.array([0.3, -0.7]), dtype=np.float32
            )
        )
        self._jax_rng = jax.random.PRNGKey(777)
        self.v_next = jax.vmap(self.next, in_axes=(0, 0), out_axes=0)
        self.v_step = jax.jit(jax.vmap(self.step))
        self.v_reset = jax.jit(jax.vmap(self.reset))

    @property
    def noise_bounds(self):
        return -self.noise, self.noise

    @partial(jax.jit, static_argnums=(0,))
    def next(self, state, action):
        action = jnp.clip(action, -1, 1)

        obstacle1 = jnp.array((0, 1))
        force1 = jnp.array((0, 1))
        dist1 = jnp.linalg.norm(obstacle1 - state)
        dist1 = jnp.clip(dist1 / 0.3, 0, 1)
        action = action * dist1 + (1 - dist1) * force1

        obstacle2 = jnp.array((0, -1))
        force2 = jnp.array((0, -1))
        dist2 = jnp.linalg.norm(obstacle2 - state)
        dist2 = jnp.clip(dist2 / 0.3, 0, 1)
        action = action * dist2 + (1 - dist2) * force2

        # state = state + action * 0.2 # RASM paper
        state = state + action * 0.1
        state = jnp.clip(state, self.observation_space.low, self.observation_space.high)

        return state

    def add_noise(self, state):
        self._jax_rng, rng = jax.random.split(self._jax_rng, 2)
        noise = triangular(rng, (self.observation_space.shape[0],))
        noise = noise * self.noise
        return state + noise

    @property
    def lipschitz_constant(self):
        return 1.06

    @property
    def lipschitz_constant_linf(self):
        return 1.06

    @property
    def delta(self):
        return 0.1 + self.noise[0]

    @partial(jax.jit, static_argnums=(0,))
    def step(self, state, action, rng):
        step = state[0]
        state = state[1:3]
        next_state = self.next(state, action)

        noise = triangular(rng, (self.observation_space.shape[0],))
        noise = noise * self.noise
        next_state = next_state + noise
        next_state = np.clip(
            next_state, self.observation_space.low, self.observation_space.high
        )

        reward = 0
        for unsafe in self.unsafe_spaces:
            contain = jnp.all(
                jnp.logical_and(state >= unsafe.low, state <= unsafe.high)
            )
            reward += -jnp.float32(contain)
            center = 0.5 * (unsafe.low + unsafe.high)
            dist = jnp.sum(jnp.abs(center - next_state))
            dist = jnp.clip(dist, 0, 0.5)
            reward -= 1 * (0.5 - dist)

        for target in self.target_spaces:
            contain = jnp.all(
                jnp.logical_and(state >= target.low, state <= target.high)
            )
            reward += jnp.float32(contain)

        reward -= 2 * jnp.mean(jnp.abs(next_state / self.observation_space.high))
        done = step >= 200
        next_packed = jnp.array([step + 1, next_state[0], next_state[1]])
        return next_packed, next_state, reward, done

    @property
    def observation_dim(self):
        return self.observation_space.shape[0]

    @property
    def action_dim(self):
        return self.action_space.shape[0]

    def integrate_noise(self, a: list, b: list):
        dims = 2
        pmass = np.ones(a[0].shape[0])
        for i in range(dims):
            loc = self.noise_bounds[0][i]
            scale = self.noise_bounds[1][i] - self.noise_bounds[0][i]
            marginal_pmass = triang.cdf(b[i], c=0.5, loc=loc, scale=scale) - triang.cdf(
                a[i], c=0.5, loc=loc, scale=scale
            )
            pmass *= marginal_pmass
        return pmass

    @partial(jax.jit, static_argnums=(0,))
    def reset(self, rng):
        obs = jax.random.uniform(
            rng,
            shape=(self.observation_space.shape[0],),
            minval=self.observation_space.low,
            maxval=self.observation_space.high,
        )
        state = jnp.array([0, obs[0], obs[1]])
        return state, obs


class vMazeEnv(gym.Env):
    name = "vmaze"

    def __init__(self):
        self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
        self.observation_space = spaces.Box(
            low=np.zeros(2, dtype=np.float32),
            high=np.ones(2, dtype=np.float32),
            dtype=np.float32,
        )
        # was 0.05 before
        self.noise = np.array([0.001, 0.001])  # was 0.02 before
        self.noise_train = np.array([0.0, 0.0])  # was 0.02 before
        # self.noise = np.array([0.01, 0.01])  # was 0.02 before
        # self.noise_train = np.array([0.05, 0.05])  # was 0.02 before
        self.target_spaces = [
            spaces.Box(
                low=np.array([0.2, 0.7]),
                high=np.array([0.3, 0.8]),
                dtype=np.float32,
            ),
        ]

        self.init_spaces_train = [
            spaces.Box(
                low=np.array([0.1, 0.1]),
                high=np.array([0.4, 0.4]),
                dtype=np.float32,
            ),
        ]
        self.init_spaces = [
            spaces.Box(
                low=np.array([0.2, 0.2]),
                high=np.array([0.3, 0.3]),
                dtype=np.float32,
            ),
        ]

        obstacle_size = 0.15
        self._obstacle_list = [
            ((0.0, 0.5), obstacle_size),
            ((0.1, 0.5), obstacle_size),
            ((0.2, 0.5), obstacle_size),
            ((0.3, 0.5), obstacle_size),
            ((0.4, 0.5), obstacle_size),
            ((0.5, 0.45), obstacle_size),
            ((0.5, 0.55), obstacle_size),
        ]

        self._reward_boxes = [
            (spaces.Box(low=np.array([0.1, 0.1]), high=np.array([0.4, 0.4])), 0.0),
            (spaces.Box(low=np.array([0.4, 0.2]), high=np.array([0.6, 0.3])), 1),
            (spaces.Box(low=np.array([0.6, 0.2]), high=np.array([0.8, 0.3])), 2),
            (spaces.Box(low=np.array([0.6, 0.3]), high=np.array([0.8, 0.45])), 3),
            (spaces.Box(low=np.array([0.6, 0.45]), high=np.array([0.8, 0.6])), 4),
            (spaces.Box(low=np.array([0.6, 0.6]), high=np.array([0.8, 0.7])), 5),
            (spaces.Box(low=np.array([0.6, 0.7]), high=np.array([0.8, 0.8])), 6),
            (spaces.Box(low=np.array([0.45, 0.7]), high=np.array([0.6, 0.8])), 7),
            (spaces.Box(low=np.array([0.3, 0.7]), high=np.array([0.45, 0.8])), 8),
            (spaces.Box(low=np.array([0.15, 0.7]), high=np.array([0.3, 0.8])), 9),
            (spaces.Box(low=np.array([0.2, 0.7]), high=np.array([0.3, 0.8])), 10),
        ]
        self.unsafe_spaces = [
            spaces.Box(low=np.array([0, 0]), high=np.array([0.1, 1.0])),
            spaces.Box(low=np.array([0, 0]), high=np.array([1.0, 0.1])),
            spaces.Box(low=np.array([0, 0.9]), high=np.array([1.0, 1.0])),
            spaces.Box(low=np.array([0.9, 0.0]), high=np.array([1.0, 1.0])),
            spaces.Box(low=np.array([0, 0.45]), high=np.array([0.45, 0.55])),
            spaces.Box(low=np.array([0.45, 0.4]), high=np.array([0.55, 0.6])),
        ]
        self._jax_rng = jax.random.PRNGKey(777)
        self.v_next = jax.vmap(self.next, in_axes=(0, 0), out_axes=0)
        self.v_step = jax.jit(jax.vmap(self.step))
        self.v_reset = jax.jit(jax.vmap(self.reset))

    @property
    def noise_bounds(self):
        return -self.noise, self.noise

    @partial(jax.jit, static_argnums=(0,))
    def next(self, state, action):
        action = jnp.clip(action, -1, 1)

        delta = 1.0
        for unsafe_space in self.unsafe_spaces:
            contain = jnp.all(
                jnp.logical_and(state >= unsafe_space.low, state <= unsafe_space.high)
            )
            action = action * (1.0 - jnp.float32(contain))

        state = state + action * 0.05
        state = jnp.clip(state, self.observation_space.low, self.observation_space.high)

        return state

    def add_noise(self, state):
        self._jax_rng, rng = jax.random.split(self._jax_rng, 2)
        noise = triangular(rng, (self.observation_space.shape[0],))
        noise = noise * self.noise
        return state + noise

    @property
    def lipschitz_constant(self):
        return 1.06

    @property
    def lipschitz_constant_linf(self):
        return 1.06

    @property
    def delta(self):
        return 0.1 + self.noise[0]

    @partial(jax.jit, static_argnums=(0,))
    def step(self, state, action, rng):
        step = state[0]
        state = state[1:3]
        next_state = self.next(state, action)

        noise = triangular(rng, (self.observation_space.shape[0],))
        noise = noise * self.noise_train
        next_state = next_state + noise
        next_state = np.clip(
            next_state, self.observation_space.low, self.observation_space.high
        )

        reward = 0
        reward_factor = 1
        for unsafe in self.unsafe_spaces:
            contain = jnp.all(
                jnp.logical_and(state >= unsafe.low, state <= unsafe.high)
            )
            reward += -reward_factor * jnp.float32(contain)
            # center = 0.5 * (unsafe.low + unsafe.high)
            # dist = jnp.sum(jnp.abs(center - next_state))
            # dist = jnp.clip(dist, 0, 1.0)
            # reward -= 1 * (1.0 - dist)

        for i in range(len(self._reward_boxes) - 1):
            box, rs = self._reward_boxes[i]
            box_next, rs_next = self._reward_boxes[i + 1]
            contain = jnp.all(jnp.logical_and(state >= box.low, state <= box.high))
            center = (box.low + box.high) / 2
            center_next = (box_next.low + box_next.high) / 2
            center_to_center = jnp.linalg.norm(center - center_next)
            dist_to_next = jnp.linalg.norm(center_next - next_state)
            dist_to_next = dist_to_next / center_to_center
            dist_to_next = jnp.clip(dist_to_next, 0, 1)
            reward += reward_factor * (1 - dist_to_next) * jnp.float32(contain)
            reward += reward_factor * rs * jnp.float32(contain)

        for target in self.target_spaces:
            contain = jnp.all(
                jnp.logical_and(state >= target.low, state <= target.high)
            )
            reward += 2 * reward_factor * jnp.float32(contain)
            # center = 0.5 * (target.low + target.high)
            # dist = jnp.sum(jnp.abs(center - next_state))
            # dist = jnp.clip(dist, 0, 1.0)
            # reward += 0.2 * (1 - dist) * jnp.float32(state[0] > 0.5)
            # reward += 0.8 * (1 - dist) * jnp.float32(state[1] > 0.5)

        # reward += 0.2 * jnp.float32(state[0] > 0.4)
        # reward += 0.2 * jnp.float32(state[1] > 0.5)
        # reward += 0.4 * jnp.float32(jnp.logical_and(state[0] < 0.5, state[1] > 0.5))

        # reward -= 2 * jnp.mean(jnp.abs(next_state / self.observation_space.high))
        done = step >= 200
        next_packed = jnp.array([step + 1, next_state[0], next_state[1]])
        return next_packed, next_state, reward, done

    @property
    def observation_dim(self):
        return self.observation_space.shape[0]

    @property
    def action_dim(self):
        return self.action_space.shape[0]

    def integrate_noise(self, a: list, b: list):
        dims = 2
        pmass = np.ones(a[0].shape[0])
        for i in range(dims):
            loc = self.noise_bounds[0][i]
            scale = self.noise_bounds[1][i] - self.noise_bounds[0][i]
            marginal_pmass = triang.cdf(b[i], c=0.5, loc=loc, scale=scale) - triang.cdf(
                a[i], c=0.5, loc=loc, scale=scale
            )
            pmass *= marginal_pmass
        return pmass

    @partial(jax.jit, static_argnums=(0,))
    def reset(self, rng):
        # obs = jax.random.uniform(
        #     rng,
        #     shape=(self.observation_space.shape[0],),
        #     minval=self.observation_space.low,
        #     maxval=self.observation_space.high,
        # )
        lowers = jnp.stack([init.low for init in self.init_spaces_train], 0)
        high = jnp.stack([init.high for init in self.init_spaces_train], 0)
        rng1, rng2 = jax.random.split(rng, 2)
        index = jax.random.randint(
            rng1, shape=(), minval=0, maxval=len(self.init_spaces_train)
        )
        obs = jax.random.uniform(
            rng2, shape=(lowers.shape[1],), minval=lowers[index], maxval=high[index]
        )
        state = jnp.array([0, obs[0], obs[1]])
        return state, obs


class vDroneEnv(gym.Env):
    name = "vdrone"

    def __init__(self):
        self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)
        self.observation_space = spaces.Box(
            low=-np.ones(3, dtype=np.float32),
            high=np.ones(3, dtype=np.float32),
            dtype=np.float32,
        )
        # was 0.05 before
        # self.noise = np.array([0.02, 0.02, 0.02])  # was 0.02 before
        self.noise = np.array([0.002, 0.002, 0.002])  # was 0.02 before
        # safe = np.array([0.2, 0.2, 0.2], np.float32)  # was 0.1 before
        safe = np.array([0.1, 0.1, 0.1], np.float32)  # was 0.1 before
        self.target_spaces = [spaces.Box(low=-safe, high=safe, dtype=np.float32)]

        self.init_spaces = [
            spaces.Box(
                low=np.array([-0.8, -0.2, -0.5]),
                high=np.array([-0.6, 0.2, 0.5]),
                dtype=np.float32,
            ),
            spaces.Box(
                low=np.array([0.6, -0.2, -0.5]),
                high=np.array([0.8, 0.2, 0.5]),
                dtype=np.float32,
            ),
        ]
        self.init_spaces_train = make_unsafe_spaces(
            self.observation_space, np.array([0.1, 0.1, 0.1])
        )
        self.unsafe_spaces = []
        self.unsafe_spaces.append(
            spaces.Box(
                low=np.array([-0.3, 0.7, -0.3]),
                high=np.array([0.3, 1.0, 0.3]),
                dtype=np.float32,
            )
        )
        self.unsafe_spaces.append(
            spaces.Box(
                low=np.array([-0.3, -1.0, -0.3]),
                high=np.array([0.3, -0.7, 0.3]),
                dtype=np.float32,
            )
        )
        self._jax_rng = jax.random.PRNGKey(777)
        self.v_next = jax.vmap(self.next, in_axes=(0, 0), out_axes=0)
        self.v_step = jax.jit(jax.vmap(self.step))
        self.v_reset = jax.jit(jax.vmap(self.reset))

    @property
    def noise_bounds(self):
        return -self.noise, self.noise

    @partial(jax.jit, static_argnums=(0,))
    def next(self, state, action):
        action = jnp.clip(action, -1, 1)

        # obstacle1 = jnp.array((0, 1, 0))
        # force1 = jnp.array((0, 1, 0))
        # dist1 = jnp.linalg.norm(obstacle1 - state)
        # dist1 = jnp.clip(dist1 / 0.3, 0, 1)
        # action = action * dist1 + (1 - dist1) * force1
        #
        # obstacle2 = jnp.array((0, -1, 0))
        # force2 = jnp.array((0, -1, 0))
        # dist2 = jnp.linalg.norm(obstacle2 - state)
        # dist2 = jnp.clip(dist2 / 0.3, 0, 1)
        # action = action * dist2 + (1 - dist2) * force2

        # state = state + action * 0.2
        state = state + action * 0.2
        state = jnp.clip(state, self.observation_space.low, self.observation_space.high)

        return state

    def add_noise(self, state):
        self._jax_rng, rng = jax.random.split(self._jax_rng, 2)
        noise = triangular(rng, (self.observation_space.shape[0],))
        noise = noise * self.noise
        return state + noise

    @partial(jax.jit, static_argnums=(0,))
    def step(self, state, action, rng):
        step = state[0]
        state = state[1:4]
        next_state = self.next(state, action)

        noise = triangular(rng, (self.observation_space.shape[0],))
        noise = noise * self.noise
        next_state = next_state + noise
        next_state = np.clip(
            next_state, self.observation_space.low, self.observation_space.high
        )

        reward = 0
        for unsafe in self.unsafe_spaces:
            contain = jnp.all(
                jnp.logical_and(state >= unsafe.low, state <= unsafe.high)
            )
            reward += -jnp.float32(contain)
            center = 0.5 * (unsafe.low + unsafe.high)
            dist = jnp.sum(jnp.abs(center - next_state))
            dist = jnp.clip(dist, 0, 0.5)
            reward -= 1 * (0.5 - dist)

        for target in self.target_spaces:
            contain = jnp.all(
                jnp.logical_and(state >= target.low, state <= target.high)
            )
            reward += jnp.float32(contain)

        reward -= 2 * jnp.mean(jnp.abs(next_state / self.observation_space.high))
        done = step >= 200
        next_packed = jnp.array([step + 1, next_state[0], next_state[1], next_state[2]])
        return next_packed, next_state, reward, done

    @property
    def observation_dim(self):
        return self.observation_space.shape[0]

    @property
    def lipschitz_constant(self):
        return 1.2

    @property
    def lipschitz_constant_linf(self):
        return 1.2

    @property
    def delta(self):
        return 0.1 + self.noise[0]

    @property
    def action_dim(self):
        return self.action_space.shape[0]

    def integrate_noise(self, a: list, b: list):
        dims = 3
        pmass = np.ones(a[0].shape[0])
        for i in range(dims):
            loc = self.noise_bounds[0][i]
            scale = self.noise_bounds[1][i] - self.noise_bounds[0][i]
            marginal_pmass = triang.cdf(b[i], c=0.5, loc=loc, scale=scale) - triang.cdf(
                a[i], c=0.5, loc=loc, scale=scale
            )
            pmass *= marginal_pmass
        return pmass

    @partial(jax.jit, static_argnums=(0,))
    def reset(self, rng):
        # obs = jax.random.uniform(
        #     rng,
        #     shape=(self.observation_space.shape[0],),
        #     minval=self.observation_space.low,
        #     maxval=self.observation_space.high,
        # )
        lowers = jnp.stack([init.low for init in self.init_spaces_train], 0)
        high = jnp.stack([init.high for init in self.init_spaces_train], 0)
        rng1, rng2 = jax.random.split(rng, 2)
        index = jax.random.randint(
            rng1, shape=(), minval=0, maxval=len(self.init_spaces_train)
        )
        obs = jax.random.uniform(
            rng2, shape=(lowers.shape[1],), minval=lowers[index], maxval=high[index]
        )
        state = jnp.array([0, obs[0], obs[1], obs[2]])
        return state, obs


if __name__ == "__main__":
    # env = vLDSEnv()
    # env = vHumanoidBalance2()
    # rngs = jax.random.split(jax.random.PRNGKey(1), 5)
    # init, obs = env.v_reset(rngs)
    # print("init=", init)
    # print("obs[0]=", obs)
    # rngs = jax.random.split(jax.random.PRNGKey(2), 5)
    # action = jax.random.uniform(
    #     jax.random.PRNGKey(3), shape=(5, 2), minval=-1, maxval=1
    # )
    # # action = jnp.ones((5, 2))
    # state, obs, reward, done = env.v_step(init, action, rngs)
    # print("next=", state)
    # print("obs[1]=", obs)

    import seaborn as sns

    env = vMazeEnv()
    N = 100
    x1 = jnp.linspace(env.observation_space.low[0], env.observation_space.high[0], N)
    x2 = jnp.linspace(env.observation_space.low[1], env.observation_space.high[1], N)
    grid_x1 = []
    grid_x2 = []
    grid_y = []
    for i1 in range(N):
        for i2 in range(N):
            state = jnp.array([0, x1[i1], x2[i2]])
            action = jnp.zeros(env.action_space.shape[0])
            next_packed, next_state, reward, done = env.step(
                state, action, jax.random.PRNGKey(0)
            )
            grid_x1.append(float(x1[i1]))
            grid_x2.append(float(x2[i2]))
            grid_y.append(float(reward))

    sns.set()
    fig, ax = plt.subplots(figsize=(6, 6))
    sc = ax.scatter(grid_x1, grid_x2, marker="s", c=grid_y)
    # if hasattr(env, '_reward_boxes'):
    #     for box,rs in env._reward_boxes:
    #         x = [
    #             box.low[0],
    #             box.high[0],
    #             box.high[0],
    #             box.low[0],
    #             box.low[0],
    #         ]
    #         y = [
    #             box.low[1],
    #             box.low[1],
    #             box.high[1],
    #             box.high[1],
    #             box.low[1],
    #         ]
    #         ax.plot(x, y, color="yellow", alpha=0.5, zorder=7)
    # for unsafe in env.unsafe_spaces:
    #     x = [
    #         unsafe.low[0],
    #         unsafe.high[0],
    #         unsafe.high[0],
    #         unsafe.low[0],
    #         unsafe.low[0],
    #     ]
    #     y = [
    #         unsafe.low[1],
    #         unsafe.low[1],
    #         unsafe.high[1],
    #         unsafe.high[1],
    #         unsafe.low[1],
    #     ]
    #     ax.plot(x, y, color="red", alpha=0.5, zorder=7)
    # for target_space in env.target_spaces:
    #     x = [
    #         target_space.low[0],
    #         target_space.high[0],
    #         target_space.high[0],
    #         target_space.low[0],
    #         target_space.low[0],
    #     ]
    #     y = [
    #         target_space.low[1],
    #         target_space.low[1],
    #         target_space.high[1],
    #         target_space.high[1],
    #         target_space.low[1],
    #     ]
    #     ax.plot(x, y, color="green", alpha=0.5, zorder=7)

    ax.set_xlim([env.observation_space.low[0], env.observation_space.high[0]])
    ax.set_ylim([env.observation_space.low[1], env.observation_space.high[1]])
    fig.tight_layout()
    fig.savefig("reward_map.png")
    plt.close(fig)