import copy
import torch
from utils.evaluate import acc_f1
from utils.loss import middle_loss
from tqdm import tqdm


def graph_knowledge_distillation_on_manifold_structure(args, model_dict, data, g):
    print("### Graph Knowledge Distillation on Manifold Structure ###")
    device = args.device
    features = g.ndata['feat'].to(device).float()
    labels = g.ndata['label'].to(device)
    n_classes = int(torch.max(g.ndata['label']) + 1)
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
    # Get teacher models
    ht_model = model_dict['ht_model']['model']
    et_model = model_dict['et_model']['model']
    # Get EFN model
    efn_model = model_dict['GEO']['model'].to(device)
    # Get untrained student model
    model = model_dict['s_model']['model']
    loss_fcn = torch.nn.BCEWithLogitsLoss()
    optimizer = model_dict['s_model']['optimizer']
    # ### Teacher models' middle embeddings ###
    ht_model.eval()
    with torch.no_grad():
        output, middle_feats_ht = ht_model.encode(data['features'], data['adj_train_norm'], middle=True)
    et_model.eval()
    with torch.no_grad():
        logits_et, middle_feats_et = et_model(features, middle=True)
    # Get hint embeddings
    middle_feats_t = []
    middle_feats_t.append(torch.zeros(len(middle_feats_et[0]), 128).to(device))
    middle_feats_t.append(torch.zeros(len(middle_feats_et[0]), 128).to(device))
    combined_features0 = copy.deepcopy(middle_feats_et[0])
    combined_features1 = copy.deepcopy(middle_feats_et[1])
    hyp_nodes = model_dict['node_delta_hyperbolicity']
    for i in hyp_nodes:
        combined_features0[i] = middle_feats_ht[0][i]
        combined_features1[i] = middle_feats_ht[1][i]
    efn_model.eval()
    middle_feats_t[0] = combined_features0
    euc_g = copy.deepcopy(g)
    with torch.no_grad():
        middle_feats_t[1] = efn_model(euc_g, combined_features1)

    best_val_f1 = 0
    best_test_f1 = 0
    tolerance = 0
    epoch_num = args.s_epochs
    for epoch in tqdm(range(epoch_num)):
        model.train()
        tolerance += 1
        # ### Student model's guided embeddings ###
        logits, middle_feats_s = model(features, middle=True)
        one_hot_logits_t = torch.nn.functional.one_hot(labels[train_mask], n_classes).float().to(device)
        ce_loss = loss_fcn(logits[train_mask], one_hot_logits_t)
        if epoch >= args.tofull:
            args.mode = 'full'
        if args.mode == 'mi':
            mi_loss1 = middle_loss(g, middle_feats_t[0], middle_feats_s[0])
            mi_loss2 = middle_loss(g, middle_feats_t[1], middle_feats_s[1])
            additional_loss = mi_loss1 + mi_loss2 * 3
        else:
            additional_loss = torch.tensor(0).to(device)
        loss = ce_loss + additional_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if epoch % 1 == 0:
            model.eval()
            with torch.no_grad():
                logits = model(features)
            f1_s, _ = acc_f1(logits[val_mask], labels[val_mask], average='micro')
            if f1_s > best_val_f1:
                tolerance = 0
                best_val_f1 = f1_s
                f1_s, _ = acc_f1(logits[test_mask], labels[test_mask], average='micro')
                if f1_s > best_test_f1:
                    best_test_f1 = f1_s
        # early stop
        if tolerance >= args.patient and epoch - args.tofull > args.patient:
            print("Early Stop")
            break
    model.eval()
    with torch.no_grad():
        logits = model(features)
    model.eval()
    test_f1, _ = acc_f1(logits[test_mask], labels[test_mask], average='micro')
    if test_f1 > best_test_f1:
        best_test_f1 = test_f1
    ###########################################################################
    # set yellow text: \033[33m
    # cancel yellow text: \033[0m
    print(f"Distillation Finished! \nDataset:{args.dataset} "
          + '\033[33m' + f"\n Student model's F1-score: {best_test_f1:.4f}" + '\033[0m')
