from pygcn import load_data
import torch
import argparse
import numpy as np
import torch.optim as optim
import random
import scipy as sp
import pickle
import torch.nn as nn
import torch.nn.functional as F
from pygcn.layers import GraphConvolution
import networkx as nx
import community
from soft_cluster import cluster, SoftCluster
import sklearn
from influmax import InfluObjective
from influmax import sample_live_icm, make_multilinear_gradient, make_multilinear_objective_samples, fw, live_edge_to_edgelist
from maxcut_mip import local_search
from influmax import sample_live_icm_parallel
from influmax import multi_to_set, greedy

def make_normalized_adj(edges):
    def normalize(mx):
        """Row-normalize sparse matrix"""
        rowsum = np.array(mx.sum(1))
        r_inv = np.power(rowsum, -1).flatten()
        r_inv[np.isinf(r_inv)] = 0.
        r_mat_inv = sp.sparse.diags(r_inv)
        mx = r_mat_inv.dot(mx)
        return mx
    
    def sparse_mx_to_torch_sparse_tensor(sparse_mx):
        """Convert a scipy sparse matrix to a torch sparse tensor."""
        sparse_mx = sparse_mx.tocoo().astype(np.float32)
        indices = torch.from_numpy(
            np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
        values = torch.from_numpy(sparse_mx.data)
        shape = torch.Size(sparse_mx.shape)
        return torch.sparse.FloatTensor(indices, values, shape)

    edges = edges.detach().numpy()
    adj = sp.sparse.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                        shape=(n, n),
                        dtype=np.float32)
    # build symmetric adjacency matrix
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    adj = normalize(adj + sp.sparse.eye(adj.shape[0]))
    return sparse_mx_to_torch_sparse_tensor(adj)

import utils
import sys, os
class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout):
        super(GCN, self).__init__()

        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nout)
        self.dropout = dropout

    def forward(self, x, adj):
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        return x
    

