from __future__ import division
from __future__ import print_function
import time, os, argparse, random
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from pygcn.utils import load_data, accuracy
from models import *
import random
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import math 
from sklearn import manifold

parser = argparse.ArgumentParser()
# EXPERIMENT SETTINGS
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--cuda', type=int, default=0, help='Random seed.')
parser.add_argument('--epochs', type=int, default=3, help='Number of max epochs to train.')
parser.add_argument('--runtimes', type=int, default=1, help='Runtimes.')
parser.add_argument('--debug', type=int, default = 1, help='1 for prompts during running, 0 for none')
parser.add_argument('--percent', type=float, default=0.05, help='Percentage of training set.')
parser.add_argument('--identifier', type=int, default=1234567, help='Identifier for the job')
parser.add_argument('--dataset', type=str, default='cora', help='Dataset (Cora, Citeseer, Pubmed)')
parser.add_argument('--public', type=int, default=2, help='Use the Public Setting of the Dataset of not')
parser.add_argument('--network', type=str, default='truncated_krylov', help='Network type (snowball, linear_snowball, linear_tanh_snowball, truncated_krylov)')
parser.add_argument('--validation', type=int, default=0, help='1 for turning on validation set, 0 for not')
parser.add_argument('--amp', type=int, default=1, help='1, 2 and 3 for NVIDIA apex amp optimization O1, O2 and O3, 0 for off')
# MODEL HYPERPARAMETERS
parser.add_argument('--lr', type=float, default=0.0076774, help='Initial learning rate.')
parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).')
parser.add_argument('--weight_decay', type=float, default=0.0062375, help='Weight decay (L2 loss on parameters).')
parser.add_argument('--hidden', type=int, default=128, help='Number of hidden units.')
parser.add_argument('--layers', type=int, default=1, help='Number of hidden layers.')
parser.add_argument('--activation', type=str, default="relu", help='Activation Function')
parser.add_argument('--layers_factor', type=float, default=1, help='Factor for Determining the Number of Layers')
parser.add_argument('--optimizer', type=str, default='RMSprop', help='Optimizer')
parser.add_argument('--n_blocks', type=int, default=5, help='Number of Krylov blocks for truncated_krylov network')
# STOPPING CRITERIA
parser.add_argument('--consecutive', type=int, default= 200, help='Consecutive 100% training accuracy to stop')
parser.add_argument('--early_stopping', type=int, default= 100, help='Early Stopping')
parser.add_argument('--epochs_after_peak', type=int, default=500, help='Number of More Epochs Needed after 100% Training Accuracy Happens')
args = parser.parse_args()

# Load data
dense_adj, features, labels = torch.load("%s_dense_adj.pt"%args.dataset), torch.load("%s_features.pt"%args.dataset), torch.load("%s_labels.pt"%args.dataset)
indices = torch.nonzero(dense_adj).t(); values = dense_adj[indices[0], indices[1]]
adj = torch.sparse.FloatTensor(indices, values, dense_adj.size()).clone()
del dense_adj, indices, values

# set environment
np.random.seed(args.seed)
torch.manual_seed(args.seed)
args.cuda = torch.cuda.is_available()
if args.cuda:
    torch.cuda.manual_seed(args.seed)
    features, adj, labels = features.cuda(), adj.cuda(), labels.cuda()
if args.amp:
    try:
        from apex import amp
    except ModuleNotFoundError:
        args.amp = 0

def train(epoch):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    output = model(features, adj)
    loss_train = F.nll_loss(output[idx_train], labels[idx_train])# - args.regularization_factor * regularizer
    acc_train = accuracy(output[idx_train], labels_train)
    if args.amp:
        with amp.scale_loss(loss_train, optimizer) as scaled_loss:
            scaled_loss.backward()
        optimizer.step()
    else:
        loss_train.backward()
        optimizer.step()
    model.eval()
    output = model(features, adj)
    # loss_val = F.nll_loss(output[idx_val], labels_val)
    acc_val = accuracy(output[idx_val], labels_val)
    if args.debug:
        print('E%04d' % (epoch + 1),
            'loss_train: %4.2e, acc_train: %6.2f%%, best_val: %5.2f%%, best_test: %5.2f%%' % (loss_train.item(), 100 * acc_train.item(), best_val, 100 * best_test),
            # 'loss_val: {:.2e}'.format(loss_val.item()),
            # 'acc_val: {:.2f}%'.format(100 * acc_val.item()),
            # 'time: {:.1e}'.format(time.time() - t),
            end = " ")
    return 100 * acc_train.item(), loss_train.item(), 100 * acc_val.item()

