import torch
import torch.nn as nn
from timm.models.layers import DropPath, trunc_normal_
import numpy as np
import os
import torch.nn.functional as F
from .build import MODELS
from utils.logger import *
from pytorch3d.ops import sample_farthest_points, knn_points

from extensions.chamfer_dist import ChamferDistanceL1, ChamferDistanceL2, ChamferFunction

def index_points(points, idx):
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points

class Encoder(nn.Module):   ## Embedding module
    def __init__(self, encoder_channel):
        super().__init__()
        self.encoder_channel = encoder_channel
        self.first_conv = nn.Sequential(
            nn.Conv1d(3, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1)
        )
        self.second_conv = nn.Sequential(
            nn.Conv1d(512, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, self.encoder_channel, 1)
        )

    def forward(self, point_groups):
        '''
            point_groups : B G N 3
            -----------------
            feature_global : B G C
        '''
        bs, g, n , _ = point_groups.shape

        point_groups = point_groups.reshape(bs * g, n, 3)
        # encoder
        feature = self.first_conv(point_groups.transpose(2,1))  # BG 256 n
        feature_global = torch.max(feature,dim=2,keepdim=True)[0]  # BG 256 1
        feature = torch.cat([feature_global.expand(-1,-1,n), feature], dim=1)# BG 512 n
        feature = self.second_conv(feature) # BG 1024 n
        feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024
        return feature_global.reshape(bs, g, self.encoder_channel)


class Group(nn.Module):  # FPS + KNN
    def __init__(self, num_group, group_size):
        super().__init__()
        self.num_group = num_group
        self.group_size = group_size
        # self.knn = KNN(k=self.group_size, transpose_mode=True)

    def forward(self, xyz):
        '''
            input: B N 3
            ---------------------------
            output: B G M 3
            center : B G 3
        '''
        batch_size, num_points, _ = xyz.shape
        center, _ = sample_farthest_points(xyz, K=self.num_group) # [B, npoint, 3]  [B, npoint]
        _, idx, _ = knn_points(center, xyz, K=self.group_size, return_nn=False) # [B, npoint, k]

        assert idx.size(1) == self.num_group
        assert idx.size(2) == self.group_size
        idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
        idx = idx + idx_base
        idx = idx.view(-1)
        neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]
        neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous()
        # normalize
        neighborhood = neighborhood - center.unsqueeze(2)
        return neighborhood, center

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)

        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        
    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
        super().__init__()
        
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, 
                drop_path = drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate
                )
            for i in range(depth)])

    def forward(self, x, pos):
        for _, block in enumerate(self.blocks):
            x = block(x + pos)
        return x


class TransformerDecoder(nn.Module):
    def __init__(self, embed_dim=384, depth=4, num_heads=6, mlp_ratio=4., qkv_bias=False, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm):
        super().__init__()
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate,
                drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate
            )
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        self.head = nn.Identity()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, pos):
        for _, block in enumerate(self.blocks):
            x = block(x + pos)

        x = self.head(self.norm(x))  # only return the mask tokens predict pixel
        return x

