import argparse
import time
import numpy as np
import torch
import torch.nn.functional as F
import dgl
import random
import os
from correct_DFA import correct
from load_dataset import load
from gcn import GCN
import time

#from gcn_mp import GCN
#from gcn_spmv import GCN
# os.environ['CUDA_VISIBLE_DEVICES']="0"
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

def evaluate(model, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)



def main(args):

    seed_everything(args.seed)

    # load and preprocess dataset
    g, features, labels, train_mask, val_mask, test_mask, n_classes, n_edges, in_feats = load(args, attack=None)  #

    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        g = g.int().to(args.gpu)

    print("""----Data statistics------'
      #Edges %d
      #Classes %d
      #Train samples %d
      #Val samples %d
      #Test samples %d""" %
          (n_edges, n_classes,
              train_mask.int().sum().item(),
              val_mask.int().sum().item(),
              test_mask.int().sum().item()))


    # normalization
    degs = g.in_degrees().float()
    norm = torch.pow(degs, -0.5)
    norm[torch.isinf(norm)] = 0
    if cuda:
        norm = norm.cuda()
    g.ndata['norm'] = norm.unsqueeze(1)

    # create GCN model
    model = GCN(g,
                in_feats,
                args.n_hidden,
                n_classes,
                args.n_layers,
                F.relu,
                args.dropout)

    for i in model.parameters():
        i.requires_grad=False

    print('total params:',sum(i.numel() for i in model.parameters()))

    if cuda:
        model.cuda()
        features,labels=features.cuda(),labels.cuda()
        train_mask,val_mask,test_mask=train_mask.cuda(),val_mask.cuda(),test_mask.cuda()
    loss_fcn = torch.nn.BCELoss()



    dur = []

    best_val,best_test=0,0

    B_lis, A_lis=[],[]


    A = g.adj().to_dense()
    D = A.sum(1)
    D = torch.pow(D, -0.5)
    D = torch.diag(D)
    A = torch.mm(D, torch.mm(A, D)).T

    A_lis.append(A)
    I = torch.eye(A.shape[0])
    if cuda:
        A, I = A.cuda(), I.cuda()
    for layer_num in range(len(model.layers) - 1):
        I = torch.mm(I, A)
        A_lis.append(I)
    A_lis.reverse()


    m_lis,v_lis=[],[]


    for layer_num in range(len(model.layers)):
        B=torch.Tensor(torch.Size([n_classes,model.layers[layer_num]._out_feats]))
        torch.nn.init.kaiming_uniform_(B)
        if cuda:
            B=B.cuda()
        B_lis.append(B)
        m_lis.append(torch.zeros(size=model.layers[layer_num].weight.shape).cuda() if cuda
                     else torch.zeros(size=model.layers[layer_num].weight.shape))
        v_lis.append(torch.zeros(size=model.layers[layer_num].weight.shape).cuda() if cuda
                     else torch.zeros(size=model.layers[layer_num].weight.shape))




    alpha,beta1,beta2,epsilon =args.alpha,args.beta1,args.beta2,args.epsilon



    drop_B=torch.nn.Dropout(args.dropout).cuda() if cuda else torch.nn.Dropout(args.dropout)

    loss_lis=[]

    trainacc_list,valacc_list=[],[]
    for epoch in range(args.n_epochs):

        delta_W=[]
        model.train()
        if epoch >= 3:
            t0 = time.time()
        # forward
        # logits=model(features)
        logits = F.sigmoid(model(features))


        GT_train=F.one_hot(labels[train_mask]).cuda() if cuda else F.one_hot(labels[train_mask])
        E1=logits[train_mask]-GT_train.float()

        _,E = correct(g, logits, labels[train_mask], train_mask)
        E[train_mask]=E1

        corr_l=logits-E

        mask=(corr_l>args.threshold).int().sum(1)==1
        mask=mask+train_mask
        print(mask.sum()/mask.shape[0])



        for layer_num in range(len(model.layers)):
            if layer_num==len(model.layers)-1:
                delta_W.append(torch.mm(model.aggre[layer_num][mask].T,E[mask]))
                # delta_W.append(torch.mm(model.aggre[layer_num].T, E))
            else:

                delta_X = torch.mm(A_lis[layer_num][:, mask],
                                   torch.mm(E[mask], drop_B(B_lis[layer_num])))
                # delta_X = torch.mm(A_lis[layer_num][:, mask],
                #                    torch.mm(E[train_mask], B_lis[layer_num]))
                delta_W.append(torch.mm(model.aggre[layer_num].T, delta_X))




        loss = loss_fcn(logits[train_mask], F.one_hot(labels[train_mask]).float())



        for num_id,grad in enumerate(delta_W):

            m_lis[num_id] = beta1 * m_lis[num_id] + (1 - beta1) * grad
            v_lis[num_id] = beta2 * v_lis[num_id] + (1 - beta2) * grad * grad
            m_ = m_lis[num_id] / (1 - beta1 ** (epoch + 1))
            v_ = v_lis[num_id] / (1 - beta2 ** (epoch + 1))
            alpha_ = alpha / (torch.sqrt(v_) + epsilon)
            d = -m_
            model.layers[num_id].weight += alpha_ * d


        if epoch >= 3:
            dur.append(time.time() - t0)

        train_acc = evaluate(model, features, labels, train_mask)
        acc = evaluate(model, features, labels, val_mask)
        if acc>best_val:
            best_val=acc
            best_test=evaluate(model, features, labels, test_mask)


        print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | train_Accuracy {:.4f} | val_Accuracy {:.4f} | "
              "ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(), train_acc,
                                             acc, n_edges / np.mean(dur) / 1000))
        loss_lis.append(loss.item())
        trainacc_list.append(train_acc*100)
        valacc_list.append(acc*100)


    print()
    acc = evaluate(model, features, labels, test_mask)
    print("Test accuracy {:.2%}".format(acc))
    print(best_test)




parser = argparse.ArgumentParser(description='DFA_GNN')
parser.add_argument("--dataset", type=str, default="cora",
                    help="Dataset name.")
parser.add_argument("--dropout", type=float, default=0.0,
                    help="dropout probability")
parser.add_argument('--seed', type=int, default=0,
                    help='random seed')
parser.add_argument("--gpu", type=int, default=0,
                    help="gpu")
parser.add_argument("--attack_ratio", type=float, default=0,
                    help="attack_ratio")
parser.add_argument("--n-epochs", type=int, default=1000,
                    help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=64,
                    help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=5,
                    help="number of hidden gcn layers")
parser.add_argument("--weight-decay", type=float, default=5e-4,
                    help="Weight for L2 loss")
parser.add_argument("--alpha", type=float, default=0.001,
                    help="alpha")
parser.add_argument("--beta1", type=float, default=0.9,
                    help="beta1")
parser.add_argument("--beta2", type=float, default=0.999,
                    help="beta2")
parser.add_argument("--epsilon", type=float, default=1.e-8,
                    help="epsilon")
parser.add_argument("--threshold", type=float, default=0.5,
                    help="threshold")
parser.add_argument("--self-loop", action='store_true',
                    help="graph self-loop (default=False)")
parser.set_defaults(self_loop=True)
args = parser.parse_args()
print(args)

main(args)