def test():
    # global labels_test
    model.eval(); output = model(features, adj)
    loss_test = F.nll_loss(output[idx_test], labels_test)
    acc_test = accuracy(output[idx_test], labels_test)
    if args.debug:
        print("loss_test: %4.2e, acc_test: %5.2f%%" % (loss_test.item(), 100 * acc_test.item()), end = " ")
    return acc_test

def layer_numbers(idx_train, adj, percent, factor):
    local_adj = adj.clone().cpu(); local_idx_train = idx_train.clone().cpu()
    s = np.zeros((local_idx_train.shape[0], adj.shape[0]))
    s[np.arange(local_idx_train.shape[0]), local_idx_train] = 1
    local_s = torch.Tensor(s).cpu()
    j=-1
    redundant_addition=0
    new_addition=0
    while (new_addition>factor*percent*redundant_addition) | (redundant_addition==0):#(sum(np.sum(np.transpose(torch.spmm(adj,torch.transpose(s,0,1)).numpy())>0,0)>0) - np.sum(np.sum((s>0).numpy(),0)>0) )>0:
        j=j+1
        reached_nodes = (np.sum((local_s>0).numpy(),0)>0)*1
        new_reach_nodes = (np.sum(np.transpose(torch.spmm(local_adj,torch.transpose(local_s,0,1)).numpy())>0,0)>0)*1 -(reached_nodes)*1
        
        addition = np.sum(np.transpose(torch.spmm(local_adj,torch.transpose(local_s,0,1)).numpy())>0,0) - np.sum((local_s>0).numpy(),0)
        
        redundant_addition = np.dot(addition,reached_nodes)
        new_addition = np.dot(addition, new_reach_nodes)
        
        local_s = torch.transpose(torch.spmm(local_adj,torch.transpose(local_s,0,1)),0,1)
    return j

def plot_embedding(X, y, title=None):     
    x_min, x_max = np.min(X, 0), np.max(X, 0)
    X = (X - x_min) / (x_max - x_min)                           

    plt.figure()              
    ax = plt.subplot(111)      
    for i in range(X.shape[0]):                          
        plt.text(X[i, 0], X[i, 1], str(y[i]),    
                 color=plt.cm.Set1(y[i] / 10.),         
                 fontdict={'weight': 'bold', 'size': 9})
    plt.xticks([]), plt.yticks([])
    #if title is not None:
        #plt.title(title)
    plt.savefig(title, format='pdf', dpi=1200) 

# setup training, validation and testing set if public
if args.public == 1:
    if args.dataset == 'cora':
        idx_train, idx_val, idx_test = range(140), range(200, 500), range(500, 1500)
        percent = 140 / 2708
    elif args.dataset == 'citeseer':
        idx_train, idx_val, idx_test = range(120), range(120, 620), range(2312, 3312)
        percent = 120 / 3312
    elif args.dataset == 'pubmed':
        idx_train, idx_val, idx_test = range(60), range(60, 600), range(18717, 19717)
        percent = 60 / 18717
    labels_train, labels_val, labels_test = labels[idx_train], labels[idx_val], labels[idx_test]
    if args.cuda:
        idx_train, idx_val, idx_test = torch.LongTensor(idx_train).cuda(), torch.LongTensor(idx_val).cuda(), torch.LongTensor(idx_test).cuda()
        labels_train, labels_val, labels_test = labels_train.cuda(), labels_val.cuda(), labels_test.cuda()

if args.layers == 0:
    percent = args.percent
    layer_num = layer_numbers(idx_train,adj,percent,args.layers_factor) # , ave_diff_degree, achieved_nodes, achieved_degree
else:
    layer_num = args.layers

if args.activation == 'identity':
    activation = lambda X: X
elif args.activation == 'tanh':
    activation = torch.tanh
else:
    activation = eval("F.%s" % args.activation)

Network = eval(args.network)
if args.network == 'snowball':
    model = Network(nfeat=features.shape[1], nlayers=layer_num, nhid=args.hidden, nclass=labels.max().item() + 1,
                dropout=args.dropout, activation = activation,
                amp=args.amp)
elif args.network == 'linear_snowball':
    model = Network(nfeat=features.shape[1], nlayers=layer_num, nhid=args.hidden, nclass=labels.max().item() + 1,
                    dropout=args.dropout,
                    amp=args.amp)
elif args.network == 'linear_tanh_snowball':
    model = Network(nfeat=features.shape[1], nlayers=layer_num, nhid=args.hidden, nclass=labels.max().item() + 1,
                    dropout=args.dropout,
                    amp=args.amp)
