import torch
import numpy as np
from utils import load_data
import scipy as sp
import scipy.sparse
def make_normalized_adj(edges, n):
    def normalize(mx):
        """Row-normalize sparse matrix"""
        rowsum = np.array(mx.sum(1))
        r_inv = np.power(rowsum, -1).flatten()
        r_inv[np.isinf(r_inv)] = 0.
        r_mat_inv = sp.sparse.diags(r_inv)
        mx = r_mat_inv.dot(mx)
        return mx
    
    def sparse_mx_to_torch_sparse_tensor(sparse_mx):
        """Convert a scipy sparse matrix to a torch sparse tensor."""
        sparse_mx = sparse_mx.tocoo().astype(np.float32)
        indices = torch.from_numpy(
            np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
        values = torch.from_numpy(sparse_mx.data)
        shape = torch.Size(sparse_mx.shape)
        return torch.sparse.FloatTensor(indices, values, shape)

    edges = edges.detach().numpy()
    adj = sp.sparse.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                        shape=(n, n),
                        dtype=np.float32)
    # build symmetric adjacency matrix
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    adj = normalize(adj + sp.sparse.eye(adj.shape[0]))
    return sparse_mx_to_torch_sparse_tensor(adj)
dataset = 'cora'
adj_all, features_all, labels, idx_train, idx_val, idx_test = load_data('data/{}/'.format(dataset), '{}'.format(dataset))
adj_all = adj_all.coalesce()
bin_adj_all = (adj_all.to_dense() > 0).float()

bin_adj = []
adj = []
features = []

train_pct = 1.0



for i in torch.unique(labels):
    nodes = (labels == i).nonzero().flatten()
    subgraph = bin_adj_all[nodes][:, nodes]
    edges = subgraph.nonzero()
    m = edges.shape[0]
    order = torch.tensor(np.random.permutation(list(range(m)))).long()
    edges_train = edges[order[:int(m*train_pct)]]
    print(i, nodes.shape[0], edges_train.max(), edges_train.shape)
#    adj = make_normalized_adj(edges_train, nodes.shape[0])
    torch.save(edges_train, 'data/cora/edges_{}_{:.2f}.pt'.format(i, train_pct))
    
