from gym import register
from gym.envs.mujoco.hopper import HopperEnv
import numpy as np
from .safe_env_spec import SafeEnv, interval_barrier, nonneg_barrier
import torch


class SafeHopperEnv(HopperEnv, SafeEnv):
    def reset_model(self):
        self.set_state(self.init_qpos, self.init_qvel)
        return self._get_obs()

    def step(self, a):
        next_state, reward, done, info = super().step(a)
        if done:
            info['episode.unsafe'] = 1
            reward = -100
        return next_state, reward, done, info

    def _get_obs(self):
        return np.concatenate([
            self.sim.data.qpos.flat[1:],
            self.sim.data.qvel.flat,
        ])

    def is_state_safe(self, states):
        height = states[..., 0]
        angle = states[..., 1]
        unsafe = (states[..., 2:].abs() < 100).all(dim=-1) & (height > 0.7) & (angle.abs() < 0.3)
        return ~unsafe

    def barrier_fn(self, states):
        height = states[..., 0]
        angle = states[..., 1]
        barriers = [
            interval_barrier(angle, -0.3, 0.3),
            nonneg_barrier(height - 0.7),
            interval_barrier(states[..., 2:], -100, 100).mean(dim=-1),
        ]
        return barriers[0].max(barriers[1]).max(barriers[2])


register('SafeHopper-v2', entry_point=SafeHopperEnv, max_episode_steps=1000)
