from typing import *
import gym.wrappers
import gym
import exp_utils as PQ
from loguru import logger
import torch
import numpy as np

from safe import *
from safe.envs import make_env
from rl_utils.runner import merge_episode_stats

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def evaluate(step, runner, policy, tag, *, n_eval_samples):
    runner.reset()
    ep_infos = runner.run(policy, n_eval_samples)

    for key, value in merge_episode_stats(ep_infos).items():
        value = np.array(value)
        mean, std = np.mean(value), np.std(value)
        if key == 'episode.unsafe':
            if value.sum() > 0:
                logger.warning(f'# {step}, tag = {tag}, {key} = {mean:.6f} ± {std:.6f} over {len(value)} episodes.')
        else:
            logger.info(f'# {step}, tag = {tag}, {key} = {mean:.6f} ± {std:.6f} over {len(value)} episodes.')
        PQ.writer.add_scalar(f'{tag}/{key}/mean', mean, global_step=step)
        PQ.writer.add_scalar(f'{tag}/{key}/std', std, global_step=step)
        PQ.writer.add_scalar(f'{tag}/{key}/n', len(value), global_step=step)


class Debugger:
    video_env_det: gym.wrappers.Monitor
    video_env_rand: gym.wrappers.Monitor
    det_traj: List
    rand_traj: List

    def __init__(self, env, policy, mean_policy, L, model, runner, horizon, s0, s_opt, s_opt_grad, s_opt_sample, L_opt,
                 fns, buffer_out, crabs, qfn, FLAGS):
        self.env = env
        self.policy = policy
        self.model = model
        self.mean_policy = mean_policy
        self.runner = runner
        self.horizon = horizon
        self.L = L
        self.s0 = s0
        self.s_opt = s_opt
        self.s_opt_sample = s_opt_sample
        self.s_opt_grad = s_opt_grad
        self.L_opt = L_opt
        self.fns = fns
        self.n_policy_updates = 0
        self.n_barrier_updates = 0
        self.buffer_out = buffer_out
        self.status = {}
        self.crabs = crabs
        self.FLAGS = FLAGS
        self.last_expl = None
        self.qfn = qfn

        self.init_video_maker()

    def init_video_maker(self):
        from gym.wrappers import Monitor

        video_path = PQ.log_dir / 'videos'
        video_path.mkdir()
        env = make_env()
        # self.video_env_det = Monitor(env, video_path, force=True, video_callable=lambda episode_id: True, uid='det')
        self.video_env_det = env
        # self.video_env_rand = Monitor(env, video_path, force=True, video_callable=lambda episode_id: True, uid='rand')
        self.video_env_rand = env

    def update(self, step_type):
        if step_type == 'policy':
            self.n_policy_updates += 1
        elif step_type == 'barrier':
            self.n_barrier_updates += 1
        else:
            assert 0

    def _do_policy(self, t):
        evaluate(t, self.runner, self.policy, 'policy', n_eval_samples=self.horizon)

    def _do_mean_policy(self, t):
        evaluate(t, self.runner, self.mean_policy, 'mean_policy', n_eval_samples=self.horizon)

    # def _do_unsafe_policy(self, t):
    #     evaluate(t, self.runner, self.unsafe_policy, 'unsafe_policy', n_eval_samples=self.horizon)

    def _do_L_grad(self, t):
        if next(self.L.parameters()).grad is not None:
            grads = torch.nn.utils.parameters_to_vector([p.grad for p in self.L.parameters() if p.grad is not None])
            L_grad_norm = grads.norm().item()
        else:
            L_grad_norm = 0.

        PQ.writer.add_scalar('L/grad_norm', L_grad_norm, global_step=t)
        g_s0 = self.L.net(self.s0).item()
        logger.debug(f"g(s0) = {g_s0:.6f}, L grad norm = {L_grad_norm:.6f}, ")

    def _do_updates(self, t):
        p_policy_update = self.n_policy_updates / self.FLAGS.n_eval_iters
        p_barrier_update = self.n_barrier_updates / self.FLAGS.n_eval_iters
        PQ.writer.add_scalar('policy/p_updates', p_policy_update, global_step=t)
        PQ.writer.add_scalar('barrier/p_updates', p_barrier_update, global_step=t)
        self.n_policy_updates = 0
        self.n_barrier_updates = 0
        logger.debug(f"Pr(policy updates) = {p_policy_update:.3f}, Pr(barrier updates) = {p_barrier_update:.3f}")

    def _do_s_grad(self, t):
        self.s_opt_grad.reinit()
        for i in range(10001):
            if i % 1000 == 0:
                grad_opt_info = self.s_opt_grad.evaluate(step=t)
            self.s_opt_grad.step()

        PQ.writer.add_scalar('grad_opt/optimal', grad_opt_info['optimal'], global_step=t)

        # for i in range(10000):
        #     if i % 1000 == 0:
        #         self.s_opt.evaluate(step=t)
        #     self.s_opt.step()

        # self.s_langevin_dev.reinit()
        # for i in range(10000):
        #     if i % 1000 == 0:
        #         self.s_langevin_dev.evaluate(step=t)
        #     self.s_langevin_dev.step()

    def _do_s(self, t):
        self.s_opt_sample.evaluate(step=t)
        self.s_opt.evaluate(step=t)
        # s_opt_grad.evaluate()
        self.L_opt.evaluate(self.s_opt.s, step=t)

        for i in range(self.FLAGS.opt_s.n_steps):
            PQ.writer.add_scalar(f's/step_{i}', PQ.meters[f'opt_progress/{i}'].mean, global_step=t)
        PQ.meters.purge('opt_progress/')

    def _do_plot(self, t):
        if self.FLAGS.env.id == 'MyPendulum-v0':
            clouds = {}
            if self.FLAGS.opt_s.method in ['grad', 'MALA', 'CEM', 'metropolis']:
                clouds['s'] = self.s_opt.s.cpu().detach().numpy()
            clouds['traj'] = np.array(self.det_traj)

            plot_fns = {key: self.fns[key] for key in ['L', 'hardD', 'softD']}
            plot_pendulum_set(plot_fns, device, clouds, PQ.log_dir / f'fig-{t}.png', f'# {t}')
        elif self.FLAGS.env.id == 'SafeReal-v1':
            plot_real(self.env, self.L, self.U_pi, self.s_opt.s, PQ.log_dir / f'fig-{t}.png', device=device)
        elif self.FLAGS.env.id == 'SafeInvertedPendulum-v2':
            Vt = torch.eye(4, device=device, dtype=torch.float32)

            # slice = torch.tensor([-3.4442109e-01, 0, 0, 0], device=device)
            slice = torch.zeros(4, device=device)
            plot_fns = {
                'L': lambda xs: self.L(xs.reshape(-1, 2).mm(Vt[:2]).reshape(201, 201, 4)),
                # 'L': lambda xs: self.L(xs.reshape(-1, 2).mm(Vt[2:]).reshape(201, 201, 4) + slice.to(xs.device)),
                # 'hardD': lambda xs: self.fns['hardD'](xs.reshape(-1, 2).mm(Vt[:2]).reshape(201, 201, 4)),
            }
            traj = np.array(self.det_traj)
            clouds = {
                'traj': traj,
            }
            if self.last_expl is not None:
                clouds['expl'] = self.last_expl

            plot_pendulum_set(plot_fns, device, clouds, PQ.log_dir / f'fig-{t}.png', f'# {t}',
                              y_max=0.2, y_min=-0.2, x_max=1.0, x_min=-1.0, xlabel="pos", ylabel="angle",
                              decode=lambda x: (x[:, 0], x[:, 1]),)

        if self.FLAGS.ckpt.x_vs_L:
            plot_x_vs_L(self.FLAGS.ckpt.x_vs_L, self.L, filename=PQ.log_dir / f"Lx-{t}.png", title=f"# {t}")

    def _do_save(self, t):
        torch.save({
            'L': self.L.state_dict(),
            's': self.s_opt.state_dict(),
            'policy': self.policy.state_dict(),
            'qfn': self.qfn.state_dict(),
            'models': self.model.state_dict(),
            'safe_invariant': self.crabs.state_dict(),
        }, PQ.log_dir / f'ckpt-{t}.pt')

    def _do_expl(self, t):
        logger.debug(f"[expl]: backup prob = {PQ.meters['expl/backup'].mean:.6f}")
        PQ.meters['expl/backup'].reset()

    def evaluate(self, t, **kwargs):
        keys = [key for key, value in kwargs.items() if value]
        if len(keys) == 0:
            return
        # logger.info(f"################ iter {t}")
        for key in keys:
            getattr(self, f'_do_{key}')(t)

    @torch.no_grad()
    def _do_virt_safe(self, t):
        states, _, _ = model_rollout(self.model.models[0], self.mean_policy, self.s0, 10000)
        b_safe = self.crabs.env_barrier_fn(states)
        last_L = self.crabs.L(states[-1]).item()
        max_safe = b_safe.max().item()
        self.status['policy_virt_safe'] = last_L <= 0

        if max_safe < 1.0:
            logger.debug(f"[virt safe] max safe = {max_safe:.6f}, last_L = {last_L:.6f}")
        else:
            logger.warning(f"[virt safe] max safe = {max_safe:.6f}, last_L = {last_L:.6f}")

    def _do_video(self, t):
        self.det_traj = render_video(t, self.video_env_det, self.mean_policy, 'det')
        self.rand_traj = render_video(t, self.video_env_rand, self.policy, 'rand')

    def _do_buf_out(self, t):
        logger.debug(f"[buf out] size = {self.buffer_out.index}")


def render_video(t, video_env, policy, tag):
    observations = []
    observation = video_env.reset()
    observations.append(observation)
    video_env.episode_id = t
    return_ = 0.
    done = False
    info = {}

    while not done:
        action = policy.get_actions(observation)
        next_observation, reward, done, info = video_env.step(action)
        return_ += reward
        observation = next_observation
        observations.append(observation)
        if done or np.linalg.norm(observation) > 1e5:
            break
    logger.debug(f'[video {tag}] iter = {t}: return = {return_:.6f}, last info = {info}')
    return observations
