from algorithms.sac import SAC
from algorithms.sac_mt import MultitaskSAC
from environments import make_env, make_multitask_env
from setup import AttrDict, parse_arguments, set_seed, set_device, setup_logger


def get_config():
    config = AttrDict()
    config.algo = "sac"
    config.env_id = "HalfCheetah-v2"
    config.expr_name = "default"
    config.seed = 0
    config.use_gpu = True
    config.pixel_obs = False

    # SAC
    config.num_steps = 1000000
    config.start_step = 10000
    config.lr = 0.0003
    config.gamma = 0.99
    config.tau = 0.005
    config.alpha = 0.2
    config.automatic_entropy_tuning = False
    config.target_entropy = "auto"
    config.replay_size = 1000000
    config.batch_size = 256
    config.hidden_size = 256
    config.repr_size = 256
    config.updates_per_step = 1
    config.target_update_freq = 1
    config.eval_freq = 10
    config.num_eval_episodes = 10
    config.checkpoint_freq = 10
    return parse_arguments(config)


if __name__ == "__main__":
    config = get_config()
    set_seed(config.seed)
    set_device(config.use_gpu)

    # Logger
    logger = setup_logger(config)

    # Environment
    if config.algo == "sac":
        env = make_env(config.env_id, config.seed, config.pixel_obs)
    elif config.algo == "sac_multitask":
        env = make_multitask_env(config.env_id, config.seed, config.pixel_obs)
        eval_env_id = config.env_id.replace("train", "test")
        eval_seed = config.seed + 1000
        eval_env = make_multitask_env(eval_env_id, eval_seed, config.pixel_obs)
    else:
        raise ValueError("Unsupported algorithm")

    # Agent
    if config.algo == "sac":
        algo = SAC(config, env, logger)
    elif config.algo == "sac_multitask":
        algo = MultitaskSAC(config, env, eval_env, logger)
    else:
        raise ValueError("Unsupported algorithm")
    algo.train()
