from copy import deepcopy
import os
import numpy as np

import hydra
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm

from buffer import rollin_rollout
from control.policies import MaxFollowingPolicy
from training import train_v
from recorder import Recorder

import torch

# fmt: off
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# fmt: on

import composuite
import d3rlpy

GLOBAL_SUBTASK_KWARGS = {
    "has_renderer": False,
    "has_offscreen_renderer": False,
    "reward_shaping": True,
    "use_camera_obs": False,
    "use_task_id_obs": True,
    "env_horizon": 500,
}

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def seed_all(seed: int):
    """Seed all random number generators."""
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    logger.info(f"Seeded all random number generators with seed {seed}")


def create_training_components(
    num_policies,
    obs_shape,
    save_path,
    buffer_cfg,
    buffer_size,
    vf_cfg,
    optimizer_cfg,
):
    buffers = []
    networks = []
    optimizers = []
    for i in range(num_policies):
        buffer = hydra.utils.instantiate(
            buffer_cfg,
            size=buffer_size,
            save_path=f"{save_path}/buffer_{i}.npz",
            observation_dim=obs_shape[0],
        )
        buffers.append(buffer)
        logger.info(f"Created buffer for policy {i}")

        net = hydra.utils.instantiate(vf_cfg, input_dim=obs_shape[0]).to(DEVICE)
        networks.append(net)
        logger.info(f"Created value function for policy {i}")

        optimizer = hydra.utils.instantiate(optimizer_cfg, params=net.parameters())
        optimizers.append(optimizer)
        logger.info(f"Created optimizer for policy {i}")

    return buffers, networks, optimizers


def create_d3rlpy_policies(env, base_path, num_samples, algo, model_nr, env_configs):
    paths = []
    for i in range(len(env_configs["robot"])):
        path = os.path.join(
            base_path,
            str(algo),
            f"nth_{model_nr}",
            str(num_samples),
            f"{env_configs['robot'][i]}_{env_configs['obj'][i]}_{env_configs['obstacle'][i]}_{env_configs['objective'][i]}",
            "IQL",
        )
        logger.info(
            f"Added model with robot {env_configs['robot'][i]}, obj {env_configs['obj'][i]}, obstacle {env_configs['obstacle'][i]}, objective {env_configs['objective'][i]}"
        )
        paths.append(os.path.join(path, "model_300000.pt"))

    policies = []
    for p in paths:
        if algo == "iql":
            model = d3rlpy.algos.IQL()
        elif algo == "cql":
            model = d3rlpy.algos.CQL()
        elif algo == "bc":
            model = d3rlpy.algos.BC()
        else:
            raise ValueError(f"Algo {algo} not supported")

        model.build_with_env(env)
        model.load_model(p)
        policies.append(
            lambda x, iql=deepcopy(model): iql.predict(np.expand_dims(x, 0))[0]
        )

    assert len(policies) == len(env_configs["robot"])
    return policies


GLOBAL_STEP_COUNTER = 0


def modified_reset(gym_env):
    original_reset = gym_env.reset

    def reset_wrapper(*args, **kwargs):
        global GLOBAL_STEP_COUNTER
        GLOBAL_STEP_COUNTER = 0

        obs, _ = original_reset(*args, **kwargs)
        return obs

    gym_env.reset = reset_wrapper


def modified_step(gym_env):
    original_step = gym_env.step

    def step_wrapper(*args, **kwargs):
        global GLOBAL_STEP_COUNTER
        GLOBAL_STEP_COUNTER += 1

        obs, rew, done, _, info = original_step(*args, **kwargs)

        if GLOBAL_STEP_COUNTER % 500 == 0:
            info["TimeLimit.truncated"] = True

        return obs, rew, done, info

    gym_env.step = step_wrapper