class GCNDeep(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout, nlayers):
        super(GCNDeep, self).__init__()

        self.gcstart = GraphConvolution(nfeat, nhid)
        self.gcn_middle = []
        for i in range(nlayers-2):
            self.gcn_middle.append(GraphConvolution(nhid, nhid))
        self.gcend = GraphConvolution(nhid, nout)
        self.dropout = dropout

    def forward(self, x, adj):
        x = F.relu(self.gcstart(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        for gc in self.gcn_middle:
            x = F.relu(gc(x, adj))
            x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu(self.gcend(x, adj))

        return x

    
class GCNLink(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout):
        super(GCNLink, self).__init__()

        self.GCN = GCN(nfeat, nhid, nout, dropout)
        self.distmult = nn.Parameter(torch.rand(nout))
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x, adj, to_pred):
        embeds = self.GCN(x, adj)
        dot = (embeds[to_pred[:, 0]]*self.distmult.expand(to_pred.shape[0], self.distmult.shape[0])*embeds[to_pred[:, 1]]).sum(dim=1)
        return dot

class GCNCluster(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout, K, cluster_temp):
        super(GCNCluster, self).__init__()

        self.GCN = GCN(nfeat, nhid, nout, dropout)
        self.distmult = nn.Parameter(torch.rand(nout))
        self.sigmoid = nn.Sigmoid()
        self.K = K
        self.cluster_temp = cluster_temp
        
    def forward(self, x, adj, num_iter=1):
        embeds = self.GCN(x, adj)
        mu, r, dist = cluster(embeds, self.K, 1, num_iter, cluster_temp = self.cluster_temp)
        return mu, r, embeds, dist
    
class GCNClusterLast(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout, K, cluster_temp):
        super(GCNClusterLast, self).__init__()

        self.GCN = GCN(nfeat, nhid, nout, dropout)
        self.distmult = nn.Parameter(torch.rand(nout))
        self.sigmoid = nn.Sigmoid()
        self.K = K
        self.cluster_temp = cluster_temp
        
    def forward(self, x, adj, num_iter=1):
        embeds = self.GCN(x, adj)
        mu_init, _, _ = cluster(embeds, self.K, 1, num_iter, cluster_temp = self.cluster_temp, init = torch.rand(self.K, embeds.shape[1]))
        mu, r, dist = cluster(embeds, self.K, 1, 1, cluster_temp = self.cluster_temp, init = mu_init.detach().clone())
#        mu, r, dist = cluster(embeds, self.K, 1, 3, cluster_temp = self.cluster_temp, init = torch.rand(self.K, embeds.shape[1]))
        return mu, r, embeds, dist


class GCNClusterAnalytical(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout, K, cluster_temp):
        super(GCNClusterAnalytical, self).__init__()

        self.GCN = GCN(nfeat, nhid, nout, dropout)
        self.distmult = nn.Parameter(torch.rand(nout))
        self.sigmoid = nn.Sigmoid()
        self.K = K
        self.cluster_temp = cluster_temp
        self.clusterfunc = SoftCluster(K, cluster_temp, 5, cluster_temp)
        
    def forward(self, x, adj, num_iter=1):
        embeds = self.GCN(x, adj)
        embeds = torch.diag(1./torch.norm(embeds, p=2, dim=1)) @ embeds
        mu, _, _ = self.clusterfunc(embeds)
        dist = embeds @ mu.t()
        r = torch.softmax(self.cluster_temp*dist, 1)
        return mu, r, embeds, dist

    
class GCNClusterParam(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout, K, cluster_temp):
        super(GCNClusterParam, self).__init__()

        self.GCN = GCN(nfeat, nhid, nout, dropout)
        self.mu = nn.Parameter(torch.rand(K, nout))
        self.sigmoid = nn.Sigmoid()
        self.K = K
        self.cluster_temp = cluster_temp
        
    def forward(self, x, adj, num_iter=1):
        embeds = self.GCN(x, adj)
        n = embeds.shape[0]
        k = self.K
        d = embeds.shape[1]
        dist = torch.cosine_similarity(embeds[:, None].expand(n, k, d).reshape((-1, d)), self.mu[None].expand(n, k, d).reshape((-1, d))).reshape((n, k))
        r = torch.softmax(self.cluster_temp*dist, 1)
        return self.mu, r, embeds, dist


parser = argparse.ArgumentParser()
parser.add_argument('--no-cuda', action='store_true', default=True,
                    help='Disables CUDA training.')
parser.add_argument('--fastmode', action='store_true', default=False,
                    help='Validate during training pass.')
parser.add_argument('--seed', type=int, default=24, help='Random seed.')
parser.add_argument('--epochs', type=int, default=200,
                    help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=0.01,
                    help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=5e-4,
                    help='Weight decay (L2 loss on parameters).')
parser.add_argument('--hidden', type=int, default=50,
                    help='Number of hidden units.')
parser.add_argument('--dropout', type=float, default=0.5,
                    help='Dropout rate (1 - keep probability).')
parser.add_argument('--embed_dim', type=int, default=20,
                    help='Dimensionality of node embeddings')
parser.add_argument('--K', type=int, default=5,
                    help='How many partitions')
parser.add_argument('--negsamplerate', type=int, default=1,
                    help='How many negative examples to include per positive in link prediction training')
parser.add_argument('--edge_dropout', type=float, default=0.2,
                    help='Rate at which to remove edges in link prediction training')
parser.add_argument('--objective', type=str, default='influmax',
                    help='What objective to optimize (currently partitioning or modularity')
parser.add_argument('--dataset', type=str, default='cora',
                    help='which network to load')
parser.add_argument('--influmaxp', type=float, default=0.1,
                    help='propagation prob for ICM')
parser.add_argument('--clustertemp', type=float, default=20,
                    help='how hard to make the softmax for the cluster assignments')
parser.add_argument('--influmaxtemp', type=float, default=30,
                    help='how hard to make seed selection softmax assignment')
parser.add_argument('--viz', action='store_true', default=False, help='Disables CUDA training.')
parser.add_argument('--train_pct', type=float, default=0.01, help='percent of total edges in training set')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
viz = args.viz
# Load data
reload_data = False
train_pct = args.train_pct
if reload_data: 
    adj_test, features, labels, idx_train, idx_val, idx_test = load_data('data/{}/'.format(args.dataset), '{}_test_{:.2f}'.format(args.dataset, train_pct))
    adj_test = adj_test.coalesce()
    bin_adj_test = (adj_test.to_dense() > 0).float()
    
    adj_valid, features, labels, idx_train, idx_val, idx_test = load_data('data/{}/'.format(args.dataset), '{}_valid_{:.2f}'.format(args.dataset, train_pct))
    adj_valid = adj_valid.coalesce()
    bin_adj_valid = (adj_valid.to_dense() > 0).float()
    
    adj_train, features, labels, idx_train, idx_val, idx_test = load_data('data/{}/'.format(args.dataset), '{}_train_{:.2f}'.format(args.dataset, train_pct))
    adj_train = adj_train.coalesce()
    bin_adj_train = (adj_train.to_dense() > 0).float()
    m_train = bin_adj_train.sum()
    
    mod_train = make_modularity_matrix(bin_adj_train)
    mod_test = make_modularity_matrix(bin_adj_test)
    mod_valid = make_modularity_matrix(bin_adj_valid)
    bin_adj_all = (bin_adj_train + bin_adj_test + bin_adj_valid > 0).float()
    mod_all = make_modularity_matrix(bin_adj_all)
    adj_all = make_normalized_adj(bin_adj_all.nonzero())

    
    graph_data_train, graph_data_test = pickle.load(open('{}_graphs_{}_{}.pickle'.format(args.dataset, args.influmaxp, train_pct), 'rb'))
    #live_graphs_test, live_graphs_train = pickle.load(open('{}_live_graphs_{}_{}.pickle'.format(args.dataset, args.influmaxp, train_pct), 'rb'))
    obj_train = InfluObjective(bin_adj_train*args.influmaxp, -1, 200, True, graph_data_train)
    obj_test = InfluObjective(bin_adj_all*args.influmaxp, -1, 500, True, graph_data_test)
    obj_train = obj_test
    graph_data_train = graph_data_test
    #live_graphs_train = live_graphs_test


n = adj_train.shape[0]
K = args.K
# Model and optimizer
model_ts = GCNLink(nfeat=features.shape[1],
            nhid=args.hidden,
            nout=args.embed_dim,
            dropout=args.dropout)

model_cluster = GCNCluster(nfeat=features.shape[1],
            nhid=args.hidden,
            nout=args.embed_dim,
            dropout=args.dropout,
            K = args.K, 
            cluster_temp = args.clustertemp)

model_cluster_analytical = GCNClusterAnalytical(nfeat=features.shape[1],
            nhid=args.hidden,
            nout=args.embed_dim,
            dropout=args.dropout,
            K = args.K, 
            cluster_temp = args.clustertemp)

model_cluster_last = GCNClusterLast(nfeat=features.shape[1],
            nhid=args.hidden,
            nout=args.embed_dim,
            dropout=args.dropout,
            K = args.K, 
            cluster_temp = args.clustertemp)


model_cluster = model_cluster_last


model_cluster_param = GCNClusterParam(nfeat=features.shape[1],
            nhid=args.hidden,
            nout=args.embed_dim,
            dropout=args.dropout,
            K = args.K, 
            cluster_temp = args.clustertemp)

model_gcn = GCNDeep(nfeat=features.shape[1],
            nhid=args.hidden,
            nout=args.K,
            dropout=args.dropout, 
            nlayers=5)


optimizer = optim.Adam(model_cluster.parameters(),
                       lr=args.lr, weight_decay=args.weight_decay)

if args.cuda:
    model_cluster.cuda()
    model_ts.cuda()
    features = features.cuda()
    adj_train = adj_train.cuda()
    labels = labels.cuda()
    idx_train = idx_train.cuda()
    idx_val = idx_val.cuda()
    idx_test = idx_test.cuda()

losses = []
losses_test = []
num_cluster_iter = 1

def make_modularity_matrix(adj):
    degrees = adj.sum(dim=0).unsqueeze(1)
    mod = adj - degrees@degrees.t()/adj.sum()
    return mod




model_cluster_last = GCNClusterLast(nfeat=features.shape[1],
            nhid=args.hidden,
            nout=args.embed_dim,
            dropout=args.dropout,
            K = args.K, 
            cluster_temp = args.clustertemp)


def loss_influmax(mu, r, embeds, dist, bin_adj, obj):
    if obj == None:
        return torch.tensor(0).float()
    x = torch.softmax(dist*args.influmaxtemp, 0).sum(dim=1)
    x = torch.clamp(x, 0, 1)
#    remaining_budget = args.K - x.sum()
#    top_k = torch.topk(x, k=args.K)
    '''for item in top_k:
        import code
        #code.interact(local=locals())
        index = item[0].detach()
        to_add = torch.min(1 - x[index], remaining_budget)
        remaining_budget -= to_add 
        x[index] += to_add '''
    loss = obj(x)
    return loss

model_cluster = model_cluster_last
optimizer = optim.Adam(model_cluster.parameters(),
                       lr=args.lr, weight_decay=args.weight_decay)

loss_fn = loss_influmax
test_object = obj_test
test_only_object = None
adj_nothing = make_normalized_adj(torch.eye(n).nonzero())
num_cluster_iter = 1
for t in range(1000):
    if args.objective == 'influmax':
        train_object = InfluObjective(bin_adj_all*args.influmaxp, -1, 200, True, graph_data_test)

    mu, r, embeds, dist = model_cluster(features, adj_all, num_cluster_iter)
#    if t == 0:
#        print(r, r.min(), r.max())
#        break
    loss = -loss_fn(mu, r, embeds, dist, bin_adj_all, test_object)
#    loss = torch.nn.MSELoss()(r, r_opt)
#    plt.figure()
#    plt.scatter(embeds[:50].flatten().detach().numpy(), np.ones(50), c=c[:50])
#    plt.scatter(mu.flatten().detach().numpy(), np.ones(2), c = ['cyan', 'cyan'])
    optimizer.zero_grad()
    loss.backward()
#    if t == 500:
#        num_cluster_iter = 5
#        if t == 1000:
#            num_cluster_iter = 10
    if t % 100 == 0:
        if args.objective == 'modularity' or args.objective == 'maxcut':
            r = torch.softmax(100*r, dim=1)
        loss_test = loss_fn(mu, r, embeds, dist, bin_adj_all, test_object)
        loss_test_only = loss_fn(mu, r, embeds, dist, bin_adj_test, test_only_object)
        losses_test.append(loss_test.item())
#        embeds_for_graph.append(embeds)
        print(t, loss.item(), loss_test.item(), loss_test_only.item())
#        plt.figure()
#        embeds_np = embeds.detach().numpy()
#        plt.scatter(embeds_np[:, 0], embeds_np[:, 1])
    losses.append(loss.item())
    optimizer.step()
