from __future__ import division
from __future__ import print_function

import time, os, argparse, random

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable


from pygcn.utils import load_data, accuracy
from models import GCN, Dense_GCN, Simplified_GCN

# Training settings 
parser = argparse.ArgumentParser()
parser.add_argument('--no_cuda', action='store_true', default=False, help='Disables CUDA training.')
parser.add_argument('--fastmode', action='store_true', default=False, help='Validate during training pass.')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=500, help='Number of epochs to train.')
parser.add_argument('--runtimes', type=int, default=50, help='Runtimes.')
parser.add_argument('--silent', type=int, default=0, help='No prompts during running.')
parser.add_argument('--percent', type=float, default=0.05, help='Percentage of training set.')
parser.add_argument('--identifier', type=int, default=1234567, help='Identifier for the job')
parser.add_argument('--dataset', type=str, default='cora', help='Dataset (Cora, Citeseer, Pubmed)')
parser.add_argument('--public', type=int, default=0, help='Use the Public Setting of the Dataset of not')
parser.add_argument('--networks', type=str, default='Dense_GCN', help='Network type (GCN, Dense_GCN, Simplified_GCN).')
# TO BE TUNED
parser.add_argument('--lr', type=float, default=0.0076774, help='Initial learning rate.')
parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).')
parser.add_argument('--weight_decay', type=float, default=0.0062375, help='Weight decay (L2 loss on parameters).')
parser.add_argument('--hidden', type=int, default=128, help='Number of hidden units.')
parser.add_argument('--layers', type=int, default=0, help='Number of hidden layers.')
parser.add_argument('--activation', type=str, default="relu", help='Activation Function')
parser.add_argument('--eye', type=int, default=0, help='Use Identity Matrix in Output Layer or not')
parser.add_argument('--with_features', type=int, default=0, help='The output layer uses features or not')
parser.add_argument('--layers_factor', type=float, default=1, help='Factor for Determining the Number of Layers')
parser.add_argument('--regularization_factor', type=float, default=0, help='Hyperparameter for Trace Regularization')
parser.add_argument('--regularization', type=str, default='btw_group', help='Regularization (Between Group Scatter(btw_group) and Tangent Distance(tan_dist)')
parser.add_argument('--features_in_hid', type=int, default=1, help='Stack Original Features in Hidden Layers or not')



args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

# Load data
dense_adj, features, labels = torch.load("%s_dense_adj.pt"%args.dataset), torch.load("%s_features.pt"%args.dataset), torch.load("%s_labels.pt"%args.dataset)
indices = torch.nonzero(dense_adj).t(); values = dense_adj[indices[0], indices[1]]
adj = torch.sparse.FloatTensor(indices, values, dense_adj.size()).clone()
del dense_adj, indices, values


if args.cuda:
    features, adj, labels = features.cuda(), adj.cuda(), labels.cuda()
    #idx_train, idx_val, idx_test = idx_train.cuda(), idx_val.cuda(), idx_test.cuda()
    #labels_train, labels_val, labels_test = labels_train.cuda(), labels_val.cuda(), labels_test.cuda()

def train(epoch):
    # global labels_train 
    t = time.time()
    model.train()
    optimizer.zero_grad()
    
    
    output, extractedf = model(features, adj)
    if args.regularization != 0:
        extractedf = (extractedf-extractedf.mean(0))
        btw_group = Variable(torch.zeros(extractedf.shape[1]), requires_grad=True).unsqueeze(0).cuda()
        between_group=torch.zeros(1)
        tan_matrix = Variable(torch.zeros(extractedf.shape[1]), requires_grad=True).unsqueeze(0).cuda()
        tan_dist = Variable(torch.zeros(1), requires_grad=True).cuda()
        for c in torch.tensor(np.unique(labels.cpu().numpy())):
            group_mean = (extractedf[torch.tensor(np.where(labels.cpu()==c)[0]),:]).mean(0).unsqueeze(0)
            if args.regularization == 'btw_group':
                btw_group = torch.cat( (btw_group, group_mean),0)
            elif args.regularization == 'tan_dist':
                tan_matrix = torch.cat((tan_matrix, torch.div(group_mean, torch.norm(group_mean))),0)
    
        #normalizedf = torch.div(btw_group, torch.norm(btw_group,2,1).reshape(btw_group[0],1))
        if args.regularization == 'btw_group':
            F_matrix = torch.mm(btw_group, torch.transpose(btw_group, 0, 1))
            regularizer = torch.trace(F_matrix)
        elif args.regularization == 'tan_dist':
            F_matrix = torch.mm(tan_matrix, torch.transpose(tan_matrix, 0, 1))
            F_norm = torch.sum(F_matrix) - torch.trace(F_matrix)
            regularizer = F_norm

    loss_train = F.nll_loss(output[idx_train], labels[idx_train]) - args.regularization_factor * regularizer

    acc_train = accuracy(output[idx_train], labels_train)
    loss_train.backward()
    optimizer.step()

    if not args.fastmode:
        # Evaluate validation set performance separately,
        # deactivates dropout during validation run.
        model.eval()
        output, extracted_features = model(features, adj)
    loss_val = F.nll_loss(output[idx_val], labels_val)
    acc_val = accuracy(output[idx_val], labels_val)
    if not args.silent:
        print('E%d' % (epoch + 1),
            'loss_train: {:.2e}'.format(loss_train.item()),
            'acc_train: {:.2f}%'.format(100 * acc_train.item()),
            'loss_val: {:.2e}'.format(loss_val.item()),
            'acc_val: {:.2f}%'.format(100 * acc_val.item()),
            # 'time: {:.1e}'.format(time.time() - t),
            end = " ")
        print("best_val: %.2e, best_test: %.2f%%" % (best_val, 100 * best_test)) #'idx_train: ', idx_train,
    return acc_val


