from algorithms.mbpo import MBPO
from environments import make_env
from setup import AttrDict, parse_arguments, set_seed, set_device, setup_logger


def get_config():
    config = AttrDict()
    config.algo = "mbpo"
    config.env_id = "HalfCheetah-v2"
    config.expr_name = "default"
    config.seed = 0
    config.use_gpu = True
    config.pixel_obs = False

    # SAC
    config.num_steps = 1000000
    config.start_step = 10000
    config.lr = 0.0003
    config.gamma = 0.99
    config.tau = 0.005
    config.alpha = 0.2
    config.automatic_entropy_tuning = False
    config.target_entropy = "auto"
    config.replay_size = 1000000
    config.batch_size = 256
    config.hidden_size = 256
    config.updates_per_step = 1
    config.target_update_freq = 1
    config.eval_freq = 10
    config.num_eval_episodes = 10

    # MBPO
    config.num_epochs = 1000
    config.epoch_length = 1000
    config.min_buffer_size = 1000
    config.init_exploration_steps = 5000
    config.ensemble_size = 7
    config.model_hidden_size = 200
    config.model_lr = 0.001
    config.model_wd = 0.00005
    config.model_retain_epochs = 1
    config.model_train_freq = 250
    config.model_train_epochs = 1
    config.model_batch_size = 256
    config.rollout_batch_size = 100000
    config.rollout_min_epoch = 20
    config.rollout_max_epoch = 150
    config.rollout_min_length = 1
    config.rollout_max_length = 15
    config.real_ratio = 0.05
    config.policy_train_freq = 1
    config.num_train_repeats = 20
    config.max_train_repeats_per_step = 5
    config.normalize = True
    return parse_arguments(config)


if __name__ == "__main__":
    config = get_config()
    set_seed(config.seed)
    set_device(config.use_gpu)

    # Logger
    logger = setup_logger(config)

    # Environment
    env = make_env(config.env_id, config.seed, config.pixel_obs)

    # Agent
    algo = MBPO(config, env, logger)
    algo.train()
