import numpy as np
from env.robosuite.robosuite import make
from env.robosuite.robosuite import load_controller_config
from env.collector.gym_wrapper import GymStackWrapper, GymLiftWrapper, GymLiftCausalWrapper
import gym

WRAPPER = {
    "LiftCausal": GymLiftWrapper,
    "StackCausal": GymStackWrapper,
    "CausalPick": GymLiftCausalWrapper,
}

class LiftEnv(gym.Env):
    def __init__(self, test_mode='IID', stage='train', task="LiftCausal", horizon=200, control_freq=5, seed=100):
        assert test_mode in ['IID', 'OOD'], 'test_mode must be IID or OOD'
        if test_mode == 'IID':
            spurious_type = 'xnr'
        elif test_mode == 'OOD':
            spurious_type = 'xpr'
        self.controller = 'OSC_POSITION'
        env = make(
            task,
            'Kinova3',
            horizon=horizon,
            control_freq=control_freq,
            has_renderer=False,
            has_offscreen_renderer=False,
            ignore_done=False,
            use_camera_obs=False,
            use_object_obs=True,
            controller_configs=load_controller_config(default_controller=self.controller),
            spurious_type=spurious_type,
            # spurious_type=spurious_type,
        )
        self.env = WRAPPER[task](env)
        self.action_dim = 4 if self.controller == 'OSC_POSITION' else 7
        self.state_dim = 33
        
        self.env.seed(seed)
        self.target_position = np.array([0, 0, 1])  # Target position to lift the object to
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.state_dim,), dtype=np.float32)
        self.action_space = gym.spaces.Box(low=-5., high=5., shape=(self.action_dim,), dtype=np.float32)
    
    def random_action(self):
        actions_xyz = np.random.uniform(-1., 1., size=(3,))
        actions_gripper = np.random.randint(0, 3, size=(1,)) - 1
        actions = np.concatenate([actions_xyz, actions_gripper], axis=0)
        return actions
    
    def reset(self):
        obs = self.env.reset()
        return obs

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        # adjusted_reward = -np.linalg.norm(obs[:3] - self.target_position)
        if obs[21] > 0.88: 
            print('height: ', obs[21]-0.80)
        adjusted_reward = 1. if obs[21] - 0.8 >= 0.1 else 0.
        return obs, adjusted_reward, done, info
    
    def render(self):
        self.env.render()

    def close(self):
        self.env.close()

    def seed(self, seed):
        self.env.seed(seed)
