import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import torchvision
import time
import torchist
from utils import *

'''
RRDB 

code is modified from https://github.com/xinntao/ESRGAN/blob/master/RRDBNet_arch.py ESRGAN/RRDBNet_arch.py 

'''
def make_layer(block, n_layers):
    layers = []
    for _ in range(n_layers):
        layers.append(block())
    return nn.Sequential(*layers)


class ResidualDenseBlock_4C(nn.Module):
    def __init__(self, nf=64, gc = 32,bias=True,kernel=3,padding=1,stride=1):
        super(ResidualDenseBlock_4C, self).__init__()
        # gc: growth channel, i.e. intermediate channels

        self.conv1 = nn.Conv2d(nf, gc, kernel_size=kernel, padding='same',stride=stride, bias=bias)
        self.conv2 = nn.Conv2d(nf + gc, gc, kernel_size=kernel, padding='same',stride=stride,bias=bias)
        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, kernel_size=kernel, padding='same',stride=stride, bias=bias)
        self.conv4 = nn.Conv2d(nf + 3 * gc, nf, kernel_size=kernel, padding='same',stride=stride, bias=bias)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        # initialization
        # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
        return x4 * 0.2 + x


class RRDB(nn.Module):
    '''Residual in Residual Dense Block'''

    def __init__(self, nf,kernel=3,padding=1,stride=1):
        super(RRDB, self).__init__()
        gc = nf // 2
        self.RDB1 = ResidualDenseBlock_4C(nf, gc,kernel=kernel,padding=padding,stride=stride)
        self.RDB2 = ResidualDenseBlock_4C(nf, gc,kernel=kernel,padding=padding,stride=stride)
        self.RDB3 = ResidualDenseBlock_4C(nf, gc,kernel=kernel,padding=padding,stride=stride)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out * 0.2 + x

class ConvBlock(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, config,in_channels, out_channels, kernel=3,mid_channels=None,bn=True,motion=False,dilation=1):
        super().__init__()
        self.config = config
        if not mid_channels:
            mid_channels = out_channels

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )  if ((config['use_bn']) or (config['motion_use_bn'] and motion)) else  nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilation, bias=False,dilation=dilation),
            nn.ReLU(inplace=True)
        ) 

    def forward(self, x):
        return self.conv(x)

class ResBlock(nn.Module):
    def __init__(self, config,in_channels, out_channels, downsample=False,upsample=False,skip=False,factor=2,motion=False):
        super().__init__()
        self.upsample = upsample
        self.config = config
        self.maxpool= None
        # (DeformConv(cur_feat_len,cur_feat_len,kernel_size=3,stride=1,padding=1),
        # nn.BatchNorm2d(cur_feat_len),
        # nn.LeakyReLU())
        
        if downsample:
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2),
                nn.BatchNorm2d(out_channels) 
            ) if config['use_bn'] else nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2)
            if factor == 4:
                self.maxpool = nn.MaxPool2d(2)
            
        elif upsample:
            self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=factor, stride=factor)
            self.shortcut = nn.Sequential(nn.Upsample(scale_factor=factor, mode='nearest'),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1),
            nn.BatchNorm2d(out_channels)) \
            if (config['use_bn'] or motion) else nn.Sequential(nn.Upsample(scale_factor=factor, mode='nearest'),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1))

        else:
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
            self.shortcut = None

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        if config['use_bn']:
            self.bn1 = nn.BatchNorm2d(out_channels)
            self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, input):
        if self.shortcut is None:
            shortcut = input.clone()
        else:
            shortcut = self.shortcut(input)
        input = nn.ReLU()(self.bn1(self.conv1(input))) if self.config['use_bn'] else nn.ReLU()(self.conv1(input))
        input = nn.ReLU()(self.bn2(self.conv2(input))) if self.config['use_bn'] else nn.ReLU()(self.conv2(input))
        #if not self.upsample:
        input = input + shortcut
        if not self.maxpool is None:
            input = self.maxpool(input)
        return nn.LeakyReLU()(input)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self,config, in_channels, out_channels, bilinear=True, skip=True,scale=2,bn=True,motion=False):
        super().__init__()
        factor = scale
        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            if skip:
                self.up = nn.Upsample(scale_factor=factor, mode='bilinear', align_corners=True)
                self.conv = ConvBlock(config,in_channels, out_channels,bn=bn)

            else:
                self.up = nn.Upsample(scale_factor=factor, mode='bilinear', align_corners=True)
                self.conv = ConvBlock(config,in_channels*2, out_channels)

        else:
            if skip:
                self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=factor, stride=factor)
                self.conv = ConvBlock(config,out_channels, out_channels,bn=bn,motion=motion)

            else:
                self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=factor, stride=factor)
                self.conv = ConvBlock(config,out_channels*2, out_channels,bn=bn,motion=motion)

    def forward(self, x1, x2=None):

        x1 = self.up(x1)
        if x2 is None:
            return self.conv(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)

        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        

    def forward(self, x):
        return self.conv(x)