@hydra.main(config_path="configs", config_name="main.yaml")
def main(cfg: DictConfig):
    logger.info(f"Starting run with config {cfg}")
    logger.info(f"Device: {DEVICE}")
    seed_all(cfg.seed)

    # make save path
    os.makedirs(cfg.save_path, exist_ok=True)
    pair_to_entries = [
        (
            [("Hollowbox", "Trashcan"), ("Plate", "Push")],
            [("Hollowbox", "Push"), ("Plate", "Trashcan")],
        ),
        (
            [("Plate", "Trashcan"), ("Box", "Shelf")],
            [("Plate", "Shelf"), ("Box", "Trashcan")],
        ),
        (
            [("Dumbbell", "Shelf"), ("Hollowbox", "Push")],
            [("Dumbbell", "Push"), ("Hollowbox", "Shelf")],
        ),
        (
            [("Box", "PickPlace"), ("Dumbbell", "Push")],
            [("Box", "Push"), ("Dumbbell", "PickPlace")],
        ),
        (
            [("Box", "Push"), ("Hollowbox", "PickPlace")],
            [("Box", "PickPlace"), ("Hollowbox", "Push")],
        ),
        (
            [("Plate", "Shelf"), ("Dumbbell", "Trashcan")],
            [("Plate", "Trashcan"), ("Dumbbell", "Shelf")],
        ),
        (
            [("Dumbbell", "PickPlace"), ("Box", "Trashcan")],
            [("Dumbbell", "Trashcan"), ("Box", "PickPlace")],
        ),
        (
            [("Hollowbox", "Shelf"), ("Plate", "PickPlace")],
            [("Hollowbox", "PickPlace"), ("Plate", "Shelf")],
        ),
    ]

    load_idx = cfg.load_idx
    env_idx = cfg.env_idx

    obj = pair_to_entries[load_idx][1][env_idx][0]
    objective = pair_to_entries[load_idx][1][env_idx][1]

    env = composuite.make(
        cfg.env_config.robot,
        obj,
        cfg.env_config.obstacle,
        objective,
        **GLOBAL_SUBTASK_KWARGS,
    )
    modified_reset(env)
    modified_step(env)

    logger.info(
        f"Created environment with robot {cfg.env_config.robot}, obj {obj}, obstacle {cfg.env_config.obstacle}, objective {objective}"
    )

    # from container
    data_env_configs = OmegaConf.to_container(cfg.data_env_configs, resolve=True)

    # get objs and objectives from load_idx entry in pair_to_entries
    objs = []
    objectives = []
    current_entry = pair_to_entries[load_idx][0]
    for i in range(2):
        objs.append(current_entry[i][0])
        objectives.append(current_entry[i][1])

    # # overwrite them in data_env_configs
    data_env_configs["obj"] = objs
    data_env_configs["objective"] = objectives

    policies = create_d3rlpy_policies(
        env,
        cfg.data_env_configs.base_path,
        cfg.offline.num_samples,
        cfg.offline.algo,
        cfg.offline.model_nr,
        data_env_configs,
    )
    num_policies = len(policies)

    rec_cfg = OmegaConf.to_container(cfg, resolve=True)
    rec = Recorder(
        cfg.log_path,
        cfg.wandb.project,
        cfg.wandb.entity,
        config=rec_cfg,
        group=cfg.wandb.group,
    )
    logger.info(f"Created recorder at {cfg.log_path}")
    rets, sucs = evaluate_policies(
        policies, env, cfg.eval.num_trajs, cfg.training.horizon
    )
    for i, (ret, suc) in enumerate(zip(rets, sucs)):
        logger.info(f"Initial eval reward for policy {i}: {np.mean(ret):.4f}")
        rec.record(0, f"policy_{i}/eval_reward", np.mean(ret))
        rec.record(0, f"policy_{i}/eval_reward_std", np.std(ret))
        rec.record(0, f"policy_{i}/eval_reward_min", np.min(ret))
        rec.record(0, f"policy_{i}/eval_reward_max", np.max(ret))
        rec.record(0, f"policy_{i}/eval_success", suc)

    buffers, value_functions, optimizers = create_training_components(
        num_policies,
        env.observation_space.shape,
        cfg.save_path,
        cfg.buffer,
        cfg.training.num_rounds
        * cfg.training.num_episodes_per_round
        * cfg.training.horizon,
        cfg.value_function,
        cfg.optimizer,
    )
    mfp = MaxFollowingPolicy(policies, value_functions)
    logger.info(f"Created max following policy, done with training components")

    # prefill buffers to at least batch_size
    for policy, buffer in zip(policies, buffers):
        while len(buffer) < cfg.buffer.batch_size:
            buffer, _, _ = rollin_rollout(
                env=env,
                policy=policy,
                buffer=buffer,
                horizon=cfg.training.horizon,
                num_episodes=1,
            )

    returns, policy_usage, success = evaluate_policy(
        env,
        mfp,
        cfg.eval.num_trajs,
        cfg.training.horizon,
    )
    eval_reward_mean_before = np.mean(returns)
    eval_reward_std_before = np.std(returns)
    rec.record(0, "eval/reward_mean", eval_reward_mean_before)
    rec.record(0, "eval/reward_std", eval_reward_std_before)
    rec.record(0, "eval/reward_min", np.min(returns))
    rec.record(0, "eval/reward_max", np.max(returns))
    rec.record(0, "eval/success", success)

    for i, entry in enumerate(policy_usage):
        rec.record(0, f"policy_{i}/usage", entry)

    for round in tqdm(range(1, cfg.training.num_rounds + 1), desc="Training round"):
        logger.info(f"Starting round {round}")
        updates_done = round * cfg.training.num_updates
        episodes_collected = round * cfg.training.num_episodes_per_round
        rec.record(round, "counting/updates_done", updates_done)
        rec.record(round, "counting/episodes_collected", episodes_collected)

        # collect samples for all policies
        for i, (policy, buffer) in enumerate(zip(policies, buffers)):
            previous_buffer_size = len(buffer)
            init_horizon = int((round / cfg.training.num_rounds) * cfg.training.horizon)
            buffer, policy_ep_rew, all_ep_rew = rollin_rollout(
                env=env,
                policy=policy,
                buffer=buffer,
                horizon=cfg.training.horizon,
                num_episodes=(
                    cfg.training.num_episodes_per_round
                    if round > 1
                    else cfg.training.starting_episodes
                ),
                init_policy=mfp if round > 1 else None,
                init_horizon=init_horizon,  # cfg.training.horizon - 1, #
                init_index_type=cfg.training.init_index_type,
                name=str(i),
            )
            logger.info(
                f"Collected {len(buffer) - previous_buffer_size} samples "
                + f"for policy {i}"
            )

            if round % cfg.eval.interval == 0:
                rec_dict = {
                    f"data_{i}/buffer_size": len(buffer),
                    f"data_{i}/policy_ep_rew_mean": np.mean(policy_ep_rew),
                    f"data_{i}/policy_ep_rew_std": np.std(policy_ep_rew),
                    f"data_{i}/all_ep_rew": np.mean(all_ep_rew),
                    f"data_{i}/all_ep_rew_std": np.std(all_ep_rew),
                }
                rec.record_dict(round, rec_dict)

        # train all value functions
        for i, (policy, buffer, net, opt) in enumerate(
            zip(policies, buffers, value_functions, optimizers)
        ):
            _, losses, value_preds, value_targets = train_v(
                buffer,
                net,
                opt,
                cfg.training.num_updates,
                name=str(i),
            )

            if round % cfg.eval.interval == 0:
                rec_dict = {
                    f"vf_{i}/losses": np.mean(losses),
                    f"vf_{i}/value_preds": np.mean(value_preds),
                    f"vf_{i}/value_targets": np.mean(value_targets),
                }
                rec.record_dict(round, rec_dict)

        # evaluate the max following policy
        if round % cfg.eval.interval == 0:
            eval_returns_after_training, policy_usage, success = evaluate_policy(
                env,
                mfp,
                cfg.eval.num_trajs,
                cfg.training.horizon,
            )
            eval_reward_mean_after = np.mean(eval_returns_after_training)
            eval_reward_std_after = np.std(eval_returns_after_training)
            rec.record(round, "eval/reward_mean", eval_reward_mean_after)
            rec.record(round, "eval/reward_std", eval_reward_std_after)
            rec.record(round, "eval/reward_min", np.min(eval_returns_after_training))
            rec.record(round, "eval/reward_max", np.max(eval_returns_after_training))
            rec.record(round, "eval/success", success)

            for i, entry in enumerate(policy_usage):
                rec.record(round, f"policy_{i}/usage", entry)
            logger.info(
                f"Eval reward after training: {eval_reward_mean_after:.4f}"
                + f" with std {eval_reward_std_after:.4f}"
                + f" and policy usage {policy_usage}"
            )

        if round % cfg.training.save_interval == 0:
            # save the models
            for i, (buffer, net) in enumerate(zip(buffers, value_functions)):
                # buffer.dump()
                torch.save(net, os.path.join(cfg.save_path, f"vf_{i}_{round}.pt"))