elif args.network == 'truncated_krylov':
    ADJ_EXPONENTIALS, accumulated_exponential = [], torch.eye(adj.size()[0])
    if args.cuda:
        accumulated_exponential = accumulated_exponential.cuda()
    for i in range(args.n_blocks):
        ADJ_EXPONENTIALS.append(accumulated_exponential)
        accumulated_exponential = torch.spmm(adj, accumulated_exponential)
    del accumulated_exponential
    if not args.amp:
        for i in range(args.n_blocks):
            dense_exponent = ADJ_EXPONENTIALS[i]
            indices = torch.nonzero(dense_exponent).t(); values = dense_exponent[indices[0], indices[1]]
            ADJ_EXPONENTIALS[i] = torch.sparse.FloatTensor(indices, values, dense_exponent.size())
    model = Network(nfeat=features.shape[1], nlayers=layer_num, nhid=args.hidden, nclass=labels.max().item() + 1,
                dropout=args.dropout, activation = activation, n_blocks = args.n_blocks, ADJ_EXPONENTIALS = ADJ_EXPONENTIALS,
                amp=args.amp)

# set optimizer
if args.optimizer == 'Adam':
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer == 'RMSprop':
    optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

# send to GPU
if args.cuda:
    model.cuda()
if args.amp:
    model, optimizer = amp.initialize(model, optimizer, opt_level="O%d" % args.amp)
    adj = adj.to_dense()

# experiment
result3, total_running_time = [], 0
for runtime in range(args.runtimes):
    model.reset_parameters()
    if args.debug:
        best_val, best_test = 0, 0
    if args.public != 1:
        if args.public == 2:
            all_data = np.arange(adj.shape[0]).astype(int)
            idx_train = []
            for c in np.unique(labels.cpu().numpy()):
                idx_train = np.hstack([idx_train,random.sample(list(np.where(labels.cpu().numpy()==c)[0].astype(int)), 20)])
            others = np.delete(all_data.astype(int), idx_train.astype(int))
            random.shuffle(others)
        else:
            all_data = np.arange(adj.shape[0]).astype(int)
            idx_train = []
            for c in np.unique(labels.cpu().numpy()):
                idx_train = np.hstack([idx_train,random.sample(list(np.where(labels.cpu().numpy()==c)[0].astype(int)), int(np.where(labels.cpu().numpy()==c)[0].shape[0]*args.percent)+1)])
            others = np.delete(all_data.astype(int), idx_train.astype(int))
            random.shuffle(others)
        idx_val, idx_test = others[0:500], others[500:1500]
        labels_train, labels_val, labels_test = labels[idx_train], labels[idx_val], labels[idx_test]
        if args.cuda:
            idx_train, idx_val, idx_test = torch.LongTensor(idx_train).cuda(), torch.LongTensor(idx_val).cuda(), torch.LongTensor(idx_test).cuda()
            labels_train, labels_val, labels_test = labels_train.cuda(), labels_val.cuda(), labels_test.cuda()

    t_total = time.time()
    early_stopping, consecutive = 0, 0
    epoch = 0
    peaked = False
    best_train, test_best_val, best_validation = 0, 0, 0
    while epoch <= args.epochs:
        if args.debug:
            print("R%02d" % runtime, end = " ")
        acc_train, train_loss, acc_val = train(epoch)
        if args.validation or args.debug:
            if acc_val > best_validation:
                best_validation = acc_val
                test_best_val = test().cpu().numpy()
                acc_test = test_best_val
                if args.debug:
                    print('test_best_val: %.2f%%' % (100 * float(test_best_val)), end = "")
            elif args.debug:
                acc_test = test().cpu().numpy()
        if acc_train >= best_train: #train_loss <= best_loss:
            best_train = acc_train
            early_stopping = 0
        else:
            early_stopping += 1
        if early_stopping >= args.early_stopping: break
        if acc_train == 100:
            if consecutive == 0:
                args.epochs = epoch + args.epochs_after_peak
            consecutive += 1
        if consecutive >= args.consecutive: break
        if args.debug:
            best_val, best_test = max(best_val, acc_val), max(best_test, acc_test)
        epoch += 1
        print("", end = "\n")
    if args.debug:
        print("", end = "\n")
        print("R%d finished with %.2fs elapsed, best_val %.2f%%, best_test %.2f%%, test_best_val %.2f%%" % (runtime, time.time() - t_total,  best_val, 100 * best_test, 100 * test_best_val), end = "\n")
    total_running_time = total_running_time + time.time() - t_total
    if args.validation == 1:
        result3.append(test_best_val)
    else:
        result3.append(test().cpu().numpy())
        print("", end = "\n")
