import torch
from torch import nn

import Modules.sorted_topk as stk
from Modules.concrete_FS import concrete_selector


# %% Ada_Gate
class Ada_Gate(nn.Module):
    def __init__(self):
        super(Ada_Gate, self).__init__()
        self.gate = nn.ReLU()

    def forward(self,x,lower_bd,upper_bd):
        x_hat = lower_bd+self.gate(x-lower_bd)
        x_res = upper_bd-self.gate(upper_bd-x_hat)
        return x_res


# %% No orthognal module
class Ada_Graph_Fixed_Concrete_no_orth(nn.Module):
    '''
    Ablated version without UFS
    '''
    def __init__(self, n, dim, k, num_neighbors, epsilon = 1e-2, num_iter = 200, manual_flag = True,device = torch.device('cuda:0')):
        super(Ada_Graph_Fixed_Concrete_no_orth, self).__init__()
        self.n = n
        self.d = dim
        self.fea_selector = concrete_selector([k,dim])              
        
        self.m = num_neighbors

        if manual_flag:
            # manual grad 
            self.sorted_topm = stk.TopK_custom1(self.m+2, epsilon=epsilon, max_iter=num_iter,device = device)
        else:
            # auto grad       
            self.sorted_topm = stk.TopK_stablized(self.m+2, epsilon=epsilon, max_iter=num_iter,device = device)

        self.S_gate = nn.ReLU()

        self.eye_const = (1e-6)*torch.eye(k,requires_grad=False).cuda(device)
    
    def cal_dist_matrix(self,X,F):
        n = X.shape[0]
        X_new = torch.matmul(X,F)
        X_sum = torch.sum(X_new ** 2, 1,keepdim=True).repeat((1,n))  
        XXT = torch.matmul(X_new,X_new.T)
        E = X_sum+X_sum.T-2*XXT
        return E,X_new

    def forward(self,X):
        # initialized feature selector I
        FS_mat, temperature = self.fea_selector()                         # d x k
        
        # learn the distance matrix E
        E, X_new = self.cal_dist_matrix(X,FS_mat)

        ## learn the similariy graph
        # obtain the top-m values
        sorted_res = self.sorted_topm(-E)                    # n x n x (m+1)
        delta = sorted_res[:,:,1:self.m+1].sum(dim = -1)       # n x n 

        # obtain the m-th smallest value and m+1-th smallest value
        xi_mp1 = sorted_res[:,:,-1]                         # n x n

        E_mp1 = torch.stack([torch.matmul(xi_mp1[i,:],E[:,i]) for i in range(self.n)]).unsqueeze(-1)
        E_sum_m = torch.stack([torch.matmul(delta[i,:],E[:,i]) for i in range(self.n)]).unsqueeze(-1)
        
        S_temp = torch.div(E_mp1-E,self.m*E_mp1-E_sum_m)*delta
        S = self.S_gate((S_temp+S_temp.T)/2)

        return temperature, FS_mat, S, X_new


# %% current version
class Ada_Graph_Fixed_Concrete_new(nn.Module):
    '''
    The framework proposed in our paper
    '''
    def __init__(self, n, dim, k, num_neighbors, epsilon = 1e-2, num_iter = 200, manual_flag = True,device = torch.device('cuda:0')):
        super(Ada_Graph_Fixed_Concrete_new, self).__init__()
        self.n = n
        self.d = dim
        self.fea_selector = concrete_selector([k,dim])              
        
        self.m = num_neighbors

        if manual_flag:
            # manual grad 
            self.sorted_topm = stk.TopK_custom1(self.m+2, epsilon=epsilon, max_iter=num_iter,device = device)
        else:
            # auto grad       
            self.sorted_topm = stk.TopK_stablized(self.m+2, epsilon=epsilon, max_iter=num_iter,device = device)

        self.S_gate = nn.ReLU()

        self.eye_const = (1e-6)*torch.eye(k,requires_grad=False).cuda(device)
    
    def cal_dist_matrix(self,X,F):
        n = X.shape[0]
        X_new = torch.matmul(X,F)
        X_sum = torch.sum(X_new ** 2, 1,keepdim=True).repeat((1,n))  
        XXT = torch.matmul(X_new,X_new.T)
        E = X_sum+X_sum.T-2*XXT
        return E,X_new

    def forward(self,X, symmetric = True):
        # initialized feature selector I
        I_temp, temperature = self.fea_selector()                         # d x k

        L = torch.linalg.cholesky(torch.matmul(I_temp.T,I_temp)+self.eye_const)

        L_inv = torch.linalg.inv(L)
        FS_mat = torch.matmul(I_temp,L_inv.T)
        
        # learn the distance matrix E
        E, X_new = self.cal_dist_matrix(X,FS_mat)

        ## learn the similariy graph
        # obtain the top-m values
        sorted_res = self.sorted_topm(-E)                    # n x n x (m+1)
        delta = sorted_res[:,:,1:self.m+1].sum(dim = -1)       # n x n 

        # obtain the m-th smallest value and m+1-th smallest value
        xi_mp1 = sorted_res[:,:,-1]                         # n x n

        E_mp1 = torch.stack([torch.matmul(xi_mp1[i,:],E[:,i]) for i in range(self.n)]).unsqueeze(-1)
        E_sum_m = torch.stack([torch.matmul(delta[i,:],E[:,i]) for i in range(self.n)]).unsqueeze(-1)
        
        S_temp = torch.div(E_mp1-E,self.m*E_mp1-E_sum_m)*delta

        if symmetric:
            S = self.S_gate((S_temp+S_temp.T)/2)
        else:
            S = self.S_gate(S_temp)

        return temperature, FS_mat, S, X_new