# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------

from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.models.vision_transformer import PatchEmbed, Mlp, DropPath

from util.pos_embed import get_2d_sincos_pos_embed

class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma
    
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, return_attn_map = False):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        if return_attn_map:
            qk_attn = attn.clone().detach()
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x_ctxed = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x_ctxed)
        x = self.proj_drop(x)
        
        if return_attn_map:
            return x, [qk_attn, x_ctxed]
        return x
    
class Block(nn.Module):

    def __init__(
            self,
            dim,
            num_heads,
            mlp_ratio=4.,
            qkv_bias=False,
            drop=0.,
            attn_drop=0.,
            init_values=None,
            drop_path=0.,
            act_layer=nn.GELU,
            norm_layer=nn.LayerNorm
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x, return_attn_map = False):
        if return_attn_map:
            x_tmp, qk_and_x = self.attn(self.norm1(x), return_attn_map = True)
            # returned attn is Q @ K / scale, before softmax
        else:
            x_tmp = self.attn(self.norm1(x))
        x = x + self.drop_path1(self.ls1(x_tmp))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        
        if return_attn_map:
            return x, qk_and_x
        return x
    
class MaskedAutoencoderViT(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 embed_dim=1024, depth=24, num_heads=16,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
        super().__init__()

        # --------------------------------------------------------------------------
        # MAE encoder specifics
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embedding

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAE decoder specifics
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)  # fixed sin-cos embedding

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for i in range(decoder_depth)])
        """
        model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
        """
        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
        # --------------------------------------------------------------------------

        self.norm_pix_loss = norm_pix_loss

        self.initialize_weights()
        
    def initialize_weights(self):
        # initialization
        # initialize (and freeze) pos_embed by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def patchify(self, imgs):
        """
        imgs: (N, 3, H, W)
        x: (N, L, patch_size**2 *3)
        """
        p = self.patch_embed.patch_size[0]
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
        return x

    def unpatchify(self, x):
        """
        x: (N, L, patch_size**2 *3)
        imgs: (N, 3, H, W)
        """
        p = self.patch_embed.patch_size[0]
        h = w = int(x.shape[1]**.5)
        assert h * w == x.shape[1]
        
        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
        return imgs
    
    def return_posembed(self):
        return self.pos_embed, self.decoder_pos_embed
    
    def random_masking(self, x, mask_ratio, mask_in = None, cluster_size = None, hint_ratio = None, hint_portion = 0.5, 
                       hint_prob=False, uniform_prob=False, prob_mask=False):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))
        
        if mask_in is None:
            noise = torch.rand(N, L, device=x.device)  # noise in [0, 1], sort noise for each sample
            ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        else:
            if prob_mask:
                attn_score = mask_in
                min_val = attn_score.min(dim=-1)[0]
                min_val = torch.where(min_val<0, min_val, 0)
                attn_score = attn_score - min_val.unsqueeze(-1)
                attn_score = attn_score + 1e-10
                
                ids_shuffle = torch.multinomial(attn_score.float(), attn_score.shape[1])
                ids_shuffle = torch.flip(ids_shuffle, [1])
            else:
                if hint_prob:
                    attn_score = mask_in
                    ids_shuffle = torch.argsort(
                        attn_score, dim=-1, descending=False
                    ) 
                    cluster_size = int(hint_portion*L) # 0.75 mask 중 어느 정도 상위 token에서 hint를 줄지
                    ids_shuffle = torch.flip(ids_shuffle, [1])
                    hint_idx = ids_shuffle[:,:cluster_size]
                    hint_cand = torch.gather(attn_score, dim=1, index = hint_idx)
                    
                    hint_num = int(hint_ratio * L)
                    min_val = hint_cand.min(dim=-1)[0] # new_B
                    min_val = torch.where(min_val<0, min_val, 0)
                    hint_cand = hint_cand - min_val.unsqueeze(-1)
                    hint_cand = hint_cand + 1e-10
                    
                    if uniform_prob:
                        hint_cand = torch.ones_like(hint_cand, device = hint_cand.device).float()
                    rand_order = torch.multinomial(hint_cand, cluster_size)
                    picked_idx, else_idx = rand_order[:,:hint_num], rand_order[:,hint_num:]
                    
                    picked_tokens = torch.gather(ids_shuffle, dim=1, index = picked_idx)
                    else_tokens = torch.gather(ids_shuffle, dim=1, index = else_idx)
                    ids_shuffle = torch.cat([else_tokens, ids_shuffle[:,cluster_size:], picked_tokens], dim=-1)
                    ids_shuffle = torch.flip(ids_shuffle, [1])
                else:
                    ids_shuffle = mask_in
                    if hint_ratio is not None:
                        cluster_size = int(hint_portion*L) # 0.75 mask 중 어느 정도 상위 token에서 hint를 줄지
                        hint_num = int(hint_ratio * cluster_size)
                        rand_order = torch.rand(N, cluster_size, device=ids_shuffle.device).argsort(dim=1)
                        picked_idx, else_idx = rand_order[:,:hint_num], rand_order[:,hint_num:]
                        
                        ids_shuffle = torch.flip(ids_shuffle, [1])
                        picked_tokens = torch.gather(ids_shuffle, dim=1, index = picked_idx)
                        else_tokens = torch.gather(ids_shuffle, dim=1, index = else_idx)
                        ids_shuffle = torch.cat([else_tokens, ids_shuffle[:,cluster_size:], picked_tokens], dim=-1)
                        ids_shuffle = torch.flip(ids_shuffle, [1])
        
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        # keep the first subset
        ids_shuffle = ids_shuffle.type(torch.int64)
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio, return_attn_map = False, mask_in = None, cluster_size = None, 
                        hint_ratio = None, hint_portion = 0.5, hint_prob=False, uniform_prob=False, prob_mask=False):
        # embed patches
        x = self.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio, mask_in, cluster_size = cluster_size, hint_ratio = hint_ratio, 
                                                   hint_portion = hint_portion, hint_prob=hint_prob, uniform_prob=uniform_prob, prob_mask=prob_mask)
        # x, mask, ids_restore = self.random_masking(x, mask_ratio, mask_in, cluster_size = cluster_size, hint_token_num = hint_token_num)
        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        # apply Transformer blocks
        if return_attn_map:
            attn_map = []
            for idx, blk in enumerate(self.blocks):
                x, attn = blk(x, True)
                attn_map.append(attn[1])
        else:
            for blk in self.blocks:
                x = blk(x, False)
                
        x = self.norm(x)

        if return_attn_map:
            return x, mask, ids_restore, attn_map
        return x, mask, ids_restore
    
    def forward_encoder_inference(self, x, target_attn, mask_ratio = 0.75, ref_cluster = 'large', return_score=False, 
                                  get_mask_stat = False, get_feat = False, force_flip=False):
        with torch.no_grad():
            x = self.patch_embed(x)
            x = x + self.pos_embed[:, 1:, :]
            cls_token = self.cls_token + self.pos_embed[:, :1, :]
            cls_tokens = cls_token.expand(x.shape[0], -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)
            
            for i, blk in enumerate(self.blocks):
                if i == target_attn:
                    x, attn = blk(x, True)
                    _, attn = attn
                    break
                else:
                    x, _ = blk(x, True)
            
            if target_attn > 11.0:
                print('Generate mask from decoder')
                x = self.norm(x)
                x = self.decoder_embed(x)
                x = x + self.decoder_pos_embed
                for i, blk in enumerate(self.decoder_blocks):
                    x, attn = blk(x, True)
                _, attn = attn
                    
            attn = attn[:,1:,:] # B 1+N D -> B N D
            if get_feat:
                return attn
            
            a_mat = self.get_affinity_mat(attn)
            if get_mask_stat:
                return self.get_graph_cut_stat(a_mat, attn)
            new_ids_shuffle, ref_cluster_size = self.graph_cut(a_mat, 
                            attn, ref_cluster = ref_cluster, return_score=return_score, force_flip=force_flip)
            return new_ids_shuffle, ref_cluster_size
                    
    def return_attn_score_KLD(self, imgs, mask_ratio = 0.6):
        # Encoder output, [blk_num B N+1 C]
        x = self.patch_embed(imgs)
        x = x + self.pos_embed[:, 1:, :]
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        enc_qk = []
        enc_contextualized = []
        for blk in self.blocks:
            x, attn = blk(x, True) # attn = [qk_attn, contextualized x (=before mlp)][1]
            enc_qk.append(attn[0].mean(1)[:,1:,1:]) # B H 1+196 1+196 -> B 196 196
            enc_contextualized.append(attn[1][:,1:,:]) # B 1+196 786 -> B 196 786
        enc_contextualized = torch.stack(enc_contextualized, dim=0) # blk B 196 786
        enc_qk = torch.stack(enc_qk, dim=0) # blk B 196 196
        
        # Decoder output, [blk_num B N+1 C]
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio, return_attn_map=False)
        dec_qk, dec_contextualized = self.forward_decoder(latent, ids_restore, return_attn_map=True)  
        dec_contextualized = torch.stack(dec_contextualized, dim=0) # blk B 1+196 786
        dec_contextualized = dec_contextualized[:,1:,:] # blk B 1+196 786 -> blk B 196 786
        dec_qk = torch.stack(dec_qk, dim=0) # blk B H 1+196 1+196
        dec_qk = dec_qk.mean(2)[:,:,1:,1:] # blk B H 1+196 1+196 -> blk B 196 196

        # Construct affinity matrix, [blk_num B N C] -> [blk_num B N N]
        enc_affinity = self.get_affinity_mat(enc_contextualized)
        dec_affinity = self.get_affinity_mat(dec_contextualized)
        
        # Affinity matrix to attention map, [blk_num B N N] -> [blk_num B N]
        enc_attn = torch.sum(enc_affinity, dim=-1).squeeze() # blk_num B 196
        dec_attn = torch.sum(dec_affinity, dim=-1).squeeze() # blk_num B 196
        enc_qk = enc_qk.sum(-1) # blk_num B 196
        dec_qk = dec_qk.sum(-1) # blk_num B 196
        
        return [enc_attn, dec_attn, enc_qk, dec_qk]
    
    def fourier_return(self, imgs):
        x = self.patch_embed(imgs)
        x = x + self.pos_embed[:, 1:, :]
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        attn_map = []
        attn_map.append(x)
        for blk in self.blocks:
            x, attn = blk(x, True)
            attn_map.append(attn[1])
            attn_map.append(x)
            
        x = self.norm(x)
        x = self.decoder_embed(x)
        x = x + self.decoder_pos_embed
        for blk in self.decoder_blocks:
            x, attn = blk(x, True)
            attn_map.append(attn[1])
            attn_map.append(x)
        return attn_map
        
    def return_attn_score(self, imgs, mask_ratio=0.6, from_where = 'inmae', rnd_cls = None, return_mask=False, do_norm=False):
        assert from_where in ['encoder', 'decoder', 'inmae'], 'Check from_where argument.'
        if from_where in ['encoder', 'decoder']:
            x = self.patch_embed(imgs)
            x = x + self.pos_embed[:, 1:, :]
            if rnd_cls is not None:
                if rnd_cls == 'rnd':
                    print('Random normal CLS')
                    new_cls = torch.rand(1,1,768)
                elif rnd_cls == 'rnd_same_dist':
                    print('Random normal CLS w/ same dist')
                    mean, std = self.cls_token.squeeze().mean().item(), self.cls_token.squeeze().std().item()
                    new_cls = torch.normal(mean, std, size=(1, 1, 768))
                    print(mean, std)
                elif rnd_cls == 'one':
                    print('Ones CLS')
                    new_cls = torch.ones(1,1,768) * self.cls_token.squeeze().mean().item()
                elif rnd_cls == 'zero':
                    print('Zeros CLS')
                    new_cls = torch.zeros(1,1,768)
                cls_token = new_cls + self.pos_embed[:, :1, :]
            else:
                # print('Learned CLS')
                cls_token = self.cls_token + self.pos_embed[:, :1, :]
            cls_tokens = cls_token.expand(x.shape[0], -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)

            attn_map = []
            for blk in self.blocks:
                x, attn = blk(x, True)
                attn_map.append([x] + attn) # [final x, qk_attn, contextualized x (=before mlp)]
            
            if from_where == 'encoder':
                return attn_map
            
            x = self.norm(x)
            x = self.decoder_embed(x)
            x = x + self.decoder_pos_embed
            for blk in self.decoder_blocks:
                x, attn = blk(x, True)
                attn_map.append([x] + attn) # [final x, qk_attn, contextualized x (=before mlp)]
            return attn_map
        
        elif from_where == 'inmae':
            latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio, return_attn_map=False)
            attn_map = self.forward_decoder(latent, ids_restore, return_attn_map=True)  # [N, L, p*p*3]
            if return_mask: return mask, attn_map
            return attn_map
        
    def forward_decoder(self, x, ids_restore, return_attn_map = False, is_informed_mask  = False):
        # embed tokens
        x = self.decoder_embed(x)

        # append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

        # add pos embed
        x = x + self.decoder_pos_embed

        # apply Transformer blocks
        if return_attn_map:
            assert is_informed_mask is False, 'return_attn_map is set to True. Check it.'
            attn_map = []
            for blk in self.decoder_blocks:
                x, attn = blk(x, True)
                attn_map.append([x] + attn)
            return attn_map
        else:
            for blk in self.decoder_blocks:
                x = blk(x, False)
            
        x = self.decoder_norm(x)

        # predictor projection
        x = self.decoder_pred(x)

        # remove cls token
        x = x[:, 1:, :]
        return x

    def forward_loss(self, imgs, pred, mask):
        """
        imgs: [N, 3, H, W]
        pred: [N, L, p*p*3]
        mask: [N, L], 0 is keep, 1 is remove, 
        """
        target = self.patchify(imgs)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5

        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss
    
    #######################################################
    def l1norm(self, prev, cur):
        # B 196
        B, _ = cur.shape
        if prev is None:
            val = 196 * torch.ones(B, device = cur.device)
            return val
        return torch.sum(abs(cur-prev), dim=-1).squeeze() # B
    
    def ids_shuffle_to_mask(self, ids_shuffle, mask_ratio):
        if ids_shuffle is None:
            return None
        N, L = ids_shuffle.shape  # batch, length
        len_keep = int(L * (1 - mask_ratio))
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        
        mask = torch.ones([N, L], device=ids_shuffle.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)
        return mask # B L
    
    def get_graph_cut_stat(self, a_mat, ctxd_x, tau = 0.2):
        eps = 1e-7
        B, N, _ = ctxd_x.shape
        
        a_mat = torch.where(a_mat <= tau, eps, a_mat) # B N N
        d_i = torch.sum(a_mat, axis=-1) # B N
        D = torch.diag_embed(d_i) # B N N
        L = D - a_mat # B N N
        
        _, eigenvectors = torch.linalg.eigh(L)
        second_vec = eigenvectors[:, :, 1] # B N
        avg = torch.sum(second_vec, dim=-1) / second_vec.shape[1] # B
        bipartition = second_vec > avg.unsqueeze(1) # B N > B 1
        
        stat = []
        for ref_cluster in ['large', 'small', 'eigen', 'complexity']:
            if ref_cluster == 'large':
                need_flip = bipartition.long().sum(dim=-1) < N//2
            elif ref_cluster == 'small':
                need_flip = bipartition.long().sum(dim=-1) >= N//2
            elif ref_cluster == 'eigen':
                max_element = torch.argmax(torch.abs(second_vec), dim=-1)
                max_element_cluster = torch.gather(bipartition, 1, max_element.unsqueeze(-1))
                need_flip = max_element_cluster != 1 # B 1
                need_flip = need_flip.squeeze(1)
            elif ref_cluster == 'complexity':
                num_cluster_1 = bipartition.sum(dim=-1)
                num_cluster_2 = (~bipartition).sum(dim=-1)
                
                cluster_1 = a_mat * bipartition.unsqueeze(-1) * bipartition.unsqueeze(-2)
                cluster_1_diag = torch.diagonal(cluster_1, dim1=-2, dim2=-1).sum(-1)
                sum_cluster_1 = (cluster_1.sum(-1).sum(-1) - cluster_1_diag) / 2 # B
                mean_cluster_1 = sum_cluster_1 / (num_cluster_1*(num_cluster_1 -1) / 2) # B
                
                cluster_2 = a_mat * ~bipartition.unsqueeze(-1) * ~bipartition.unsqueeze(-2)
                cluster_2_diag = torch.diagonal(cluster_2, dim1=-2, dim2=-1).sum(-1)
                sum_cluster_2 = (cluster_2.sum(-1).sum(-1) - cluster_2_diag) / 2 # B
                mean_cluster_2 = sum_cluster_2 / (num_cluster_2*(num_cluster_2 -1) / 2) # B
            
                need_flip = mean_cluster_1 > mean_cluster_2
            stat.append(need_flip.squeeze())
        return torch.stack(stat, dim=0)
    
    def graph_cut(self, a_mat, ctxd_x, tau = 0.2, ref_cluster = 'large', return_score=False, force_flip=False):
        assert ref_cluster in ['small', 'large', 'eigen', 'complexity_high', 'complexity_low'], 'Incorrect ref_cluster'
        eps = 1e-7
        B, N, _ = ctxd_x.shape
        
        a_mat = torch.where(a_mat <= tau, eps, a_mat) # B N N
        d_i = torch.sum(a_mat, axis=-1) # B N
        D = torch.diag_embed(d_i) # B N N
        L = D - a_mat # B N N
        
        _, eigenvectors = torch.linalg.eigh(L)
        second_vec = eigenvectors[:, :, 1] # B N
        avg = torch.sum(second_vec, dim=-1) / second_vec.shape[1] # B
        bipartition = second_vec > avg.unsqueeze(1) # B N > B 1
        
        if ref_cluster == 'large':
            need_flip = bipartition.long().sum(dim=-1) < N//2
        elif ref_cluster == 'small':
            need_flip = bipartition.long().sum(dim=-1) > N//2
            if force_flip: need_flip = ~need_flip
        elif ref_cluster == 'eigen':
            max_element = torch.argmax(torch.abs(second_vec), dim=-1)
            max_element_cluster = torch.gather(bipartition, 1, max_element.unsqueeze(-1))
            need_flip = max_element_cluster != 1 # B 1
            need_flip = need_flip.squeeze(1)
        elif 'complexity' in ref_cluster:
            num_cluster_1 = bipartition.sum(dim=-1)
            num_cluster_2 = (~bipartition).sum(dim=-1)
            
            invalid_rows = (num_cluster_1 <= 1) | (num_cluster_2 <= 1)
            # print(f'{invalid_rows.sum().item()} invalid rows')
            corrected_row = torch.zeros_like(bipartition[0], dtype=torch.bool)
            N = bipartition.shape[1]
            corrected_row[N//2], corrected_row[N//2 + 1] = True, True
            bipartition[invalid_rows] = corrected_row
            
            num_cluster_1 = bipartition.sum(dim=-1)
            num_cluster_2 = (~bipartition).sum(dim=-1)
            
            cluster_1 = a_mat * bipartition.unsqueeze(-1) * bipartition.unsqueeze(-2)
            mask1= cluster_1 != 0
            cluster1_var = torch.stack([t[m].var() for t, m in zip(cluster_1, mask1)])
            
            cluster_2 = a_mat * ~bipartition.unsqueeze(-1) * ~bipartition.unsqueeze(-2)
            mask2= cluster_2 != 0
            cluster2_var = torch.stack([t[m].var() for t, m in zip(cluster_2, mask2)])
            
            if 'high' in ref_cluster:
                need_flip = cluster1_var > cluster2_var
            else:
                need_flip = cluster1_var <= cluster2_var
            
            # cluster_1 = a_mat * bipartition.unsqueeze(-1) * bipartition.unsqueeze(-2)
            # cluster_1_diag = torch.diagonal(cluster_1, dim1=-2, dim2=-1).sum(-1)
            # sum_cluster_1 = (cluster_1.sum(-1).sum(-1) - cluster_1_diag) / 2 # B
            # mean_cluster_1 = sum_cluster_1 / (num_cluster_1*(num_cluster_1 -1) / 2) # B
            
            # cluster_2 = a_mat * ~bipartition.unsqueeze(-1) * ~bipartition.unsqueeze(-2)
            # cluster_2_diag = torch.diagonal(cluster_2, dim1=-2, dim2=-1).sum(-1)
            # sum_cluster_2 = (cluster_2.sum(-1).sum(-1) - cluster_2_diag) / 2 # B
            # mean_cluster_2 = sum_cluster_2 / (num_cluster_2*(num_cluster_2 -1) / 2) # B
        
            # need_flip = mean_cluster_1 > mean_cluster_2
        #     print(f'need flip: {need_flip.shape}')
        # print(f'bipartition: {bipartition.shape}')
        bipartition = bipartition * (~need_flip).unsqueeze(1) +\
            torch.logical_not(bipartition) * need_flip.unsqueeze(1)
        # print(f'bipartition: {bipartition.shape}')
        
        if ref_cluster != 'large':
            zero_cluster = torch.sum(bipartition, dim=1) < 1
            if torch.sum(zero_cluster) > 0:
                zero_cluster = zero_cluster.reshape(B,1)
                z = torch.zeros_like(bipartition).long()
                z[:,N//2] += 1
                bipartition = bipartition * torch.logical_not(zero_cluster) + \
                    z * zero_cluster

        ref_cluster = ctxd_x * bipartition.unsqueeze(-1) # [B N D] * [B N 1]
        # print(ref_cluster.shape)
        ref_cluster_patchnum = torch.sum(bipartition.long(), dim=-1) # B
        # print(ref_cluster_patchnum.shape)
        ref_patch = ref_cluster.sum(dim=1)/ref_cluster_patchnum.unsqueeze(-1) # [B D] / [B 1]
    
        attn_score = ref_patch.unsqueeze(1) @ ctxd_x.transpose(-2,-1) # [B 1 D] @ [B D N] = B N
        attn_score = attn_score.squeeze().reshape(B, N) # B N
        if return_score:
            return attn_score, ref_cluster_patchnum.squeeze()
        # len_keep = int(N * (1 - mask_ratio))
        # attn_mask = torch.ones([B, N])
        # attn_mask[:, :len_keep] = 0
        ids_shuffle = torch.argsort(
            attn_score, dim=1, descending=False
        ) 
        
        return ids_shuffle, ref_cluster_patchnum.squeeze()
    
    
    def get_mask(self, to_be_updated, a_mat, ctxd_x, prev_ids_shuffle, mask_ratio, cluster_size, ref_cluster='large'):
        # to_be_updated : B 1 (boolean)
        # a_mat : B N N
        # ctxd_x : B N D
        # prev_ids_shuffle : B N
        B, N, D = ctxd_x.shape
        to_be_updated = to_be_updated.squeeze().nonzero().squeeze() # B
        if to_be_updated.sum() == 0:
            return prev_ids_shuffle, cluster_size
        
        update_sample = a_mat[to_be_updated].reshape(-1, N, N) # new_B N N
        del a_mat
        
        update_ctxd_x = ctxd_x[to_be_updated].reshape(-1, N, D) # new_B N D
        del ctxd_x
        
        new_ids_shuffle, ref_cluster_patchnum = self.graph_cut(update_sample, update_ctxd_x, mask_ratio = mask_ratio, 
                                                               ref_cluster=ref_cluster)
        
        new_ids_shuffle = new_ids_shuffle.to(prev_ids_shuffle.dtype)
        ref_cluster_patchnum = ref_cluster_patchnum.to(cluster_size.dtype)
            
        prev_ids_shuffle[to_be_updated] = new_ids_shuffle
        cluster_size[to_be_updated] = ref_cluster_patchnum
        torch.cuda.empty_cache()
        return prev_ids_shuffle, cluster_size
        
    def get_affinity_mat(self, ctxd_x): # B N D
        attn = F.normalize(ctxd_x, p=2, dim=-1)
        attn = attn @ attn.transpose(-2,-1) # B N N
        return attn
        
    def preprocess_attn(self, a_mat):
        attn = torch.sum(a_mat, dim=-1).squeeze() # B 196
        ids_shuffle = torch.argsort(attn, dim=1)
        binary = self.ids_shuffle_to_mask(ids_shuffle, 0.5)
        return ids_shuffle, binary
    
    def get_ids_shuffle(self, attn):
        # attn: B 196
        return torch.argsort(attn, dim=1)  
    #######################################################
    
    def forward(self, imgs, mask_ratio=0.75, informed_mask = False, mask_in = None, is_informed_mask_started = False, 
                hint_ratio = None, hint_portion=0.5, hint_prob=False, uniform_prob=False, prob_mask = False):
        if informed_mask and is_informed_mask_started:
            if mask_in.shape[1] > 196:
                mask_in, ref_cluster_size = mask_in[:,:196], mask_in[:,196]
            else:
                mask_in = mask_in
                ref_cluster_size = None
            # mask_in = mask_in
        else:
            mask_in, ref_cluster_size = None, None
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio, mask_in = mask_in, cluster_size = ref_cluster_size, 
                                                         hint_ratio = hint_ratio, hint_portion=hint_portion, hint_prob=hint_prob, 
                                                         uniform_prob=uniform_prob, prob_mask=prob_mask)
        # latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio, mask_in = mask_in, cluster_size = ref_cluster_size, 
        #                                                  hint_token_num =hint_token_num)
        pred = self.forward_decoder(latent, ids_restore, is_informed_mask = informed_mask)  # [N, L, p*p*3]
        loss = self.forward_loss(imgs, pred, mask)
        return loss, pred, mask


def mae_vit_base_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

def mae_vit_base_patch16_dec128d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=128, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_large_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_huge_patch14_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


# set recommended archs
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_base_patch16_128 = mae_vit_base_patch16_dec128d8b