def test():
    # global labels_test
    model.eval(); output,_ = model(features, adj)
    loss_test = F.nll_loss(output[idx_test], labels_test)
    acc_test = accuracy(output[idx_test], labels_test)
    if not args.silent:
        print("loss_test= {:.2e}".format(loss_test.item()), "acc_test= {:.2f}%".format(100 * acc_test.item()))
    return acc_test

def layer_numbers(idx_train, adj, percent, factor):
    local_adj = adj.clone().cpu(); local_idx_train = idx_train.clone().cpu()
    s = np.zeros((local_idx_train.shape[0], adj.shape[0]))
    s[np.arange(local_idx_train.shape[0]), local_idx_train] = 1
    local_s = torch.Tensor(s).cpu()
    j=-1
    redundant_addition=0
    new_addition=0
    while (new_addition>factor*percent*redundant_addition) | (redundant_addition==0):#(sum(np.sum(np.transpose(torch.spmm(adj,torch.transpose(s,0,1)).numpy())>0,0)>0) - np.sum(np.sum((s>0).numpy(),0)>0) )>0:
        j=j+1
        reached_nodes = (np.sum((local_s>0).numpy(),0)>0)*1
        new_reach_nodes = (np.sum(np.transpose(torch.spmm(local_adj,torch.transpose(local_s,0,1)).numpy())>0,0)>0)*1 -(reached_nodes)*1
        
        addition = np.sum(np.transpose(torch.spmm(local_adj,torch.transpose(local_s,0,1)).numpy())>0,0) - np.sum((local_s>0).numpy(),0)
        
        redundant_addition = np.dot(addition,reached_nodes)
        new_addition = np.dot(addition, new_reach_nodes)
        
        local_s = torch.transpose(torch.spmm(local_adj,torch.transpose(local_s,0,1)),0,1)
    return j
    
#    local_adj = adj.clone().cpu(); local_idx_train = idx_train.clone().cpu()
#    s = np.zeros((local_idx_train.shape[0], adj.shape[0]))
#    s[np.arange(local_idx_train.shape[0]), local_idx_train] = 1
#    local_s = torch.Tensor(s).cpu()
#    j = 0
#
#    ave_diff_degree = np.sum(torch.spmm(local_adj, torch.transpose(local_s, 0, 1)).numpy()) / local_idx_train.shape[0]
#    while (sum(np.sum(np.transpose(torch.spmm(local_adj, torch.transpose(local_s, 0, 1)).numpy()) > 0, 0) > 0) - np.sum(np.sum((local_s>0).numpy(), 0) > 0)) > 0:
#        j = j + 1
#        local_s = torch.transpose(torch.spmm(local_adj, torch.transpose(local_s,0,1)), 0, 1)
#    achieved_nodes = np.sum(np.sum((local_s > 0).numpy(),0) >0)
#    achieved_degree = np.sum(np.sum((local_s).numpy(), 0))
#    return j, ave_diff_degree, achieved_nodes, achieved_degree
#
#
    
    

# Train model
result3=list()
result_test=list()

