import networkx as nx
import numpy as np
g = nx.read_edgelist('data/pubmed/pubmed.cites', nodetype=int)
features = np.loadtxt('data/pubmed/pubmed.content')
node_list = np.array(list(g.nodes()))
part = np.loadtxt('data/pubmed/pubmet_for_metis.txt.part.20', dtype=np.int)
train_pct = 0.4
for i in range(20):
    nodes = np.where(part == i)[0]
    nodes = [node_list[v] for v in nodes]
    subg = nx.subgraph(g, nodes)
    print(len(subg))
    nx.write_edgelist(subg, 'data/pubmed/pubmed_{}.cites'.format(i), data=False)
    features_i = np.zeros((len(subg), features.shape[1]))
    for j, v in enumerate(subg.nodes()):
        row = np.where(features[:, 0] == v)[0][0]
        features_i[j] = features[row]
    np.savetxt('data/pubmed/pubmed_{}.content'.format(i), features_i, fmt = '%d ' + '%f '*(features.shape[1]-1))
    np.savetxt('data/pubmed/pubmed_{}_train_{:.2f}.content'.format(i, train_pct), features_i, fmt = '%d ' + '%f '*(features.shape[1]-1))