import pytorch_lightning as pl
import torch
import pickle
from loguru import logger

from safe import *
from safe.envs import make_env
import rl_utils


class DetMLPPolicy(rl_utils.MLP, rl_utils.DetNetPolicy):
    pass


class MLPQFn(rl_utils.MLP, rl_utils.NetQFn):
    pass


class TanhGaussianMLPPolicy(rl_utils.policy.TanhGaussianPolicy, rl_utils.MLP, rl_utils.NetPolicy):
    pass


class SafeCtrl(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.env = make_env()
        self.s0 = torch.tensor(self.env.reset(), dtype=torch.float32)
        self.dim_state = self.env.observation_space.shape[0]
        self.dim_action = self.env.action_space.shape[0]
        self.normalizer = Normalizer(self.dim_state, clip=1000)

        if FLAGS.ckpt.buf:
            with open(FLAGS.ckpt.buf, 'rb') as f:
                self.buf_real = pickle.load(f)
                logger.warning(f"load model buffer from {FLAGS.ckpt.buf}")
        else:
            self.buf_real = rl_utils.TorchReplayBuffer(self.env, max_buf_size=1000_000)

        self.buf_dev = rl_utils.TorchReplayBuffer(self.env, max_buf_size=10_000)
        # policy = DetMLPPolicy([dim_state, 64, 64, dim_action], auto_squeeze=False, output_activation=nn.Tanh).to(device)
        # mean_policy = policy
        self.policy = TanhGaussianMLPPolicy([self.dim_state, 64, 64, self.dim_action * 2])
        # unsafe_policy = TanhGaussianMLPPolicy([dim_state, 64, 64, dim_action * 2]).to(device)
        self.mean_policy = rl_utils.policy.MeanPolicy(self.policy)

        if FLAGS.env.id == 'MySafexp-PointGoal1-v1':
            logger.warning("use DomainModel")
            make_model = lambda: DomainModel(self.env, self.env.hazards_pos, self.env.vases_pos, self.env.goal_pos)
        else:
            # make_model = lambda i:  \
            #     TransitionModel(dim_state, normalizer, [dim_state + dim_action, 256, 256, 256, 256, dim_state * 2])
            make_model = lambda i: StableDynamics(
                TransitionModel(self.dim_state, normalizer, [self.dim_state + self.dim_action, 256, 256, 256, 256, self.dim_state * 2]),
                self.dim_state, 0.01, buf=self.buf_real, buf_dev=None, name=f'model_{i}')
        self.ensemble = EnsembleModel([make_model(i) for i in range(FLAGS.model.n_ensemble)])

        horizon = self.env.spec.max_episode_steps
        make_stats = [lambda: ExtractLastInfo('episode.unsafe'), lambda: EpisodeReturn()]
        runners = {
            'explore': RunnerX(make_env, 1, make_stats, device=device),
            'evaluate': RunnerX(make_env, 1, make_stats, device=device),
            'test': RunnerX(make_env, 1, make_stats, device=device),
        }


    def forward(self):
        pass
