import numpy as np
from gym.wrappers import RescaleAction


def get_env(domain):
    from .env.torch_wrapper import TorchWrapper
    from .env.hopper_no_bonus import HopperNoBonusEnv
    from .env.cheetah_no_flip import CheetahNoFlipEnv
    from .env.navigation1 import Navigation1
    from .env.navigation2 import Navigation2
    envs = {
        'hopper': HopperNoBonusEnv,
        'cheetah-no-flip': CheetahNoFlipEnv,
        'navigation1': Navigation1,
        'navigation2': Navigation2
    }
    env = envs[domain]()
    if not (np.all(env.action_space.low == -1.0) and np.all(env.action_space.high == 1.0)):
        env = RescaleAction(env, -1.0, 1.0)
    return TorchWrapper(env)