def evaluate_policy(env, policy, num_eval_trajs, horizon):
    """Evaluate the policy."""
    returns = []

    if hasattr(policy, "eval"):
        policy.eval()

    # if policy does not have len() attribute, set num_policies=1
    if not hasattr(policy, "__len__"):
        policy = policy
        num_policies = 1
    else:
        num_policies = len(policy)
    policy_usage = np.zeros(num_policies)

    total_steps = 0
    policy_used = 0
    successes = []
    for _ in tqdm(range(num_eval_trajs), desc="Evaluating policy", leave=False):
        observation = env.reset()
        done = False
        episode_reward = 0
        success = 0
        for _ in range(horizon):
            observation = observation.astype(np.float32)
            action = policy(observation)
            if isinstance(action, tuple):
                _action = action
                action = _action[0]
                policy_used = _action[1]

            policy_usage[policy_used] += 1

            observation, reward, done, _ = env.step(action)
            if reward == 1:
                success = 1

            total_steps += 1
            episode_reward += reward
            if done:
                break

        successes.append(success)

        returns.append(episode_reward)

    if hasattr(policy, "train"):
        policy.train()

    return (
        returns,
        np.array(policy_usage) / total_steps,
        np.mean(successes),
    )


def evaluate_policies(policies, env, num_eval_trajs, horizon):
    """Evaluate the policies."""
    returns = []
    successes = []
    for i, policy in enumerate(policies):
        ret, usage, suc = evaluate_policy(env, policy, num_eval_trajs, horizon)
        logger.info(
            f"Eval reward for policy {i}: {np.mean(ret):.4f} +- {np.std(ret):.4f}"
            + f" with policy usage {usage} and success rate {suc}"
        )
        returns.append(ret)
        successes.append(suc)

    return returns, successes


if __name__ == "__main__":
    main()
