import torch
import numpy as np
import tinycudann as tcnn
from lidarnerf.activation import trunc_exp
from .renderer import NeRFRenderer
from lidarnerf.dataset import nus_dataset_barf
import pickle
import sys
import os
from lidarnerf.planes import Planes4D



class NeRFNetwork(NeRFRenderer):
    def __init__(
        self,
        opt=None,
        device=torch.device("cuda:1"),
        encoding="HashGrid",
        desired_resolution=2048,
        log2_hashmap_size=19,
        encoding_dir="SphericalHarmonics",
        n_features_per_level=2,
        num_layers=2,
        hidden_dim=64,
        geo_feat_dim=15,
        num_layers_color=3,
        hidden_dim_color=64,
        out_color_dim=3,
        out_lidar_color_dim=2,
        bound=1,
        size=36,
        min_resolution=32,
        n_features_per_level_plane=0,
        n_levels_plane=0,
        **kwargs,
    ):
        super().__init__(bound, **kwargs)
        
        self.opt=opt
        self.rot=opt.rot
        self.trans=opt.trans
        self.noise_rot=opt.noise_rot
        self.noise_trans=opt.noise_trans
        self.device=device
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.geo_feat_dim = geo_feat_dim
        self.desired_resolution = desired_resolution
        self.log2_hashmap_size = log2_hashmap_size
        self.out_color_dim = out_color_dim
        self.out_lidar_color_dim = out_lidar_color_dim
        self.n_features_per_level = n_features_per_level
        self.rre=[[] for i in range(self.opt.dataloader_size)]
        self.rte=[[] for i in range(self.opt.dataloader_size)]
        self.rre_when_train_pose=[[] for i in range(self.opt.dataloader_size)]
        self.rte_when_train_pose=[[] for i in range(self.opt.dataloader_size)]
        self.rre_when_graph_optim=[[] for i in range(self.opt.dataloader_size)]
        self.rte_when_graph_optim=[[] for i in range(self.opt.dataloader_size)]
        self.progress = torch.nn.Parameter(torch.tensor(0.),requires_grad=False)
        self.loss_record=[[] for i in range(self.opt.dataloader_size)]

        self.se3_refine_trans = torch.nn.Embedding(self.opt.dataloader_size,3).to(self.device)
        torch.nn.init.zeros_(self.se3_refine_trans.weight)
        self.se3_refine_rot=torch.nn.Embedding(self.opt.dataloader_size,3).to(self.device)
        torch.nn.init.zeros_(self.se3_refine_rot.weight)
        se3_noise_rot=torch.randn(self.opt.dataloader_size,3,device=self.device)*0
        se3_noise_trans=torch.randn(self.opt.dataloader_size,3,device=self.device)*0
        if self.noise_rot:
            se3_noise_rot = torch.randn(self.opt.dataloader_size,3,device=self.device)*0.15
        if self.noise_trans:
            se3_noise_trans = torch.randn(self.opt.dataloader_size,3,device=self.device)*2*self.opt.scale
        se3_noise=torch.cat([se3_noise_rot,se3_noise_trans],dim=-1)
        self.pose_noise =[]
        for s in se3_noise:
            self.pose_noise.append(self.lie.se3_to_SE3(s,self.device))
        per_level_scale = np.exp2(
            np.log2(self.desired_resolution * bound / 16) / (16 - 1)
        )
        # if needed
        # self.planes_encoder = Planes4D(
        #     grid_dimensions=2,
        #     input_dim=3,
        #     output_dim=n_features_per_level_plane,
        #     resolution=[min_resolution] * 3,
        #     multiscale_res=[2**(n) for n in range(n_levels_plane)],
        # )

        self.encoder = tcnn.Encoding(
            n_input_dims=3,
            encoding_config={
                "otype": "HashGrid",
                "n_levels": 20,
                "n_features_per_level": self.n_features_per_level,
                "log2_hashmap_size": self.log2_hashmap_size,
                "base_resolution": 20,
                "per_level_scale": per_level_scale,
            },
        )

        self.sigma_net = tcnn.Network(
            n_input_dims=self.encoder.n_output_dims,
            n_output_dims=1 + self.geo_feat_dim,
            network_config={
                "otype": "FullyFusedMLP",
                "activation": "ReLU",
                "output_activation": "None",
                "n_neurons": hidden_dim,
                "n_hidden_layers": num_layers - 1,
            },
        )

        self.num_layers_color = num_layers_color
        self.hidden_dim_color = hidden_dim_color

        self.encoder_dir = tcnn.Encoding(
            n_input_dims=3,
            encoding_config={
                "otype": "SphericalHarmonics",
                "degree": 4,
            },
        )

        self.encoder_lidar_dir = tcnn.Encoding(
            n_input_dims=3,
            encoding_config={
                "otype": "Frequency",
                "degree": 12,
            },
        )

        self.in_dim_color = self.encoder_dir.n_output_dims + self.geo_feat_dim

        self.color_net = tcnn.Network(
            n_input_dims=self.in_dim_color,
            n_output_dims=self.out_color_dim,
            network_config={
                "otype": "FullyFusedMLP",
                "activation": "ReLU",
                "output_activation": "None",
                "n_neurons": hidden_dim_color,
                "n_hidden_layers": num_layers_color - 1,
            },
        )

        self.in_dim_lidar_color = (
            self.encoder_lidar_dir.n_output_dims + self.geo_feat_dim
        )
        self.lidar_color_net = tcnn.Network(
            n_input_dims=self.in_dim_lidar_color,
            n_output_dims=self.out_lidar_color_dim,
            network_config={
                "otype": "FullyFusedMLP",
                "activation": "ReLU",
                "output_activation": "None",
                "n_neurons": hidden_dim_color,
                "n_hidden_layers": num_layers_color - 1,
            },
        )


    def save_pose(self,idx,pose):
        return 0

    def get_pose(self,idx,pose):
        if self.rot and self.trans:
            self.se3_refine=torch.cat([self.se3_refine_rot.weight,self.se3_refine_trans.weight],dim=-1)
            se3_refine = self.se3_refine[idx] #3
            pose_refine = self.lie.se3_to_SE3(se3_refine,self.device)
            pose=pose@self.pose_noise[idx]
            if self.opt.no_gt_pose:
                p=torch.eye(4)
                pose=p.unsqueeze(0)
                pose=pose.to(self.opt.device)
            pose_new = self.lie.compose_pair(pose_refine,pose)
        elif self.rot:
            se3_trans=torch.zeros([36,3],device=self.device)
            self.se3_refine=torch.cat([self.se3_refine_rot.weight,se3_trans],dim=-1)
            se3_refine=self.se3_refine[idx]
            pose_refine = self.lie.se3_to_SE3(se3_refine,self.device)
            if idx==1:
                print(se3_refine)
            pose=pose@self.pose_noise[idx]
            pose_new = self.lie.compose_pair(pose_refine,pose)
        elif self.trans:
            se3_rot=torch.zeros([self.opt.dataloader_size,3],device=self.device)
            self.se3_refine=torch.cat([se3_rot,self.se3_refine_trans.weight],dim=-1)
            se3_refine=self.se3_refine[idx]
            pose_refine = self.lie.se3_to_SE3(se3_refine,self.device)
            if idx==1:
                print(se3_refine)
            pose=pose@self.pose_noise[idx]
            pose_new = self.lie.compose_pair(pose_refine,pose)
        else:
            se3_trans=torch.zeros([self.opt.dataloader_size,3],device=self.device)
            se3_rot=torch.zeros([self.opt.dataloader_size,3],device=self.device)
            self.se3_refine=torch.cat([se3_rot,se3_trans],dim=-1)
            pose_new=pose@self.pose_noise[idx]
        return pose_new
    def forward(self, x, d):
        pass
    def load_data(self):
        self.dataset=nus_dataset_barf.BaseDataset
    def density(self, x):
        start,end =0.0,0.8
        alpha = ((self.progress.data-start)/(end-start)*20).clamp_(min=0,max=20)
        k = torch.arange(20,dtype=torch.float32,device=self.device)
        weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2
        weight=weight.repeat_interleave(2)

        x = (x + self.bound) / (2 * self.bound)
        has_greater_than_1 = torch.any(x > 1)
        has_less_than_0 = torch.any(x < 0)

        if has_greater_than_1 or has_less_than_0:
            sys.exit()
        x = self.encoder(x)
        x=weight*x

        h = self.sigma_net(x)
        sigma = trunc_exp(h[..., 0])
        geo_feat = h[..., 1:]
        return {
            "sigma": sigma,
            "geo_feat": geo_feat,
        }
    def color(self, x, d, cal_lidar_color=False, mask=None, geo_feat=None, **kwargs):
        # x: [N, 3] in [-bound, bound]
        # mask: [N,], bool, indicates where we actually needs to compute rgb.

        x = (x + self.bound) / (2 * self.bound)  # to [0, 1]

        if mask is not None:
            rgbs = torch.zeros(
                mask.shape[0], self.out_dim, dtype=x.dtype, device=x.device
            )  # [N, 3]
            # in case of empty mask
            if not mask.any():
                return rgbs
            x = x[mask]
            d = d[mask]
            geo_feat = geo_feat[mask]
            
        if cal_lidar_color:
            d = (d + 1) / 2  # tcnn SH encoding requires inputs to be in [0, 1]
            d = self.encoder_lidar_dir(d)
            h = torch.cat([d, geo_feat], dim=-1)
            h = self.lidar_color_net(h)
        else:
            d = (d + 1) / 2
            d = self.encoder_dir(d)
            h = torch.cat([d, geo_feat], dim=-1)
            h = self.color_net(h)

        # sigmoid activation for rgb
        h = torch.sigmoid(h)

        if mask is not None:
            rgbs[mask] = h.to(rgbs.dtype)  # fp16 --> fp32
        else:
            rgbs = h

        return rgbs
    def get_params_pose_trans(self,lr):
        if self.opt.no_gt_pose:
            params = [
                {"params": self.se3_refine_trans.parameters(),"lr":0.1*lr*3}
            ]
        else:
            if self.opt.dataloader=="kitti360":
                params = [
                    {"params": self.se3_refine_trans.parameters(),"lr":0.1*lr*0.1}
                ]
            else:
                params = [
                    {"params": self.se3_refine_trans.parameters(),"lr":0.1*lr}
                ]

        return params
    def get_params_pose_rot(self,lr):
        if self.opt.no_gt_pose:
            params = [
                {"params": self.se3_refine_rot.parameters(),"lr":0.5*lr*3}
            ]
        else:
            if self.opt.dataloader=="kitti360":
                params = [
                    {"params": self.se3_refine_rot.parameters(),"lr":0.5*lr*0.1}
                ]
            else:
                params = [
                    {"params": self.se3_refine_rot.parameters(),"lr":0.5*lr}
                ]
        return params
    def get_params(self, lr):
        params = [
            {"params": self.encoder.parameters(), "lr": lr},
            {"params": self.sigma_net.parameters(), "lr": lr},
            {"params": self.encoder_dir.parameters(), "lr": lr},
            {"params": self.encoder_lidar_dir.parameters(), "lr": lr},
            {"params": self.color_net.parameters(), "lr": lr},
            {"params": self.lidar_color_net.parameters(), "lr": lr},
        ]
        if self.bg_radius > 0:
            params.append({"params": self.encoder_bg.parameters(), "lr": lr})
            params.append({"params": self.bg_net.parameters(), "lr": lr})

        return params
