"""
Copyright 2022 Sahand Rezaei-Shoshtari. All Rights Reserved.

Implementation of DDPG
https://arxiv.org/abs/1509.02971

Code is based on:
https://github.com/sfujim/TD3/blob/master/OurDDPG.py
"""

import hydra
import copy
import numpy as np
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.core import DeterministicActor, DDPGCritic
import utils.utils as utils
import torchvision.models as models

torch.set_printoptions(threshold=100000)  # Reset to default threshold value


class Encoder(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, state_dim + action_dim)
        )
        self.apply(utils.weight_init)

        self.predictor = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim, bias=False),
                                       nn.BatchNorm1d(hidden_dim),
                                       nn.ReLU(inplace=True),  # hidden layer
                                       nn.Linear(hidden_dim, state_dim + action_dim))  # output layer

    def forward(self, state, action):
        sa = torch.cat([state, action], 1)
        temp = self.encoder(sa)
        split_index = state.size(1)
        abstract_state = temp[:, :split_index]
        abstract_action = temp[:, split_index:]
        return abstract_state, abstract_action

    def decode(self, state, action, state_1, action_1):
        """
        Input:
            x1: first views of images
            x2: second views of images
        Output:
            p1, p2, z1, z2: predictors and targets of the network
            See Sec. 3 of https://arxiv.org/abs/2011.10566 for detailed notations
        """

        # compute features for one view
        x1 = torch.cat([state, action], 1)
        x2 = torch.cat([state_1, action_1], 1)
        z1 = self.encoder(x1)  # NxC
        z2 = self.encoder(x2)  # NxC
        p1 = self.predictor(z1)  # NxC
        p2 = self.predictor(z2)  # NxC

        return p1, p2, z1.detach(), z2.detach()