class MaskTransformer(nn.Module):
    def __init__(self, config, **kwargs):
        super().__init__()
        self.config = config
        # define the transformer argparse
        self.mask_ratio = config.transformer_config.mask_ratio 
        self.trans_dim = config.transformer_config.trans_dim
        self.depth = config.transformer_config.depth 
        self.drop_path_rate = config.transformer_config.drop_path_rate
        self.num_heads = config.transformer_config.num_heads 
        print_log(f'[args] {config.transformer_config}', logger = 'Transformer')
        # embedding
        self.encoder_dims =  config.transformer_config.encoder_dims
        self.encoder = Encoder(encoder_channel = self.encoder_dims)

        self.mask_type = config.transformer_config.mask_type

        dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
        self.blocks = TransformerEncoder(
            embed_dim = self.trans_dim,
            depth = self.depth,
            drop_path_rate = dpr,
            num_heads = self.num_heads,
        )

        self.norm = nn.LayerNorm(self.trans_dim)
        self.apply(self._init_weights)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
        trunc_normal_(self.mask_token, std=.02)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv1d):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def _mask_center_rand(self, center, noaug = False):
        '''
            center : B G 3
            --------------
            mask : B G (bool)
        '''
        B, G, _ = center.shape
        # skip the mask
        if noaug or self.mask_ratio == 0:
            return torch.zeros(center.shape[:2]).bool()

        self.num_mask = int(self.mask_ratio * G)

        overall_mask = np.zeros([B, G])
        for i in range(B):
            mask = np.hstack([
                np.zeros(G-self.num_mask),
                np.ones(self.num_mask),
            ])
            np.random.shuffle(mask)
            overall_mask[i, :] = mask
        overall_mask = torch.from_numpy(overall_mask).to(torch.bool)

        return overall_mask.to(center.device) # B G

    def forward(self, neighborhood, center, pos, shifted_feature=None, get_encoder=False):
        # generate mask
        if self.training:
            bool_masked_pos = self._mask_center_rand(center)  # B 4G
        else:
            vis = torch.zeros((center.shape[0], (center.shape[1] // 4) * 3))
            mask = torch.ones((center.shape[0], center.shape[1] // 4))
            bool_masked_pos = torch.cat([vis, mask], dim=1).to(torch.bool)  # B 4G

        group_input_tokens = self.encoder(neighborhood)  #  B 4G C
        if get_encoder:
            return group_input_tokens

        batch_size, seq_len, C = group_input_tokens.size()
        feature_len = int(seq_len/4)
        
        # test-time feature shifting
        if shifted_feature is not None:
            group_feature_front = group_input_tokens[:, : feature_len, :]
            group_feature_back = group_input_tokens[:, 2*feature_len: , :]
            group_input_tokens = torch.cat([group_feature_front, shifted_feature, group_feature_back], dim=1)

        mask_token = self.mask_token.expand(batch_size, seq_len, -1)  # B 4G C
        m = bool_masked_pos.unsqueeze(-1).type_as(mask_token).reshape(batch_size, seq_len, 1)  # B 4G 1
        x = group_input_tokens * (1 - m) + mask_token * m  # B 4G C

        pos = pos.to(x.device)

        # transformer
        x = self.blocks(x, pos)
        x = self.norm(x)

        return x, bool_masked_pos, group_input_tokens


@MODELS.register_module()
class PCoTTA(nn.Module):
    def __init__(self, config):
        super().__init__()
        print_log(f'[PCoTTA] ', logger ='PCoTTA')
        self.config = config
        self.trans_dim = config.transformer_config.trans_dim
        self.MAE_encoder = MaskTransformer(config)
        self.group_size = config.group_size
        self.num_group = config.num_group
        self.drop_path_rate = config.transformer_config.drop_path_rate

        self.decoder_depth = config.transformer_config.decoder_depth
        self.decoder_num_heads = config.transformer_config.decoder_num_heads
        dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.decoder_depth)]
        self.MAE_decoder = TransformerDecoder(
            embed_dim=self.trans_dim,
            depth=self.decoder_depth,
            drop_path_rate=dpr,
            num_heads=self.decoder_num_heads,
        )

        print_log(f'[PCoTTA] divide point cloud into G{self.num_group} x S{self.group_size} points ...', logger ='PCoTTA')
        self.group_divider = Group(num_group = self.num_group, group_size = self.group_size)

        # prediction head
        self.increase_dim = nn.Sequential(
            nn.Conv1d(self.trans_dim, 3*self.group_size, 1)
        )
        self.pos_sincos = self.get_positional_encoding(4 * self.num_group, self.trans_dim)

        self.loss = config.loss
        # loss
        self.build_loss_func(self.loss)

    def cd_loss(self, xyz1, xyz2):
        dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
        dist1 = dist1.cuda()
        dist2 = dist2.cuda()
        return torch.mean(dist1, dim=-1) + torch.mean(dist2, dim=-1)

    def build_loss_func(self, loss_type):
        if loss_type == "cdl1":
            self.loss_func = ChamferDistanceL1().cuda()
        elif loss_type =='cdl2':
            self.loss_func = ChamferDistanceL2().cuda()
        else:
            raise NotImplementedError

    def combine_(self, pc1, pc2, pc3, pc4):
        return torch.cat([pc1, pc2, pc3, pc4], dim=1)

    def get_positional_encoding(self, max_seq_len, embed_dim):
        positional_encoding = torch.zeros(max_seq_len, embed_dim)
        for pos in range(max_seq_len):
            for i in range(embed_dim):
                if i % 2 == 0:
                    positional_encoding[pos, i] = torch.sin(pos / torch.tensor(10000 ** (2 * i / embed_dim)))
                else:
                    positional_encoding[pos, i] = torch.cos(pos / torch.tensor(10000 ** (2 * i / embed_dim)))
        return positional_encoding

    def get_patch(self, pc1, pc2, target1, target2):
        pc1_center, pc1_center_idx = sample_farthest_points(pc1, K=self.num_group)
        pc2_center, pc2_center_idx = sample_farthest_points(pc2, K=self.num_group)

        _, pc1_neighborhood_idx, _ = knn_points(pc1_center, pc1, K=self.group_size, return_nn=False)
        _, pc2_neighborhood_idx, _ = knn_points(pc2_center, pc2, K=self.group_size, return_nn=False)
        pc1_neighborhood = index_points(pc1, pc1_neighborhood_idx)
        pc2_neighborhood = index_points(pc2, pc2_neighborhood_idx)

        target1_center = index_points(target1, pc1_center_idx)
        target2_center = index_points(target2, pc2_center_idx)
        _, target1_neighborhood_idx, _ = knn_points(target1_center, target1, K=self.group_size, return_nn=False)
        _, target2_neighborhood_idx, _ = knn_points(target2_center, target2, K=self.group_size, return_nn=False)
        target1_neighborhood = index_points(target1, target1_neighborhood_idx)
        target2_neighborhood = index_points(target2, target2_neighborhood_idx)

        return pc1_center, pc1_neighborhood, pc2_center, pc2_neighborhood, target1_center, target1_neighborhood, target2_center, target2_neighborhood

    def forward(self, pc1, pc2, target1, target2, shifted_feature=None, get_encoder=False, **kwargs):
        pc1_center, pc1_neighborhood, pc2_center, pc2_neighborhood, target1_center, \
        target1_neighborhood, target2_center, target2_neighborhood = self.get_patch(pc1, pc2, target1, target2)

        center = self.combine_(pc1_center, target1_center, pc2_center, target2_center)
        neighborhood = self.combine_(pc1_neighborhood, target1_neighborhood, pc2_neighborhood, target2_neighborhood)

        if get_encoder:
            x = self.MAE_encoder(neighborhood, center, self.pos_sincos, shifted_feature, get_encoder)
            return x
        
        x, mask, latent_feature = self.MAE_encoder(neighborhood, center, self.pos_sincos, shifted_feature, get_encoder)

        B, _, C = x.shape

        center_mask = center[mask].reshape(B, -1, 3)

        pos = self.pos_sincos.to(x.device)
        x = self.MAE_decoder(x, pos)

        x_rec = x[mask].reshape(B, -1, C)
        B, M, C = x_rec.shape
        
        rebuild_points = self.increase_dim(x_rec.transpose(1, 2)).transpose(1, 2).reshape(B * M, -1, 3)  # B M 1024
        gt_points = neighborhood[mask].reshape(B*M,-1,3)

        loss = self.loss_func(rebuild_points, gt_points)

        gt_points = gt_points.reshape(B, M * self.group_size, 3) # B M*groupsize 3

        rebuild_points = rebuild_points.reshape(B, M * self.group_size, 3) # B M*groupsize 3

        loss_pc = self.cd_loss(rebuild_points, gt_points)

        return latent_feature, rebuild_points, loss, loss_pc


@MODELS.register_module()
class GSSM(nn.Module):
    def __init__(self, config):
        super().__init__()
        print_log(f'[Gaussian Splatted Shfting Module] ', logger ='GSSM')
        self.config = config
        self.task_num = 3
        self.source_num = 2
        self.target_num = 2
        self.feature_dim = 384
        self.prototype_path = os.path.join('experiments/test', 'feature_dict.pth')
        self.sources_prototypes_all = self.get_sources_prototypes(self.prototype_path)
        self.sources_prototypes_mean = self.sources_prototypes_all.mean(dim=0, keepdim=True)
        self.learnable_prototypes_all = nn.Parameter(self.sources_prototypes_mean.detach().clone().repeat(self.target_num, 1, 1, 1), requires_grad=True)

        self.gaussian_atten = nn.Sequential(nn.LeakyReLU(negative_slope=0.2),
                                    nn.Conv2d(self.feature_dim, 1, kernel_size=1, bias=False),
                                    nn.Softmax(dim=-2)) # domains level

        self.loss = nn.CrossEntropyLoss()
    
    def get_sources_prototypes(self, path):
        sources_prototypes_dict = torch.load(path)
        sources_prototypes = [prototype for prototype in sources_prototypes_dict['prototype']]
        sources_prototypes = torch.stack(sources_prototypes) # [2*3 -> domain*task, patch num. C]
        N, G, C = sources_prototypes.size()
        sources_prototypes = sources_prototypes.view(self.source_num, self.task_num, G, C)
        return sources_prototypes

    def feature_normailize(self, x):
        min_vals, _ = x.min(dim=-1, keepdim=True)
        max_vals, _ = x.max(dim=-1, keepdim=True)
        normalized_x = (x - min_vals) / (max_vals - min_vals)

        return normalized_x

    def get_similarity(self, feature_1, feautre_2, normalize=True):
        B, G, C = feature_1.size()
        N = feautre_2.size(1)

        feature_1 = feature_1.unsqueeze(1).repeat(1, N, 1, 1).view(B*N, G, C)
        feature_2 = feautre_2.reshape(B*N, G, C)

        if normalize:
            feature_1 = F.normalize(feature_1, dim=-1)
            feature_2 = F.normalize(feature_2, dim=-1)
        
        simi = torch.matmul(feature_1, feature_2.permute(0, 2, 1)) # B*N, G, G
        simi = torch.diagonal(simi, dim1=-2, dim2=-1)

        simi_mean = torch.mean(simi, dim=-1)
        simi_max = torch.max(simi, dim=-1)[0]
        simi = simi_mean + simi_max

        simi = simi.view(B, N)

        return simi


    def gaussian_kernel_2d(self, x, y, mu_x=0.0, mu_y=0.0, sigma_x=1.0, sigma_y=1.0):
        exponent = -0.5 * ((x - mu_x)**2 / sigma_x**2 + (y - mu_y)**2 / sigma_y**2)
        gaussian = torch.exp(exponent) / (2 * torch.pi * sigma_x * sigma_y)
    
        return gaussian

    def gaussian_splatted_weight(self, x, sources_prototypes, learnable_prototypes):
        B, G, C = x.shape
        S = sources_prototypes.size(1)
        T = learnable_prototypes.size(1)

        simi_S = self.get_similarity(x, sources_prototypes)
        simi_T = self.get_similarity(x, learnable_prototypes)
 
        simi_S = simi_S[:, :, None]
        simi_T = simi_T[:, None, :]

        mean_S = simi_S.mean(dim=0)
        std_S = simi_S.std(dim=0)
        mean_T = simi_T.mean(dim=0)
        std_T = simi_T.mean(dim=0)

        gaussian = self.gaussian_kernel_2d(simi_S, simi_T, mean_S, mean_T, std_S, std_T).view(B, -1) # B, S, T

        return gaussian, simi_S, simi_T


    def get_mixup_weight(self, simi_S, simi_T, alpha=1.0):
        return (simi_S / (simi_S + simi_T)) * alpha

    def prototypes_mixup(self, sources_prototypes, learnable_prototypes, simi_S, simi_T):
        mixup_weight = self.get_mixup_weight(simi_S, simi_T)
        mixup_weight = mixup_weight[:, :, :, None, None]  # B, S, T, 1, 1
        sources_prototypes = sources_prototypes[:, :, None]
        learnable_prototypes = learnable_prototypes[:, None, :]
        mixed_prototypes = mixup_weight * sources_prototypes + (1-mixup_weight) * learnable_prototypes
        B, S, T, G, C = mixed_prototypes.size()
        return mixed_prototypes.view(B, -1, G, C)
        
    def repulsion_loss(self, x, learnable_prototypes, simi, t=0.07):
        simi = simi.squeeze()
        B, G, C = x.shape
        N = learnable_prototypes.size(1)

        # nearest learnable prototypes
        nearest_idx = simi.max(dim=-1)[1]
        # positive pair
        pos_prototypes = learnable_prototypes[torch.arange(B), nearest_idx, :]. view(B, 1, G, C)
        # negative pairs
        mask = torch.ones(B, N, dtype=torch.bool)
        mask[torch.arange(B), nearest_idx] = False
        remaining_idx = torch.nonzero(mask, as_tuple=False)
        neg_prototypes = learnable_prototypes[remaining_idx[:, 0], remaining_idx[:, 1], :].view(B, N-1, G, C)

        x = x.view(B*G, 1, C)
        pos_prototypes = pos_prototypes.permute(0, 2, 3, 1).reshape(B*G, C, 1)
        neg_prototypes = neg_prototypes.permute(0, 2, 3, 1).reshape(B*G, C, N-1)

        # get logits
        l_pos = torch.matmul(x, pos_prototypes).squeeze(1)
        l_neg = torch.matmul(x, neg_prototypes).squeeze(1)

        logits = torch.cat([l_pos, l_neg], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
        # aply temperature
        logits = logits / t
        repul_loss = self.loss(logits, labels)
        
        return repul_loss

    def forward(self, x, task, **kwargs):
        task_id = torch.zeros([x.size(0)], dtype=torch.int).to(x.device)
        for i in range(x.size(0)):
            if task[i] == 'reconstruction':
                task_id[i] = 0
            elif task[i] == 'denoising':
                task_id[i] = 1
            elif task[i] == 'registration':
                task_id[i] = 2
        
        sources_prototypes = self.sources_prototypes_all[:, task_id].permute(1, 0, 2, 3)  # N, T, G, C -> B, N, G, C
        learnable_prototypes = self.learnable_prototypes_all[:, task_id].permute(1, 0, 2, 3)

        gaussian_weight, simi_S, simi_T = self.gaussian_splatted_weight(x, sources_prototypes, learnable_prototypes)
        mixed_prototypes = self.prototypes_mixup(sources_prototypes, learnable_prototypes, simi_S, simi_T)

        # gaussian-based graph attention
        gaussian_weight = gaussian_weight[:, :, None, None]
        gaussian_weight_max = gaussian_weight.max() + 0.1
        gaussian_prototypes = (gaussian_weight_max-gaussian_weight) * mixed_prototypes  # Gaussian is inverse function
        weight_atten = self.gaussian_atten(gaussian_prototypes.permute(0, 3, 1, 2))
        weight_atten = weight_atten.permute(0, 2, 3, 1)

        # shifting
        x_repeat = x.unsqueeze(1).repeat(1, self.source_num * self.target_num, 1, 1)
        x_repeat = (1-weight_atten) * x_repeat + weight_atten * mixed_prototypes
        x = x_repeat.mean(dim=1, keepdim=False)
        
        repul_loss = self.repulsion_loss(x, learnable_prototypes, simi_T)

        return x, repul_loss, simi_S