import torch
import torch.nn as nn
from pytorch3d.ops import knn_points

from models.transformer.pe_transformer import PETransformerLayer
from models.transformer.rpe_transformer import RPETransformerLayer
from models.transformer.vanilla_transformer import TransformerLayer
from models.transformer.positional_encoding import GeometricStructureEmbedding
from models.transformer.linear_transformer import LinearTransformerLayer

from models.cast.spot_attention import Upsampling, Downsampling


class LinearTransformer(nn.Module):
    def __init__(self, cfg):
        super(LinearTransformer, self).__init__()
        self.down_k = cfg.down_k
        self.blocks = cfg.blocks
        self.dual_normalization = cfg.dual_normalization

        self.in_proj1 = nn.Linear(cfg.input_dim_c, cfg.hidden_dim)
        self.in_proj2 = nn.Linear(cfg.input_dim_f, cfg.hidden_dim)

        self.upsampling = nn.ModuleList()
        self.downsampling = nn.ModuleList()
        self.cross_attentions = nn.ModuleList()
        self.linear_attentions = nn.ModuleList()

        for _ in range(self.blocks):
            self.upsampling.append(Upsampling(cfg.hidden_dim, cfg.hidden_dim))
            self.downsampling.append(Downsampling(cfg.hidden_dim, cfg.hidden_dim))
            self.cross_attentions.append(TransformerLayer(
                cfg.hidden_dim, cfg.num_heads, cfg.dropout, cfg.activation_fn
            ))
            self.linear_attentions.append(LinearTransformerLayer(
                cfg.hidden_dim, cfg.num_heads, cfg.dropout, cfg.activation_fn
            ))
    
    def matching_scores(self, input_states:torch.Tensor, memory_states:torch.Tensor):
        if input_states.ndim == 2:
            matching_scores = torch.einsum('mc,nc->mn', input_states, memory_states)
        else:
            matching_scores = torch.einsum('bmc,bnc->bmn', input_states, memory_states)
        if self.dual_normalization:
            ref_matching_scores = torch.softmax(matching_scores, dim=-1)
            src_matching_scores = torch.softmax(matching_scores, dim=-2)
            matching_scores = ref_matching_scores * src_matching_scores
        return matching_scores
    
    def forward(self, ref_points,src_points, ref_feats,src_feats, ref_points_c,src_points_c, ref_feats_c,src_feats_c):
        """
        Args:
            ref_points (Tensor): (B, N, 3)
            src_points (Tensor): (B, M, 3)
            ref_feats (Tensor): (B, N, C)
            src_feats (Tensor): (B, M, C)
            ref_points_c (Tensor): (B, N', 3)
            src_points_c (Tensor): (B, M', 3)
            ref_feats_c (Tensor): (B, N', C')
            src_feats_c (Tensor): (B, M', C')

        Returns:
            ref_feats: torch.Tensor (B, N, C)
            src_feats: torch.Tensor (B, M, C)
        """
        with torch.no_grad():
            # for nearest up-sampling fusion
            ref_idx_up = knn_points(ref_points, ref_points_c)[1]  # (B, N, 1)
            src_idx_up = knn_points(src_points, src_points_c)[1]  # (B, M, 1)
        
        # for knn interpolation in down-sampling fusion
        _, ref_idx_down, ref_xyz_down = knn_points(ref_points_c, ref_points, K=self.down_k, return_nn=True)
        _, src_idx_down, src_xyz_down = knn_points(src_points_c, src_points, K=self.down_k, return_nn=True)
        
        ref_feats_c = self.in_proj1(ref_feats_c)
        src_feats_c = self.in_proj1(src_feats_c)
        new_ref_feats = self.in_proj2(ref_feats)
        new_src_feats = self.in_proj2(src_feats)

        for i in range(self.blocks):
            new_ref_feats_c,_ = self.cross_attentions[i](ref_feats_c, src_feats_c)
            new_src_feats_c,_ = self.cross_attentions[i](src_feats_c, ref_feats_c)
            
            ref_feats = self.upsampling[i](new_ref_feats, new_ref_feats_c, ref_idx_up)
            src_feats = self.upsampling[i](new_src_feats, new_src_feats_c, src_idx_up)
            
            ref_feats_c = self.downsampling[i](new_ref_feats_c, new_ref_feats, ref_points_c, ref_xyz_down, ref_idx_down)
            src_feats_c = self.downsampling[i](new_src_feats_c, new_src_feats, src_points_c, src_xyz_down, src_idx_down)

            new_ref_feats = self.linear_attentions[i](ref_feats, src_feats)
            new_src_feats = self.linear_attentions[i](src_feats, ref_feats)
        
        return new_ref_feats, new_src_feats



