import torch.nn as nn
import torch.nn.functional as F
from pygcn.layers import GraphConvolution
import torch
from soft_cluster import cluster, SoftCluster
#from gat import GAT, SpGAT, GATGeom

class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout):
        super(GCN, self).__init__()

        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nout)
        self.dropout = dropout

    def forward(self, x, adj):
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        return x
    

class GCNDeep(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout, nlayers):
        super(GCNDeep, self).__init__()

        self.gcstart = GraphConvolution(nfeat, nhid)
        self.gcn_middle = []
        for i in range(nlayers-2):
            self.gcn_middle.append(GraphConvolution(nhid, nhid))
        self.gcend = GraphConvolution(nhid, nout)
        self.dropout = dropout

    def forward(self, x, adj):
        x = F.relu(self.gcstart(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        for gc in self.gcn_middle:
            x = F.relu(gc(x, adj))
            x = F.dropout(x, self.dropout, training=self.training)
        x = self.gcend(x, adj)

        return x
    

class GCNDeepSigmoid(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout, nlayers):
        super(GCNDeepSigmoid, self).__init__()

        self.gcstart = GraphConvolution(nfeat, nhid)
        self.gcn_middle = []
        for i in range(nlayers-2):
            self.gcn_middle.append(GraphConvolution(nhid, nhid))
        self.gcend = GraphConvolution(nhid, nout)
        self.dropout = dropout

    def forward(self, x, adj):
        x = F.relu(self.gcstart(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        for gc in self.gcn_middle:
            x = F.relu(gc(x, adj))
            x = F.dropout(x, self.dropout, training=self.training)
        x = self.gcend(x, adj)
        x = torch.nn.Sigmoid()(x).flatten()
        return x


    
class GCNLink(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout):
        super(GCNLink, self).__init__()

        self.GCN = GCN(nfeat, nhid, nout, dropout)
        self.distmult = nn.Parameter(torch.rand(nout))
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x, adj, to_pred):
        embeds = self.GCN(x, adj)
        dot = (embeds[to_pred[:, 0]]*self.distmult.expand(to_pred.shape[0], self.distmult.shape[0])*embeds[to_pred[:, 1]]).sum(dim=1)
        return dot

class GCNCluster(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout, K, cluster_temp):
        super(GCNCluster, self).__init__()

        self.GCN = GCN(nfeat, nhid, nout, dropout)
        self.distmult = nn.Parameter(torch.rand(nout))
        self.sigmoid = nn.Sigmoid()
        self.K = K
        self.cluster_temp = cluster_temp
        
    def forward(self, x, adj, num_iter=1):
        embeds = self.GCN(x, adj)
        mu, r, dist = cluster(embeds, self.K, 1, num_iter, cluster_temp = self.cluster_temp)
        return mu, r, embeds, dist
    
class GCNClusterLast(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout, K, cluster_temp):
        super(GCNClusterLast, self).__init__()

        self.GCN = GCN(nfeat, nhid, nout, dropout)
#        self.GCN = GCNDeep(nfeat, nhid, nout, dropout, 3)
#        self.gc1 = GraphConvolution(nfeat, nout)
        self.distmult = nn.Parameter(torch.rand(nout))
        self.sigmoid = nn.Sigmoid()
        self.K = K
        self.cluster_temp = cluster_temp
        self.init =  torch.rand(self.K, nout)
#        self.linear = torch.nn.Linear(K, 2)
#        self.linear.bias.data.zero_()
#        self.linear = torch.nn.Linear(nout, nout)
        
    def forward(self, x, adj, num_iter=1):
        embeds = self.GCN(x, adj)
#        embeds = self.gc1(x, adj)
#        embeds = self.linear(torch.relu(embeds))
        mu_init, _, _ = cluster(embeds, self.K, 1, num_iter, cluster_temp = self.cluster_temp, init = self.init)
        mu, r, dist = cluster(embeds, self.K, 1, 1, cluster_temp = self.cluster_temp, init = mu_init.detach().clone())
#        dist = self.linear(dist)
#        r = torch.softmax(self.cluster_temp*dist, dim=1)
#        mu, r, dist = cluster(embeds, self.K, 1, 3, cluster_temp = self.cluster_temp, init = torch.rand(self.K, embeds.shape[1]))
        return mu, r, embeds, dist
    

class GCNClusterGAT(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout, K, cluster_temp, alpha = 0.2, nheads = 2):
        super(GCNClusterGAT, self).__init__()

        self.GAT = GATGeom(nfeat, nhid, nout, dropout, alpha, nheads)
        self.distmult = nn.Parameter(torch.rand(nout))
        self.sigmoid = nn.Sigmoid()
        self.K = K
        self.cluster_temp = cluster_temp
        self.init =  torch.rand(self.K, nout)
        
    def forward(self, x, adj, num_iter=1):
        embeds = self.GAT(x, adj)
        mu_init, _, _ = cluster(embeds, self.K, 1, num_iter, cluster_temp = self.cluster_temp, init = self.init)
        mu, r, dist = cluster(embeds, self.K, 1, 1, cluster_temp = self.cluster_temp, init = mu_init.detach().clone())
        return mu, r, embeds, dist



class GCNClusterDirect(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout, K, cluster_temp):
        super(GCNClusterDirect, self).__init__()

        self.GCN = GCN(nfeat, nhid, nout, dropout)
#        self.GCN = GCNDeep(nfeat, nhid, nout, dropout, 4)
#        self.gc1 = GraphConvolution(nfeat, nout)
        self.distmult = nn.Parameter(torch.rand(nout))
        self.sigmoid = nn.Sigmoid()
        self.K = K
        self.cluster_temp = cluster_temp
        self.init =  torch.rand(self.K, nout)
        self.linear1 = torch.nn.Linear(nfeat, nout)
        self.linear1a = torch.nn.Linear(nout, nout)
#        self.linear.bias.data.zero_()
        self.linear2 = torch.nn.Linear(2*nout, nout)
        self.dropout = dropout
        
    def forward(self, x, adj, num_iter=1):
        embeds = self.GCN(x, adj)
        feat = torch.relu(self.linear1(x))
        feat = torch.nn.Dropout(self.dropout)(feat)
        feat = torch.relu(self.linear1a(feat))
        embeds = self.linear2(torch.cat((embeds, feat), dim=1))
#        embeds = self.gc1(x, adj)
#        embeds = self.linear(torch.relu(embeds))
        mu_init, _, _ = cluster(embeds, self.K, 1, num_iter, cluster_temp = self.cluster_temp, init = self.init)
        mu, r, dist = cluster(embeds, self.K, 1, 1, cluster_temp = self.cluster_temp, init = mu_init.detach().clone())
#        dist = self.linear(dist)
#        r = torch.softmax(self.cluster_temp*dist, dim=1)
#        mu, r, dist = cluster(embeds, self.K, 1, 3, cluster_temp = self.cluster_temp, init = torch.rand(self.K, embeds.shape[1]))
        return mu, r, embeds, dist


class GCNClusterAnalytical(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout, K, cluster_temp):
        super(GCNClusterAnalytical, self).__init__()

        self.GCN = GCN(nfeat, nhid, nout, dropout)
        self.distmult = nn.Parameter(torch.rand(nout))
        self.sigmoid = nn.Sigmoid()
        self.K = K
        self.cluster_temp = cluster_temp
        self.clusterfunc = SoftCluster(K, cluster_temp, 5, cluster_temp)
        
    def forward(self, x, adj, num_iter=1):
        embeds = self.GCN(x, adj)
        embeds = torch.diag(1./torch.norm(embeds, p=2, dim=1)) @ embeds
        mu, _, _ = self.clusterfunc(embeds)
        dist = embeds @ mu.t()
        r = torch.softmax(self.cluster_temp*dist, 1)
        return mu, r, embeds, dist

    
class GCNClusterParam(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout, K, cluster_temp):
        super(GCNClusterParam, self).__init__()

        self.GCN = GCN(nfeat, nhid, nout, dropout)
        self.mu = nn.Parameter(torch.rand(K, nout))
        self.sigmoid = nn.Sigmoid()
        self.K = K
        self.cluster_temp = cluster_temp
        
    def forward(self, x, adj, num_iter=1):
        embeds = self.GCN(x, adj)
        n = embeds.shape[0]
        k = self.K
        d = embeds.shape[1]
        dist = torch.cosine_similarity(embeds[:, None].expand(n, k, d).reshape((-1, d)), self.mu[None].expand(n, k, d).reshape((-1, d))).reshape((n, k))
        r = torch.softmax(self.cluster_temp*dist, 1)
        return self.mu, r, embeds, dist