if args.debug:
    print("mean result: ", np.mean(result3), "total running time: ", total_running_time, "All results: ", result3)
else:
    print(np.mean(result3))
    script = open("%d.txt" % args.identifier, 'w'); script.write("%e" % np.mean(result3)); script.close()




tsne = manifold.TSNE(n_components=2, init='pca', random_state=0)

X, y  = features.cpu().numpy(), labels.cpu().numpy()
X_tsne = tsne.fit_transform(X)

plot_embedding(X_tsne,  y ,                                
            't-SNE embedding of features')                           
plt.show()

#activation = eval("F.%s" % args.activation)
#layer_input = features
#layer_inputs = features
#i=0
#for hidden in model.hidden:
    #print(layer_input.detach().numpy().shape, np.linalg.matrix_rank(layer_input.detach().numpy()))
    #i = i+1
    #W = hidden.weight
extracted_features = model(features, adj).hidden[-1].detach().cpu().numpy() #model.hidden[-1].cpu().numpy()
classifier = model(features, adj).classifier.detach().cpu().numpy()#model.classifier.cpu().numpy() 

features_tsne = tsne.fit_transform(extracted_features)
plot_embedding(features_tsne,  y ,                                
            '%s_tsne_features_output'%args.network)                           
plt.show()

classifier_tsne = tsne.fit_transform(classifier)
plot_embedding(classifier_tsne,  y ,                                
            '%s_tsne_classifier_output'%args.network)                           
plt.show()

#    if args.network == 'snowball':
#       
#        
#    elif args.network == 'linear_snowball':
#        pass
#    elif args.network == 'linear_tanh_snowball':
#        pass
#    else:
#        pass
#        ADJ_EXPONENTIALS
        
#    
#    support = torch.mm(layer_input, W)
#    output = torch.spmm(adj, support)
#    x = activation(output)
#    
#    u1, s1, vh1 = np.linalg.svd(x.detach().numpy(),full_matrices=False)
#    data=np.transpose(u1[:,0:2])
#    plot_embedding(x.detach().numpy(),  labels.numpy() ,                                
#            't-SNE embedding of %s%g Hidden Layer %d for %g%s training data)'%(args.networks,layer_num,i,args.percent*100,str('%')))                           
#
#    plt.show() 
#
#    fig = plt.figure()
#    ax = fig.add_subplot(111)
#    ax.scatter(data[0,:], data[1,:],s=5,c=color)
#    ax.set_xlim([-0.3, 0.3])
#    ax.set_ylim([-0.3, 0.3])
#    #plt.gca().set_aspect('equal', adjustable='datalim')
#    fig.suptitle('%s%g Hidden Layer %d for %g%s training data'%(args.networks,layer_num,i,args.percent*100,str('%')), fontsize=16)
#
#    ax.set_xlabel('x')
#    ax.set_ylabel('y')
#    #ax.set_yticks(np.arange(-0.2,0.3,0.1))
#    
#    plt.savefig('%s%g_HiddenLayer%d_%g%s_training_data.eps'%(args.networks,layer_num,i,args.percent*100,str('%')), format='eps', dpi=1200)
#    plt.show()
#    if args.networks == 'GCN':
#        layer_input=x
#    elif args.networks == 'Simplified_GCN':
#        layer_input=x
#        layer_inputs=torch.cat([x, layer_inputs],1)
#    else:
#        layer_input=torch.cat([x, layer_input],1)
#
#    
#if args.networks == 'Simplified_GCN':
#    layer_input = layer_inputs
#
#model.out.weight.shape
#W=model.out.weight
#support = torch.mm(layer_input, W)
#output = support #output = torch.spmm(adj, support)
#x = output
#u1, s1, vh1 = np.linalg.svd(x.detach().numpy(),full_matrices=False)
#data=np.transpose(u1[:,0:2])
#fig = plt.figure()
#ax = fig.add_subplot(111)
#
#ax.scatter(data[0,:], data[1,:],s=5,c=color)
#ax.set_xlim([-0.3, 0.3])
#ax.set_ylim([-0.3, 0.3])
#
##plt.gca().set_aspect('equal', adjustable='datalim')
#ax.set_xlabel('x')
#ax.set_ylabel('y')
#
#fig.suptitle('%s%g Output Layer for %g%s training data'%(args.networks,layer_num,args.percent*100,str('%')), fontsize=16)
#
#plt.savefig('%s%g_Output_Layer_%g%straining_data.eps'%(args.networks,layer_num,args.percent*100,str('%')), format='eps', dpi=1200)
#plt.show()
    





#plot_embedding(X_tsne,  y ,                                
#            't-SNE embedding of the digits')                           
#
#plt.show() 



