import numpy as np

def get_my_neighbors():
    from torch_geometric.datasets import Planetoid
    dataset=Planetoid(root='./dataset',name='Cora')
    data=dataset.data
    edge_index=data.edge_index
    edge_index=edge_index.numpy()
    node_neighbors = {}
    for i in range(data.x.shape[0]):
        node_neighbors[i] = [i] 
    for i in range(data.edge_index.shape[1]):
          node_neighbors[edge_index[0,i]].append(edge_index[1,i])
    np.save('node_neighbors.npy', node_neighbors)


def get_my_feats():
    from torch_geometric.datasets import Planetoid
    dataset=Planetoid(root='./dataset',name='Cora')
    data=dataset.data
    feats=data.x.numpy()
    np.save('feats.npy', feats)
def get_my_labels():
    from torch_geometric.datasets import  Planetoid
    dataset=Planetoid(root='./dataset',name='Cora')
    data=dataset.data
    labels=data.y.numpy()
    np.save('labels.npy', labels)


def dataset_splits(data, num_classes, percls_trn=20, val_lb=500, Flag=0):

    indices = []
    for i in range(num_classes):
        index = (data.y == i).nonzero().view(-1) 
        index = index[torch.randperm(index.size(0))]
        indices.append(index)

    train_index = torch.cat([i[:percls_trn] for i in indices], dim=0) 

    if Flag == 0:
         rest_index = torch.cat([i[percls_trn:] for i in indices], dim=0)
         rest_index = rest_index[torch.randperm(rest_index.size(0))]
         val_index=rest_index[:val_lb]
         test_index=rest_index[val_lb:]

    else:
         val_index = torch.cat([i[percls_trn:percls_trn+val_lb] for i in indices], dim=0)
         rest_index = torch.cat([i[percls_trn+val_lb:] for i in indices], dim=0)
         rest_index = rest_index[torch.randperm(rest_index.size(0))]
         test_index=rest_index
                                        
    return train_index, val_index, test_index



def get_t_v_t(train_index, val_index, test_index):
    file = open('train_set.txt','w')
    for i in range(len(train_index)):
         a=train_index[i].numpy()
         file.writelines(str(a)+'\n')

    file = open('val_set.txt','w')
    for i in range(len(val_index)):
         b=val_index[i].numpy()
         file.writelines(str(b)+'\n')

    file = open('test_set.txt','w')
    for i in range(len(test_index)):
         c=test_index[i].numpy()
         file.writelines(str(c)+'\n')

from torch_geometric.datasets import Planetoid
import torch
dataset = Planetoid(root='./dataset',name='Cora')
data=dataset.data
num_classes=dataset.num_classes
train_index,val_index,test_index = dataset_splits(data, num_classes,percls_trn=20, val_lb=30, Flag=1)

get_t_v_t(train_index,val_index,test_index)
get_my_neighbors()
get_my_feats()
get_my_labels()