total_running_time=0
for runtime in range(args.runtimes):
    best_val = 0; best_test = 0
    all_data = np.arange(adj.shape[0])
    
    if args.public == 1:
        if args.dataset == 'cora':
            idx_train = range(140)
            idx_val = range(200, 500)
            idx_test = range(500, 1500)
            percent=140/2708
        
        elif args.dataset == 'citeseer':
            idx_train = range(120)
            idx_val = range(120, 620)
            idx_test = range(2312, 3312)
            percent=120/3312
        
        elif args.dataset == 'pubmed':
            idx_train = range(60)
            idx_val = range(60, 600)
            idx_test = range(18717, 19717)
            percent=60/18717
        idx_train = torch.LongTensor(idx_train).cuda()
        idx_val = torch.LongTensor(idx_val).cuda()
        idx_test = torch.LongTensor(idx_test).cuda()
    else:
        idx_train=[]
        for c in np.unique(labels.cpu().numpy()):
            idx_train = np.hstack([idx_train,random.sample(list(np.where(labels.cpu().numpy()==c)[0]), int(np.where(labels.cpu().numpy()==c)[0].shape[0]*args.percent)+1)])
        others=np.delete(all_data,idx_train)
        random.shuffle(others)
        idx_train, idx_val, idx_test = torch.LongTensor(idx_train).cuda(), torch.LongTensor(others[200:500]).cuda(), torch.LongTensor(others[500:1500]).cuda()

    labels_train = labels[idx_train].cuda(); labels_val = labels[idx_val].cuda(); labels_test = labels[idx_test].cuda()

    if args.layers == 0:
        percent=args.percent
        layer_num = layer_numbers(idx_train,adj,percent,args.layers_factor) # , ave_diff_degree, achieved_nodes, achieved_degree
    else:
        layer_num = args.layers

    Network=eval(args.networks)
    model = Network(nfeat=features.shape[1],
                nlayers=layer_num,
                nhid=args.hidden,
                nclass=labels.max().item() + 1,
                dropout=args.dropout,
                activation = eval("F.%s" % args.activation),
                eye=args.eye,
                with_features=args.with_features)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    model.cuda()
    #model.reset_parameters()
    #for layer in range(1, args.layers + 1):
    #   eval("model.gc%d.reset_parameters()" % layer)
    t_total = time.time()
    #best_dict = model.state_dict().copy()
    for epoch in range(args.epochs):
        if not args.silent:
            print("R%d" % runtime, end = " ")
        acc_val=train(epoch)
        acc_test=test().cpu().numpy()
        if acc_val >= best_val:
            best_val = acc_val
            #best_dict['gc1.bias'] = (model.state_dict())['gc1.bias'].clone()
            #best_dict['gc1.weight'] = (model.state_dict())['gc1.weight'].clone()
            #best_dict['gc2.bias'] = (model.state_dict())['gc2.bias'].clone()
            #best_dict['gc2.weight'] = (model.state_dict())['gc2.weight'].clone()
            #best_dict['gc3.bias'] = (model.state_dict())['gc3.bias'].clone()
            #best_dict['gc3.weight'] = (model.state_dict())['gc3.weight'].clone()
            #best_dict['gc4.bias'] = (model.state_dict())['gc4.bias'].clone()
            #best_dict['gc4.weight'] = (model.state_dict())['gc4.weight'].clone()
            #best_dict['gc5.bias'] = (model.state_dict())['gc5.bias'].clone()
            #best_dict['gc5.weight'] = (model.state_dict())['gc5.weight'].clone()
            #best_dict['gc6.bias'] = (model.state_dict())['gc6.bias'].clone()
            #best_dict['gc6.weight'] = (model.state_dict())['gc6.weight'].clone()
            #best_dict['gc7.bias'] = (model.state_dict())['gc7.bias'].clone()
            #best_dict['gc7.weight'] = (model.state_dict())['gc7.weight'].clone()
            #best_dict['gc8.bias'] = (model.state_dict())['gc8.bias'].clone()
            #best_dict['gc8.weight'] = (model.state_dict())['gc8.weight'].clone()
            # best_dict = model.state_dict().copy()
            # for layer in range(1, args.layers + 1):
                # eval('best_dict["gc%d.bias"] = torch.Tensor.new_tensor(best_dict["gc%d.bias"])' % (layer, layer))
                # eval("best_dict[\'gc%d.weight\'] = best_dict[\'gc%d.weight\'].clone().detach()" % (layer, layer))
        if acc_test >= best_test:
            best_test = acc_test
    # Testing
    
    #model.load_state_dict(best_dict)
    if not args.silent:
        print("Optimization Finished!")
        print("Total time elapsed: {:.4f}s".format(time.time() - t_total))
    total_running_time=total_running_time+time.time() - t_total
    result_test.append(best_test)
    result3.append(test().cpu().numpy())
    if not args.silent:
	    print("idx_train: ", idx_train)

if not args.silent:
    print("mean result: ", np.mean(result3), "total running time: ", total_running_time)
else:
    print(np.mean(result3))
    script = open("%d.txt" % args.identifier, 'w'); script.write("%e" % np.mean(result3)); script.close()