class DDPGAgent:
    def __init__(self, obs_shape, action_shape, device, lr, feature_dim,
                 hidden_dim, linear_approx, critic_target_tau, num_expl_steps,
                 update_every_steps, stddev_schedule, stddev_clip,
                 clipped_noise, aug_ratio):

        self.device = device
        self.critic_target_tau = critic_target_tau
        self.update_every_steps = update_every_steps
        self.num_expl_steps = num_expl_steps
        self.stddev_schedule = stddev_schedule
        self.clipped_noise = clipped_noise
        self.stddev_clip = stddev_clip
        self.action_dim = action_shape[0]
        self.hidden_dim = hidden_dim
        self.lr = lr
        self.aug_ratio = aug_ratio
        self.k = 0

        # models
        self.actor = DeterministicActor(obs_shape, action_shape[0],
                                        hidden_dim, linear_approx).to(self.device)
        self.actor_target = copy.deepcopy(self.actor)

        self.critic = DDPGCritic(obs_shape, action_shape[0],
                                 hidden_dim, linear_approx).to(self.device)
        self.critic_target = copy.deepcopy(self.critic)

        self.encoder = Encoder(obs_shape, action_shape[0], hidden_dim).to(self.device)

        # optimizers
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)
        self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(), lr=1e-2)

        self.train()
        self.actor_target.train()
        self.critic_target.train()

    def train(self, training=True):
        self.training = training
        self.actor.train(training)
        self.critic.train(training)
        self.encoder.train(training)

    def act(self, obs, step, eval_mode):
        obs = torch.as_tensor(obs, device=self.device)
        stddev = utils.schedule(self.stddev_schedule, step)
        action = self.actor(obs.float().unsqueeze(0))
        if eval_mode:
            action = action.cpu().numpy()[0]
        else:
            action = action.cpu().numpy()[0] + np.random.normal(0, stddev, size=self.action_dim)
            if step < self.num_expl_steps:
                action = np.random.uniform(-1.0, 1.0, size=self.action_dim)
        return action.astype(np.float32)

    def observe(self, obs, action):
        obs = torch.as_tensor(obs, device=self.device).float().unsqueeze(0)
        action = torch.as_tensor(action, device=self.device).float().unsqueeze(0)

        q = self.critic(obs, action)

        return {
            'state': obs.cpu().numpy()[0],
            'value': q.cpu().numpy()[0]
        }

    def rotation_matirx(self, alphas):
        c, s = np.cos(alphas), np.sin(alphas)
        bz = len(c)
        R = torch.zeros((bz, 3, 3))
        R[:, 0, 0] = c
        R[:, 0, 1] = -s
        R[:, 1, 0] = s
        R[:, 1, 1] = c
        R[:, 2, 2] = 1

        return R.cuda()

    def roto_translate_vector_3d(self, vector_3d, R, trans):
        # R = self.rotation_matirx(alphas)
        transformed_vector_3d = torch.bmm(R, vector_3d)
        transformed_vector_3d = torch.permute(transformed_vector_3d, (0, 2, 1))
        # return torch.t(R@(torch.t(vector_3d))) + trans
        return transformed_vector_3d + trans

    def roto_translate(self, obs, action, next_obs, alphas):

        bz = len(obs)
        R = self.rotation_matirx(alphas)

        touch = obs[:, -2:]
        obs = obs[:, :-2]

        obs = obs.view(bz, -1, 3)
        obs = torch.permute(obs, (0, 2, 1))
        rotated_obs = self.roto_translate_vector_3d(obs, R, 0)
        rotated_obs = rotated_obs.reshape(bz, -1)
        rotated_obs = torch.cat((rotated_obs, touch), dim=1)

        next_touch = next_obs[:, -2:]
        next_obs = next_obs[:, :-2]

        next_obs = next_obs.view(bz, -1, 3)
        next_obs = torch.permute(next_obs, (0, 2, 1))
        rotated_next_obs = self.roto_translate_vector_3d(next_obs, R, 0)
        rotated_next_obs = rotated_next_obs.reshape(bz, -1)
        rotated_next_obs = torch.cat((rotated_next_obs, next_touch), dim=1)

        rotated_action = torch.clone(action)

        return rotated_obs, rotated_action, rotated_next_obs

    def get_aug_data(self, obs, action, next_obs):
        bz = len(obs)
        alphas = torch.rand(bz) * 360

        mask = torch.zeros(bz)
        if self.aug_ratio == 1:
            mask[:] = 1
        elif self.aug_ratio == 2:  # 50
            mask[:bz // 2] = 1
        elif self.aug_ratio == 3:  # 75
            mask[bz // 4:] = 1
        elif self.aug_ratio == 4:  # 25
            mask[:bz // 4] = 1
        alphas *= mask

        alphas = torch.deg2rad(alphas)

        rotated_obs, rotated_action, rotated_next_obs = self.roto_translate(obs, action, next_obs, alphas)

        return rotated_obs, rotated_action, rotated_next_obs

    def update_critic(self, obs, action, reward, discount, next_obs, step):
        metrics = dict()
        obs, action = self.encoder(obs, action)

        with torch.no_grad():
            if self.clipped_noise:
                # Select action according to policy and add clipped noise
                stddev = utils.schedule(self.stddev_schedule, step)
                noise = (torch.randn_like(action) * stddev).clamp(-self.stddev_clip, self.stddev_clip)

                next_action = (self.actor_target(next_obs) + noise).clamp(-1.0, 1.0)
            else:
                next_action = self.actor_target(next_obs)

            next_obs, next_action = self.encoder(next_obs, next_action)
            # Compute the target Q value
            target_Q = self.critic_target(next_obs, next_action)
            target_Q = reward + discount * target_Q

        # Get current Q estimates
        current_Q = self.critic(obs, action)

        # Compute critic loss
        critic_loss = F.mse_loss(current_Q, target_Q)

        metrics['critic_target_q'] = target_Q.mean().item()
        metrics['critic_q'] = current_Q.mean().item()
        metrics['critic_loss'] = critic_loss.item()


        self.encoder_optimizer.zero_grad(set_to_none=True)
        self.critic_optimizer.zero_grad(set_to_none=True)
        critic_loss.backward()
        self.critic_optimizer.step()
        self.encoder_optimizer.step()


        return metrics

    def update_actor(self, eq_state, step):
        metrics = dict()

        # Compute actor loss
        #        with torch.no_grad():
        #            eq_state = self.encoder(eq_state)
        eq_state, new_action = self.encoder(eq_state, self.actor(eq_state))
        actor_loss = -self.critic(eq_state, new_action).mean()

        # Optimize the actor
        self.encoder_optimizer.zero_grad(set_to_none=True)
        self.actor_optimizer.zero_grad(set_to_none=True)
        actor_loss.backward()
        self.actor_optimizer.step()
        self.encoder_optimizer.step()

        metrics['actor_loss'] = actor_loss.item()

        return metrics

    def update_sim(self, obs, action, ori_obs, ori_action):
        metrics = dict()

        criterion = nn.CosineSimilarity(dim=1)
        # compute output and loss
        p1, p2, z1, z2 = self.encoder.decode(obs, action, ori_obs, ori_action)
        SimSiam_loss = -(criterion(p1, z2).mean() + criterion(p2, z1).mean()) * 0.5

        # compute gradient and do SGD step
        self.encoder_optimizer.zero_grad(set_to_none=True)
        SimSiam_loss.backward()
        self.encoder_optimizer.step()

        metrics['sim_loss'] = SimSiam_loss.item()
        return metrics

    def update(self, replay_iter, step):
        metrics = dict()

        batch = next(replay_iter)
        obs, action, reward, discount, next_obs, _, eq_state, next_eq_state = utils.to_torch(batch, self.device)

        eq_state = eq_state.float()
        next_eq_state = next_eq_state.float()
        aug_eq_state, aug_action, aug_next_eq_state = self.get_aug_data(eq_state, action, next_eq_state)
        aug_eq_state = aug_eq_state.float()
        aug_action = aug_action.float()
        aug_next_eq_state = aug_next_eq_state.float()

        metrics['batch_reward'] = reward.mean().item()

        metrics.update(self.update_sim(aug_eq_state, aug_action, eq_state, action))

        # update critic
        metrics.update(self.update_critic(aug_eq_state, aug_action, reward, discount, aug_next_eq_state, step))

        # update actor (delayed)
        if step % self.update_every_steps == 0:
            metrics.update(self.update_actor(aug_eq_state.detach(), step))
            # update target networks
            utils.soft_update_params(self.critic, self.critic_target, self.critic_target_tau)
            utils.soft_update_params(self.actor, self.actor_target, self.critic_target_tau)

        return metrics

    def save(self, model_dir, step):
        model_save_dir = Path(f'{model_dir}/step_{str(step).zfill(8)}')
        model_save_dir.mkdir(exist_ok=True, parents=True)

        torch.save(self.actor.state_dict(), f'{model_save_dir}/actor.pt')
        torch.save(self.critic.state_dict(), f'{model_save_dir}/critic.pt')

    def load(self, model_dir, step):
        print(f"Loading the model from {model_dir}, step: {step}")
        model_load_dir = Path(f'{model_dir}/step_{str(step).zfill(8)}')

        self.actor.load_state_dict(
            torch.load(f'{model_load_dir}/actor.pt', map_location=self.device)
        )
        self.critic.load_state_dict(
            torch.load(f'{model_load_dir}/critic.pt', map_location=self.device)
        )
