"""
Use SAC to learn an agent that adaptively designs source location experiments
"""

import argparse
import os
import pickle as pkl

import numpy as np
import torch
import wandb
from dowel import logger
from garage.experiment import deterministic
from garage.torch import set_gpu_mode
from torch import nn

from src import set_rng_seed, wrap_experiment
from src.algos import SAC
from src.envs import (
    GymEnv,
    NormalizedCausalEnv,
    AdaptiveIntervDesignEnvLikelihoodFree,
    AdaptiveIntervDesignEnvEvalLikelihoodFree,
)
from src.envs.adaptive_design_env import LOWER, TERMINAL, UPPER
from src.experiment import Trainer
from src.models.causal_experiment_model import GRNSergioModel
from src.policies import AdaptiveTransformerTanhGaussianPolicy
from src.q_functions.adaptive_mlp_q_function import AdaptiveMLPQFunctionDoCausal
from src.replay_buffer import PathBuffer
from src.sampler.local_sampler import LocalSampler
from src.sampler.vector_worker import VectorWorker
import avici
from functools import partial


def main(
    n_parallel=1,
    budget=1,
    n_rl_itr=1,
    seed=0,
    eval_save_dir=None,
    log_dir=None,
    snapshot_mode="gap",
    snapshot_gap=500,
    discount=1.0,
    alpha=None,
    d=2,
    log_info=None,
    tau=5e-3,
    pi_lr=3e-4,
    qf_lr=3e-4,
    buffer_capacity=int(1e6),
    ens_size=2,
    M=2,
    G=1,
    minibatch_size=4096,
    data_seed=1,
    num_initial_obs=0,
    batch_size=1,
    use_wandb=False,
    num_attn_layers=2,
    num_attn_layers_q_func=2,
    norm_rewards=False,
    is_single_target=True,
    intervention_type="kout",
    shared_encoder=False,
    graph_degree=1,
    noise_config_type="10x-chromium-mini",
):
    if log_info is None:
        log_info = []

    @wrap_experiment(
        log_dir=log_dir, snapshot_mode=snapshot_mode, snapshot_gap=snapshot_gap
    )
    def sac_source(
        ctxt=None,
        n_parallel=1,
        budget=1,
        n_rl_itr=1,
        eval_save_dir=None,
        seed=0,
        discount=1.0,
        alpha=None,
        d=2,
        tau=5e-3,
        pi_lr=3e-4,
        qf_lr=3e-4,
        buffer_capacity=int(1e6),
        ens_size=2,
        M=2,
        G=1,
        minibatch_size=4096,
        num_initial_obs=0,
        batch_size=1,
        use_wandb=False,
        num_attn_layers=2,
        num_attn_layers_q_func=2,
        norm_rewards=False,
        is_single_target=True,
        intervention_type="do",
        shared_encoder=False,
        graph_degree=1,
        noise_config_type="10x-chromium-mini",
    ):
        trainer = Trainer(snapshot_config=ctxt, wandb=use_wandb)
        replay_buffer = PathBuffer(capacity_in_transitions=buffer_capacity)
        if os.path.exists(os.path.join(log_dir, "params.pkl")):
            sampler = partial(
                LocalSampler,
                max_episode_length=budget,
                worker_class=VectorWorker,
                worker_args={
                    "num_init_obs": num_initial_obs,
                    "batch_size": batch_size,
                },
            )
            trainer.restore(
                from_dir=log_dir,
                from_epoch="last",
                replay_buffer=replay_buffer,
                sampler=sampler,
                alpha=alpha,
            )
            trainer.resume(n_epochs=n_rl_itr, batch_size=n_parallel * budget)
        else:
            if log_info:
                logger.log(str(log_info))
            if torch.cuda.is_available():
                set_gpu_mode(True, 0)
                device = torch.device("cuda")
                # torch.set_default_tensor_type("torch.cuda.FloatTensor")
                logger.log("GPU available")
            else:
                device = torch.device("cpu")
                set_gpu_mode(False)
                logger.log("no GPU detected")
            # if there is a saved agent to load

            logger.log("creating new policy")
            layer_size = 128
            # is_cube = round(minibatch_size ** (1/3)) ** 3 == minibatch_size

            model = GRNSergioModel(
                n_parallel=n_parallel,
                d=d,
                graph_prior="erdos_renyi",
                intervention_type=intervention_type,
                graph_args={"degree": graph_degree},
                noise_config_type=noise_config_type,
            )
            eval_model = GRNSergioModel(
                n_parallel=100,
                d=d,
                graph_prior="erdos_renyi",
                intervention_type=intervention_type,
                graph_args={"degree": graph_degree},
                noise_config_type=noise_config_type,
            )
            eval_ood_models = {}
            for ood_envs in [
                "graph",
                "intervType",
                "d+2",
                "d+5",
                "d+10",
                "noise_config",
            ]:
                if ood_envs == "graph":
                    eval_ood_models[ood_envs] = GRNSergioModel(
                        n_parallel=100,
                        d=d,
                        graph_prior="scale_free",
                        intervention_type=intervention_type,
                        graph_args={"degree": graph_degree},
                    )
                elif ood_envs == "d+2":
                    eval_ood_models[ood_envs] = GRNSergioModel(
                        n_parallel=100,
                        d=d + 2,
                        graph_prior="erdos_renyi",
                        intervention_type=intervention_type,
                        graph_args={"degree": graph_degree},
                    )
                elif ood_envs == "intervType":
                    eval_ood_models[ood_envs] = GRNSergioModel(
                        n_parallel=100,
                        d=d,
                        graph_prior="erdos_renyi",
                        intervention_type=(
                            "kdown" if intervention_type == "kout" else "kout"
                        ),
                        graph_args={"degree": graph_degree},
                        noise_config_type=noise_config_type,
                    )
                elif ood_envs == "d+5":
                    eval_ood_models[ood_envs] = GRNSergioModel(
                        n_parallel=100,
                        d=d + 5,
                        graph_prior="erdos_renyi",
                        intervention_type=intervention_type,
                        graph_args={"degree": graph_degree},
                    )
                elif ood_envs == "d+10":
                    eval_ood_models[ood_envs] = GRNSergioModel(
                        n_parallel=100,
                        d=d + 10,
                        graph_prior="erdos_renyi",
                        intervention_type=intervention_type,
                        graph_args={"degree": graph_degree + 1},
                    )
                elif ood_envs == "noise_config":
                    eval_ood_models[ood_envs] = GRNSergioModel(
                        n_parallel=100,
                        d=d,
                        graph_prior="erdos_renyi",
                        intervention_type=intervention_type,
                        graph_args={"degree": graph_degree},
                        noise_config_type=(
                            "drop-seq"
                            if noise_config_type == "10x-chromium-mini"
                            else "10x-chromium-mini"
                        ),
                    )

            eval_env_type = AdaptiveIntervDesignEnvEvalLikelihoodFree
            env_type = AdaptiveIntervDesignEnvLikelihoodFree
            reward_model = avici.load_pretrained(
                download="neurips-grn", expects_counts=True
            )
            kwargs = {}
            kwargs["zero_bias"] = True
            kwargs["batch_size"] = batch_size
            kwargs["num_initial_obs"] = num_initial_obs
            kwargs["reward_model"] = reward_model

            def make_eval_env(
                eval_model,
                budget,
                save_path=eval_save_dir,
            ):
                env = GymEnv(
                    NormalizedCausalEnv(
                        eval_env_type(
                            eval_model,
                            budget,
                            data_seed=data_seed,
                            **kwargs,
                        ),
                        normalize_obs=True,
                        is_count_data=True,
                    )
                )
                return env

            def make_ood_eval_env(
                eval_models,
                budget,
                save_path=eval_save_dir,
            ):
                envs = {
                    eval_model_keys: GymEnv(
                        NormalizedCausalEnv(
                            eval_env_type(
                                eval_models[eval_model_keys],
                                budget,
                                save_path=None,
                                data_seed=data_seed,
                                **kwargs,
                            ),
                            normalize_obs=True,
                            is_count_data=True,
                        )
                    )
                    for eval_model_keys in eval_models.keys()
                }
                return envs

            def make_env(model, budget):
                env = GymEnv(
                    NormalizedCausalEnv(
                        env_type(
                            model,
                            budget,
                            **kwargs,
                        ),
                        normalize_obs=True,
                        is_count_data=True,
                        normalize_reward=norm_rewards,
                    )
                )
                return env

            def make_policy():
                return AdaptiveTransformerTanhGaussianPolicy(
                    env_spec=env.spec,
                    n_attention_heads=8,
                    n_attention_layers=num_attn_layers,
                    dropout=0.1,
                    widening_factor=4,
                    pooling="max",
                    embedding_dim=32,
                    emitter_sizes=[layer_size, layer_size],
                    emitter_nonlinearity=nn.ReLU,
                    emitter_output_nonlinearity=None,
                    init_std=np.sqrt(1 / 3),
                    min_std=np.exp(-20.0),
                    max_std=np.exp(0.0),
                    batch_size=batch_size,
                    device=device,
                    is_single_target=is_single_target,
                    no_value=True,
                ).to(device)

            def make_q_func():
                return AdaptiveMLPQFunctionDoCausal(
                    env_spec=env.spec,
                    encoding_dim=32,
                    batch_size=batch_size,
                    encoder_widening_factor=2,
                    encoder_dropout=0.0,
                    encoder_n_layers=num_attn_layers_q_func,
                    encoder_num_heads=8,
                    emitter_sizes=[layer_size, layer_size],
                    emitter_nonlinearity=nn.ReLU,
                    emitter_output_nonlinearity=None,
                    is_single_target=is_single_target,
                    no_value=True,
                ).to(device)

            eval_env = make_eval_env(
                eval_model,
                budget,
                save_path=eval_save_dir,
            )
            eval_ood_envs = make_ood_eval_env(
                eval_ood_models,
                budget,
                save_path=eval_save_dir,
            )
            deterministic.set_seed(seed)
            set_rng_seed(seed)
            env = make_env(model, budget)
            policy = make_policy()
            qfs = [make_q_func() for _ in range(ens_size)]
            if shared_encoder:
                for qf in qfs:
                    qf._encoder = policy._encoder
            sampler = LocalSampler(
                agents=policy,
                envs=env,
                max_episode_length=budget,
                worker_class=VectorWorker,
                worker_args={
                    "num_init_obs": num_initial_obs,
                    "batch_size": batch_size,
                },
            )

            sac = SAC(
                env_spec=env.spec,
                policy=policy,
                qfs=qfs,
                replay_buffer=replay_buffer,
                sampler=sampler,
                max_episode_length_eval=budget,
                gradient_steps_per_itr=64,
                min_buffer_size=int(1e4),
                target_update_tau=tau,
                policy_lr=pi_lr,
                qf_lr=qf_lr,
                discount=discount,
                discount_delta=0.0,
                fixed_alpha=alpha,
                buffer_batch_size=minibatch_size,
                reward_scale=1.0,
                M=M,
                G=G,
                ent_anneal_rate=1 / 1.4e4,
                eval_env=eval_env,
                eval_ood_envs=eval_ood_envs,
                device=device,
                save_dir=eval_save_dir,
            )

        sac.to()
        trainer.setup(algo=sac, env=env)
        trainer.train(n_epochs=n_rl_itr, batch_size=n_parallel * budget)

    sac_source(
        n_parallel=n_parallel,
        budget=budget,
        n_rl_itr=n_rl_itr,
        seed=seed,
        eval_save_dir=eval_save_dir,
        discount=discount,
        alpha=alpha,
        d=d,
        tau=tau,
        pi_lr=pi_lr,
        qf_lr=qf_lr,
        buffer_capacity=buffer_capacity,
        ens_size=ens_size,
        M=M,
        G=G,
        minibatch_size=minibatch_size,
        num_initial_obs=num_initial_obs,
        batch_size=batch_size,
        use_wandb=use_wandb,
        num_attn_layers=num_attn_layers,
        num_attn_layers_q_func=num_attn_layers_q_func,
        norm_rewards=norm_rewards,
        is_single_target=is_single_target,
        intervention_type=intervention_type,
        shared_encoder=shared_encoder,
        graph_degree=graph_degree,
        noise_config_type=noise_config_type,
    )