class LinearGeoTransformer(nn.Module):
    def __init__(self, cfg):
        super(LinearGeoTransformer, self).__init__()
        self.down_k = cfg.down_k
        self.blocks = cfg.blocks
        self.dual_normalization = cfg.dual_normalization

        self.in_proj1 = nn.Linear(cfg.input_dim_c, cfg.hidden_dim)
        self.in_proj2 = nn.Linear(cfg.input_dim_f, cfg.hidden_dim)
        if "sigma_d" in cfg.keys() and "sigma_a" in cfg.keys():
            self.embed = GeometricStructureEmbedding(
                cfg.hidden_dim, cfg.sigma_d, cfg.sigma_a, cfg.angle_k, cfg.reduction_a
            )
            self.geometric_structure_embedding = True
        else: self.geometric_structure_embedding = False

        self.upsampling = nn.ModuleList()
        self.downsampling = nn.ModuleList()
        self.self_attentions = nn.ModuleList()
        self.cross_attentions = nn.ModuleList()
        self.linear_attentions = nn.ModuleList()

        for _ in range(self.blocks):
            self.upsampling.append(Upsampling(cfg.hidden_dim, cfg.hidden_dim))
            self.downsampling.append(Downsampling(cfg.hidden_dim, cfg.hidden_dim))
            if self.geometric_structure_embedding:
                self.self_attentions.append(RPETransformerLayer(
                    cfg.hidden_dim, cfg.num_heads, cfg.dropout, cfg.activation_fn
                ))
            else:
                self.self_attentions.append(PETransformerLayer(
                    cfg.hidden_dim, cfg.num_heads, cfg.dropout, cfg.activation_fn
                ))
            self.cross_attentions.append(TransformerLayer(
                cfg.hidden_dim, cfg.num_heads, cfg.dropout, cfg.activation_fn
            ))
            self.linear_attentions.append(LinearTransformerLayer(
                cfg.hidden_dim, cfg.num_heads, cfg.dropout, cfg.activation_fn
            ))
    
    def matching_scores(self, input_states:torch.Tensor, memory_states:torch.Tensor):
        if input_states.ndim == 2:
            matching_scores = torch.einsum('mc,nc->mn', input_states, memory_states)
        else:
            matching_scores = torch.einsum('bmc,bnc->bmn', input_states, memory_states)
        if self.dual_normalization:
            ref_matching_scores = torch.softmax(matching_scores, dim=-1)
            src_matching_scores = torch.softmax(matching_scores, dim=-2)
            matching_scores = ref_matching_scores * src_matching_scores
        return matching_scores
    
    def forward(self, ref_points,src_points, ref_feats,src_feats, ref_points_c,src_points_c, ref_feats_c,src_feats_c):
        """
        Args:
            ref_points (Tensor): (B, N, 3)
            src_points (Tensor): (B, M, 3)
            ref_feats (Tensor): (B, N, C)
            src_feats (Tensor): (B, M, C)
            ref_points_c (Tensor): (B, N', 3)
            src_points_c (Tensor): (B, M', 3)
            ref_feats_c (Tensor): (B, N', C')
            src_feats_c (Tensor): (B, M', C')

        Returns:
            ref_feats: torch.Tensor (B, N, C)
            src_feats: torch.Tensor (B, M, C)
        """
        with torch.no_grad():
            # for nearest up-sampling fusion
            ref_idx_up = knn_points(ref_points, ref_points_c)[1]  # (B, N, 1)
            src_idx_up = knn_points(src_points, src_points_c)[1]  # (B, M, 1)
        
        # for knn interpolation in down-sampling fusion
        _, ref_idx_down, ref_xyz_down = knn_points(ref_points_c, ref_points, K=self.down_k, return_nn=True)
        _, src_idx_down, src_xyz_down = knn_points(src_points_c, src_points, K=self.down_k, return_nn=True)
        
        ref_feats_c = self.in_proj1(ref_feats_c)
        src_feats_c = self.in_proj1(src_feats_c)

        new_ref_feats = self.in_proj2(ref_feats)
        new_src_feats = self.in_proj2(src_feats)

        if self.geometric_structure_embedding:
            ref_embeddings = self.embed(ref_points_c)
            src_embeddings = self.embed(src_points_c)

        for i in range(self.blocks):
            if self.geometric_structure_embedding:
                ref_feats_c,_ = self.self_attentions[i](ref_feats_c, ref_feats_c, ref_embeddings)
                src_feats_c,_ = self.self_attentions[i](src_feats_c, src_feats_c, src_embeddings)
            else:
                ref_feats_c,_ = self.self_attentions[i](ref_feats_c, ref_feats_c, ref_points_c, ref_points_c)
                src_feats_c,_ = self.self_attentions[i](src_feats_c, src_feats_c, src_points_c, src_points_c)
            new_ref_feats_c,_ = self.cross_attentions[i](ref_feats_c, src_feats_c)
            new_src_feats_c,_ = self.cross_attentions[i](src_feats_c, ref_feats_c)
            
            ref_feats = self.upsampling[i](new_ref_feats, new_ref_feats_c, ref_idx_up)
            src_feats = self.upsampling[i](new_src_feats, new_src_feats_c, src_idx_up)
            
            ref_feats_c = self.downsampling[i](new_ref_feats_c, new_ref_feats, ref_points_c, ref_xyz_down, ref_idx_down)
            src_feats_c = self.downsampling[i](new_src_feats_c, new_src_feats, src_points_c, src_xyz_down, src_idx_down)

            new_ref_feats = self.linear_attentions[i](ref_feats, src_feats)
            new_src_feats = self.linear_attentions[i](src_feats, ref_feats)
        
        return new_ref_feats, new_src_feats