class RRDBEncoder(nn.Module):
    def __init__(self, config):
        super(RRDBEncoder, self).__init__()

        if config['feat_shrink']:
            base_channel = config['base_channel'] * 4
            
        else:
            base_channel = config['base_channel']
        bilinear = config['bilinear']
        scale = config['downsample_scale']
        self.scale_num = len(scale)
        n_channel = config['n_channel'] * (config['shuffle_scale']**2)
        self.config = config
        self.inconv = nn.Conv2d(n_channel, base_channel, 3, 1, 1)
        layers = []
        for i in range(config['rrdb_encoder_num']):
            layers.append(RRDB(base_channel,kernel=(1,3)))
        self.block_rrdb = nn.Sequential(*layers)

        pre_downsample_block_list = []
        
        for i in range(self.scale_num-2):
            pre_downsample_block_list.append(ResBlock(config,base_channel * (2**(i)),base_channel* (2**((i+1))),downsample=True,factor=scale[i]))
        self.pre_downsample_block = nn.ModuleList(pre_downsample_block_list)
        
        if self.scale_num >= 2: 
            self.downsample_high = ResBlock(config,base_channel*(2**((self.scale_num-2))),base_channel*(2**((self.scale_num-1))),downsample=True,factor=scale[-2])
        self.downsample_low = ResBlock(config,base_channel*(2**((self.scale_num-1))),base_channel*(2**((self.scale_num))),downsample=True,factor=scale[-1])
        if config['feat_shrink']:
            self.shrink_layer = nn.Sequential(
                nn.Conv2d(base_channel*(2**((self.scale_num))), base_channel*(2**((self.scale_num)))//4, 1, 1, 0),
                nn.LeakyReLU()
            )

    def forward(self, x, save_all=False):
        in_feat = []
        x = self.inconv(x)
        x1 = self.block_rrdb(x)
        x2 = x1
        in_feat.append(x2.clone())
        for i in range(self.scale_num-2):
            x2 = self.pre_downsample_block[i](x2) 
            in_feat.append(x2.clone())
        if self.scale_num >= 2:
            x_high = self.downsample_high(x2)
            in_feat.append(x_high.clone())
        else:
            x_high = x2
        x_low = self.downsample_low(x_high)
        if self.config['feat_shrink']:
            x_low = self.shrink_layer(x_low)
        in_feat.append(x_low.clone())


        for i in range(len(in_feat)-self.config['scale_in_use']):
            in_feat[i] = None


        return in_feat


class ImageEnhancer(nn.Module):
    def __init__(self, config,n_channels=3, base_channel=16,scale=[2,2,2]):
        super(ImageEnhancer, self).__init__()
        self.n_channels = n_channels * 2 if config['res_cat_img'] else n_channels
        layers = [nn.Conv2d(self.n_channels, base_channel,3,1,1)]
        for i in range(config['rrdb_enhance_num']):
            layers.append(RRDB(base_channel))
        self.model = nn.Sequential(*layers)
        self.outconv = OutConv(base_channel,n_channels)
    
    def forward(self, x,):
        feat = self.model(x)
        out = self.outconv(feat)

        return out

class MotionDecoder(nn.Module):
    def __init__(self, config):
        super(MotionDecoder, self).__init__()
        self.config = config
        base_channel = config['pred_base_channel'] + config['tendency_len'] + config['pos_len']
        bilinear = config['bilinear']
        scale = config['downsample_scale']
        self.scale_num = len(scale)
        out_channel = config['out_edge_num'] * 3

        factor = 2 if bilinear else 1
        
        upsample_block_list = []
        if self.config['high_res_only']:
            self.init = nn.Conv2d(base_channel,base_channel,1,1,0)
        for i in range(0,self.scale_num,1):
            upsample_block_list.append(nn.Sequential(
                ResBlock(config,base_channel,base_channel,upsample=False,motion=True),
                ResBlock(config,base_channel,base_channel,upsample=False,motion=True),
                ResBlock(config,base_channel,base_channel,upsample=True,motion=True)
                ))
        self.upsample_block =  nn.ModuleList(upsample_block_list)
        # self.upsampler = 
        if config['deeper_MDecoder']:
            self.outc = nn.Sequential(
                nn.Conv2d(base_channel,base_channel,3,1,1),
                nn.LeakyReLU(),
                OutConv(base_channel, out_channel)
            )
        else:
            self.outc_list = nn.ModuleList()
            for i in range(0,self.scale_num,1):
                self.outc_list.append(OutConv(base_channel, out_channel))

    def upscale_flow(self,flow,shape,scale):

        motion_up = F.interpolate(flow,shape)
        b,c,h,w = motion_up.shape
        motion_up = motion_up.reshape(b,c//3,3,h,w)
        motion_up[:,:,:2] *= scale
        motion_up =motion_up.reshape(b,c,h,w)

        return motion_up


    def forward(self, in_feat):
        x = in_feat
        logits_list = []
        resize_logits_list = []
        
        for i in range(self.scale_num-1,-1,-1):
            if self.config['high_res_only']:
                if i == 0:
                    x = self.init(x)
                cur_flow = self.outc(F.interpolate(x.clone(),scale_factor=2**(i+1)))

                logits_list.append(cur_flow)    
            else:
                if i<self.scale_num-1:
                    motion_up = self.upscale_flow(logits_list[-1],(x.shape[-2:]),self.config['downsample_scale'][i-1])
                    logits_list.append(motion_up + self.outc_list[i](x.clone()))  
                else:
                    
                    logits_list.append(self.outc_list[i](x.clone()))

            x = self.upsample_block[i](x)
        
        motion_up = self.upscale_flow(logits_list[-1],(x.shape[-2:]),self.config['downsample_scale'][i-1])
        logits_list.append(motion_up + self.outc_list[-1](x))
        h,w = logits_list[-1].shape[-2:]
        for i in range(len(logits_list)-1):
            cur_h = logits_list[i].shape[-2]
            resize_logits_list.append(self.upscale_flow(logits_list[i],(h,w),h//cur_h).clone())
        resize_logits_list.append(logits_list[-1].clone())


        return resize_logits_list


class HighResDecoder(nn.Module):
    def __init__(self, config,no_output=False):
        super(HighResDecoder, self).__init__()
        self.config = config
        self.no_output = no_output
        base_channel = config['base_channel']
        bilinear = config['bilinear']
        scale = config['downsample_scale']
        self.scale_num = len(scale)+1 
        out_channel = config['n_channel'] * (config['shuffle_scale']**2)

        factor = 2 if bilinear else 1
        

        fuse_list = nn.ModuleList()
        
        cur_feat_len = base_channel * self.scale_num

        for i in range(self.scale_num):
            fuse_list.append(nn.Sequential(nn.Conv2d(cur_feat_len,base_channel * (self.scale_num-i),3,1,1),
            nn.LeakyReLU()))
            cur_feat_len = base_channel * (self.scale_num-i)
        fuse_list.append(
            nn.Sequential(nn.Conv2d(base_channel,base_channel,3,1,1),
            nn.LeakyReLU())
            )
        self.fuse_block = fuse_list

        if not self.no_output:
            self.outc = OutConv(base_channel, out_channel)

    def forward(self, in_feat):
        x = torch.cat(in_feat,dim=1)
       
        for i in range(self.scale_num+1):
            x = self.fuse_block[i](x)
        if self.no_output:
            return x

        logits = self.outc(x)
        return logits


class RRDBDecoder(nn.Module):
    def __init__(self, config,no_output=False):
        super(RRDBDecoder, self).__init__()
        self.config = config
        self.no_output = no_output
        base_channel = config['base_channel']
        bilinear = config['bilinear']
        scale = config['downsample_scale']
        self.scale_num = len(scale)
        out_channel = config['n_channel'] * (config['shuffle_scale']**2)

        factor = 2 if bilinear else 1
        
        upsample_block_list = []
        for i in range(0,self.scale_num,1):
            skip=True if ((self.scale_num-i)>=config['scale_in_use']) else False
            upsample_block_list.append(Up(config,base_channel * (2**((i+1))), base_channel * (2**(i)), bilinear,scale=scale[i],skip=skip))
        self.upsample_block =  nn.ModuleList(upsample_block_list)

        if self.config['conv_fuse']:
            fuse_list = []
            
            for i in range(0,self.scale_num+1):
                cur_feat_len = base_channel * (2**(i))
                fuse_list.append(nn.Sequential(
                ResBlock(config,cur_feat_len,cur_feat_len),
                ResBlock(config,cur_feat_len,cur_feat_len),
                ResBlock(config,cur_feat_len,cur_feat_len),
                ))
            self.conv_fuse_list = nn.ModuleList(fuse_list)

        layers = []
        for i in range(config['rrdb_decoder_num']):
            layers.append(RRDB(config['base_channel']))
        self.rrdb_block= nn.Sequential(*layers)
        if not self.no_output:
        
            self.outc = OutConv(base_channel, out_channel)
        if self.config['res_cat_img']:
            self.fuse = nn.Sequential(
                nn.Conv2d(out_channel*2,base_channel*2,3,1,1),
                nn.LeakyReLU(),
                nn.Conv2d(base_channel*2,base_channel,3,1,1),
                nn.LeakyReLU(),
                nn.Conv2d(base_channel,out_channel,3,1,1),
            )

    def forward(self, in_feat):

        if self.config['res_cat_img']:
            res_img = in_feat.pop(0)

        x = in_feat[-1].clone()
        for i in range(self.scale_num-1,-1,-1):
            x = self.upsample_block[i](x,in_feat[i])
        
        x = self.rrdb_block(x)
        if self.no_output:
            return x

        logits = self.outc(x)
        if self.config['res_cat_img']:
            logits = self.fuse(torch.cat([logits,res_img],dim=1))
        return logits

class PostGAT(nn.Module):
    def __init__(self, config):
        super(PostGAT, self).__init__()

        base_channel = config['base_channel']
        bilinear = config['bilinear']
        scale = config['downsample_scale']
        self.scale_num = len(scale)

        process_block_list = []
        for i in range(0,config['scale_in_use'],1):
            cur_feat_len = base_channel * (2**((i+(self.scale_num+1-config['scale_in_use']))))

            process_block_list.append(nn.Sequential(nn.Conv2d(cur_feat_len,cur_feat_len*2,3,1,1),
            nn.BatchNorm2d(cur_feat_len * 2),
            nn.LeakyReLU()))
        self.process_block_list =  nn.ModuleList(process_block_list)


    def forward(self, in_feat):
        out = []
        none_ct = 0
        for i in range(len(in_feat)):
            if in_feat[i] is None:
                out.append(None)
                none_ct += 1
            else:
                out.append(self.process_block_list[i-none_ct](in_feat[i]))
        return out

#---------------------------------#

class SpatialAtt(nn.Module):
    def __init__(self, config,img_feat=False,edge_type='spatial'):
        super(SpatialAtt, self).__init__()
        self.config = config
        cur_feat_len = config['pred_base_channel'] + config['tendency_len'] + config['pos_len']
        # self.net = ResBlock(config,cur_feat_len,cur_feat_len)
        self.net = nn.Sequential(nn.Conv2d(cur_feat_len,cur_feat_len,3,1,1),
        nn.BatchNorm2d(cur_feat_len),
        nn.LeakyReLU())

        # if self.config['deform_conv']:

        # # self.net = nn.Sequential(DeformConv(cur_feat_len,cur_feat_len,kernel_size=3,stride=1,padding=1),
        # # nn.BatchNorm2d(cur_feat_len),
        # # nn.LeakyReLU())
        
    def forward(self,graph_feat,edge,weight=None,debug=False):

            
        B,T,HW,C = graph_feat.shape
        graph_feat_spatial = graph_feat.reshape(B*T,self.config['mat_size'][-1][0],self.config['mat_size'][-1][1],C).permute(0,3,1,2)
        graph_feat = self.net(graph_feat_spatial)
        graph_feat = graph_feat.reshape(B,T,C,HW).permute(0,1,3,2)

        return graph_feat


class GraphAtt(nn.Module):
    def __init__(self, config,img_feat=False,edge_type='forward'):
        super(GraphAtt, self).__init__()
        self.config = config
        self.edge_type = edge_type
        self.img_feat = img_feat
        
        head_num = self.config['graph_att_head_num']
        if img_feat:
            graph_feat_len = img_feat
        else:
            graph_feat_len = config['pred_base_channel'] + config['tendency_len'] + config['pos_len']


        if not img_feat:

            # self.value = nn.Linear(graph_feat_len,graph_feat_len * head_num, bias=False)
            self.att_layer = nn.Linear(graph_feat_len * head_num ,graph_feat_len * head_num)
            self.fuse = nn.Linear(graph_feat_len * (head_num+1),graph_feat_len)
            # self.fuse = nn.Linear(graph_feat_len,graph_feat_len)
            # self.linear = nn.Linear(graph_feat_len,graph_feat_len)
            self.fuse_norm = nn.GroupNorm(1, graph_feat_len)
            self.activate = nn.LeakyReLU()
            self.norm = nn.GroupNorm(1, graph_feat_len * head_num)
            
            if not self.config['Wo_PE']:
                if self.config['tendency_dist']:
                    self.dist = nn.Sequential(
                    nn.Linear(2 , graph_feat_len),
                    nn.LeakyReLU(inplace=True),
                    nn.Linear(graph_feat_len,graph_feat_len),
                    nn.GroupNorm(1, graph_feat_len)
                    )
                else:
                    self.dist = nn.Sequential(
                    nn.Linear(2 , graph_feat_len),
                    nn.LeakyReLU(inplace=True),
                    nn.Linear(graph_feat_len,graph_feat_len),
                    nn.GroupNorm(1, graph_feat_len)
                    )

    def forward(self,graph_feat,edge,weight=None,debug=False,position=None):
        if self.edge_type == 'compose':
            B,T,C,H,W = graph_feat.shape

            HW = H*W

            graph_feat = graph_feat.permute(0,1,3,4,2)
            graph_feat = graph_feat.reshape(B*T*H*W,C)
            # print(graph_feat[0,:10])
            
        else:
            
            B,T,HW,C = graph_feat.shape
            graph_feat = graph_feat.reshape(-1,C)

            position = position.clone().repeat(B,1,1,1).reshape(B*T*HW,-1)
        copy_graph_feat = torch.zeros_like(graph_feat)

        if self.edge_type == 'forward':
            node_id_pre = torch.stack([edge[:,0],edge[:,1],edge[:,2]],dim=1)
            node_id_suc = torch.stack([edge[:,0],edge[:,1]+1,edge[:,3]],dim=1)
        elif self.edge_type == 'backward':
            node_id_pre = torch.stack([edge[:,0],edge[:,1]+1,edge[:,2]],dim=1)
            node_id_suc = torch.stack([edge[:,0],edge[:,1],edge[:,3]],dim=1)
        elif self.edge_type == 'spatial':
            node_id_pre = torch.stack([edge[:,0],edge[:,1],edge[:,2]],dim=1)
            node_id_suc = torch.stack([edge[:,0],edge[:,1],edge[:,3]],dim=1)
        elif self.edge_type == 'compose':

            node_id_pre = torch.stack([edge[:,0],edge[:,1],edge[:,2]],dim=1)
            fut_timestamp = torch.ones_like(edge[:,1]) * (self.config['prev_len'])
            node_id_suc = torch.stack([edge[:,0],fut_timestamp,edge[:,3]],dim=1) ### Currently only support single frame prediction ###
        
        flat_id_pre = torchist.ravel_multi_index(node_id_pre,(B,T,HW))
        flat_id_suc = torchist.ravel_multi_index(node_id_suc,(B,T,HW))

        
        
        if self.img_feat:

            copy_graph_feat.index_add_(0,flat_id_suc,graph_feat[flat_id_pre].clone()* weight.unsqueeze(-1))
            graph_feat = copy_graph_feat.reshape(B*T,H,W,C).permute(0,3,1,2)
            # graph_feat = self.fuse(graph_feat)
        else:
            # res = graph_feat.clone()
            copy_graph_feat = copy_graph_feat.unsqueeze(-2).repeat(1,self.config['graph_att_head_num'],1).reshape(B*T*HW,-1)
            feat_to_add = graph_feat[flat_id_pre].clone()
            # feat_to_add[...,-self.config['tendency_len']:] = feat_to_add[...,-self.config['tendency_len']:] - graph_feat[flat_id_suc][...,-self.config['tendency_len']:] 
            # value = self.value(feat_to_add)
            value  = feat_to_add
            
            if self.config['Wo_PE']:
                att_result = self.att_layer(value)
            else:
                tendency_dist = graph_feat[flat_id_suc][...,-self.config['tendency_len']:] -feat_to_add[...,-self.config['tendency_len']:]
                pos_dist = position[flat_id_suc]-position[flat_id_pre]
                if self.config['tendency_dist']:
                    dist_emb = self.dist(torch.cat([pos_dist,tendency_dist],dim=-1))
                else:
                    dist_emb = self.dist(pos_dist)
                att_result = self.att_layer(value + dist_emb)
            # att_result = value + dist_emb
            # graph_feat.index_add_(0,flat_id_suc,att_result * weight.unsqueeze(-1))
            # graph_feat = self.linear(graph_feat)
            # graph_feat = self.activate(self.fuse_norm(self.fuse(graph_feat + res)))
            
            copy_graph_feat.index_add_(0,flat_id_suc,att_result * weight.unsqueeze(-1))
            copy_graph_feat = self.norm(copy_graph_feat)
            
            graph_feat[torch.unique(flat_id_suc)] = self.activate(self.fuse_norm(self.fuse(torch.cat([graph_feat[torch.unique(flat_id_suc)],copy_graph_feat[torch.unique(flat_id_suc)]],dim=-1))))

        if self.img_feat:
            graph_feat = graph_feat.reshape(B,T,C,H,W)
        else:
            graph_feat = graph_feat.reshape(B,T,HW,-1)
        # print('exit compose: ',graph_feat[0,3,:10,16,16])

        return graph_feat



class DeformConv(nn.Module):
    '''
    Code is modified from https://blog.csdn.net/C1nDeRainBo0M/article/details/123104016
    '''
    def __init__(self,in_channel,out_channel,kernel_size=3,stride=1,padding=1,device='cuda'):
        super(DeformConv, self).__init__()
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding)
        self.conv_offset = nn.Conv2d(in_channel, 18, kernel_size=kernel_size, stride=stride, padding=padding)
        init_offset = torch.zeros([18, in_channel, 3, 3]).to(device)
        self.conv_offset.weight = torch.nn.Parameter(init_offset) 
 
        self.conv_mask = nn.Conv2d(1, 9, kernel_size=kernel_size, stride=stride, padding=padding)
        init_mask = (torch.zeros([9, in_channel, 3, 3]) + 0.5).to(device)
        self.conv_mask.weight = torch.nn.Parameter(init_mask) 
 
    def forward(self, x):
        offset = self.conv_offset(x)
        mask = torch.sigmoid(self.conv_mask(x)) 
        out = torchvision.ops.deform_conv2d(input=x, offset=offset, 
                                            weight=self.conv.weight, 
                                             mask=mask, padding=(1, 1))
        return out