if __name__ == "__main__":
    torch.multiprocessing.set_start_method("spawn")
    parser = argparse.ArgumentParser()
    parser.add_argument("--n-parallel", default="10000", type=int)
    parser.add_argument("--budget", default="10", type=int)
    parser.add_argument("--n-rl-itr", default="10000", type=int)
    parser.add_argument(
        "--log-dir",
        default="sergio_sac/",
        type=str,
    )
    parser.add_argument("--snapshot-mode", default="gap_overwrite", type=str)
    parser.add_argument("--snapshot-gap", default=100, type=int)
    parser.add_argument("--discount", default="1", type=float)
    parser.add_argument("--alpha", default="-1", type=float)
    parser.add_argument("--num-initial-obs", default="50", type=int)
    parser.add_argument("--d", default="5", type=int)
    parser.add_argument("--tau", default="5e-3", type=float)
    parser.add_argument("--pi-lr", default="3e-4", type=float)
    parser.add_argument("--qf-lr", default="3e-4", type=float)
    parser.add_argument("--buffer-capacity", default="1e6", type=float)
    parser.add_argument("--ens-size", default="2", type=int)
    parser.add_argument("--M", default="2", type=int)
    parser.add_argument("--G", default="1", type=int)
    parser.add_argument("--minibatch-size", default="1024", type=int)
    parser.add_argument("--batch-size", default="1", type=int)
    parser.add_argument("--data-seed", default=1, type=int)
    parser.add_argument("--seed", default=1, type=int)
    parser.add_argument("--wandb", action="store_true")
    parser.add_argument("--anneal-lr", default=0, type=int)
    parser.add_argument("--num-attn-layers", default=1, type=int)
    parser.add_argument("--num-attn-layers-q-func", default=1, type=int)
    parser.add_argument("--norm-rewards", default=0, type=int)
    parser.add_argument("--is-single-target", default=0, type=int)
    parser.add_argument("--intervention-type", default="kout", type=str.lower)
    parser.add_argument("--shared-encoder", action="store_true")
    parser.add_argument("--graph-degree", default=1.0, type=float)
    parser.add_argument("--noise-config-type", default="10x-chromium-mini", type=str)

    args = parser.parse_args()
    if args.norm_rewards == 0:
        args.norm_rewards = False
    else:
        args.norm_rewards = True
    if args.is_single_target == 0:
        args.is_single_target = False
    else:
        args.is_single_target = True
    if args.anneal_lr == 0:
        args.anneal_lr = False
    else:
        args.anneal_lr = True
    if args.wandb:
        wandb.init(
            project="caasl",
            entity="",
            config=args,
            dir="wandb/",
            group=f"sergio-sem-sac/d={args.d}",
            job_type="train",
        )
    args.log_dir = os.path.join(
        args.log_dir,
        f"best_d={args.d}_nio={args.num_initial_obs}_pi_lr={args.pi_lr}_qf_lr={args.qf_lr}_tau={args.tau}_bc={args.buffer_capacity}_ens_size={args.ens_size}_mb_size={args.minibatch_size}_ist={args.is_single_target}_it={args.intervention_type}_deg={args.graph_degree}_disc={args.discount}_alpha={args.alpha}_G={args.G}_ann_lr={args.anneal_lr}",
    )
    exp_id = args.seed + 1
    alpha = args.alpha if args.alpha >= 0 else None
    buff_cap = int(args.buffer_capacity)
    log_info = f"input params: {vars(args)}"
    main(
        n_parallel=args.n_parallel,
        budget=args.budget,
        n_rl_itr=args.n_rl_itr,
        log_dir=args.log_dir,
        eval_save_dir=None,
        snapshot_mode=args.snapshot_mode,
        snapshot_gap=args.snapshot_gap,
        discount=args.discount,
        alpha=alpha,
        d=args.d,
        log_info=log_info,
        tau=args.tau,
        pi_lr=args.pi_lr,
        qf_lr=args.qf_lr,
        buffer_capacity=buff_cap,
        ens_size=args.ens_size,
        M=args.M,
        G=args.G,
        minibatch_size=args.minibatch_size,
        data_seed=args.data_seed,
        num_initial_obs=args.num_initial_obs,
        seed=args.seed,
        batch_size=args.batch_size,
        use_wandb=args.wandb,
        num_attn_layers=args.num_attn_layers,
        num_attn_layers_q_func=args.num_attn_layers_q_func,
        norm_rewards=args.norm_rewards,
        is_single_target=args.is_single_target,
        intervention_type=args.intervention_type,
        shared_encoder=args.shared_encoder,
        graph_degree=args.graph_degree,
        noise_config_type=args.noise_config_type,
    )