class LinearTransformerS2(nn.Module):
    def __init__(self, cfg):
        super(LinearTransformerS2, self).__init__()
        self.blocks = cfg.blocks
        self.dual_normalization = cfg.dual_normalization

        self.in_proj = nn.Linear(cfg.input_dim_f, cfg.hidden_dim)
        self.linear_attentions = nn.ModuleList()

        for _ in range(self.blocks):
            self.linear_attentions.append(LinearTransformerLayer(
                cfg.hidden_dim, cfg.num_heads, cfg.dropout, cfg.activation_fn
            ))
    
    def matching_scores(self, input_states:torch.Tensor, memory_states:torch.Tensor):
        if input_states.ndim == 2:
            matching_scores = torch.einsum('mc,nc->mn', input_states, memory_states)
        else:
            matching_scores = torch.einsum('bmc,bnc->bmn', input_states, memory_states)
        if self.dual_normalization:
            ref_matching_scores = torch.softmax(matching_scores, dim=-1)
            src_matching_scores = torch.softmax(matching_scores, dim=-2)
            matching_scores = ref_matching_scores * src_matching_scores
        return matching_scores
    
    def forward(self, ref_points, src_points, ref_feats, src_feats):
        """
        Args:
            ref_points (Tensor): (B, N, 3)
            src_points (Tensor): (B, M, 3)
            ref_feats (Tensor): (B, N, C)
            src_feats (Tensor): (B, M, C)

        Returns:
            ref_feats: torch.Tensor (B, N, C)
            src_feats: torch.Tensor (B, M, C)
        """
        ref_feats = self.in_proj(ref_feats)
        src_feats = self.in_proj(src_feats)

        for i in range(self.blocks):
            new_ref_feats = self.linear_attentions[i](ref_feats, src_feats)
            new_src_feats = self.linear_attentions[i](src_feats, ref_feats)
            ref_feats, src_feats = new_ref_feats, new_src_feats
        
        return ref_feats, src_feats