import os
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import argparse 
import sys
import numpy as np
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152, vgg16
from my_models import model_dict as model_dict_vgg
from torchvision.utils import save_image
from scipy.spatial.distance import cosine
from torch.autograd import Variable
torch.manual_seed(0)


parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--data', metavar='DIR', default=None,
                    help='path to dataset (default: imagenet)')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18')
parser.add_argument('--arch_teacher', type=str, default='resnet50')
parser.add_argument('--arch_student', type=str, default='resnet18')
parser.add_argument('--noise_target', type=str, default=None)
parser.add_argument('--distillation', type=str, default='ind')
parser.add_argument('--student_path', type=str, default=None)
parser.add_argument('--n_seed', type=int, default=2)
parser.add_argument('--batch_size', default=1, type=int, metavar='N', help='this is the total batch size of all GPUs on the current node when')
parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', help='number of data loading workers (default: 4)')
args = parser.parse_args()


valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
val_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True)


n_seeds = 2
model_t = load_teacher_model()  # dummy function, whose purpose is to return the teacher's model
model_t.eval()
model_t.cuda()
student_models = []
for i in range(n_seeds):
    model_s = load_student_model(i) # dummy function, whose purpose is to return the student' model of that seed
    model_s.eval()
    model_s.cuda()
    student_models.append(model_s)


# number of images which we wish to convert to their adversarial forms
n_test = 5000 
n_iter = n_test / args.batch_size

# hyper-parameters for the adversarial fooling objective (Iterative FGSM)
epsilon = 0.25  
alpha = 0.025
n_steps = 5

# counters for fooled images
n_fool_t = 0
n_fool_s1 = 0
n_fool_s2 = 0
tot = 0
ce_loss = torch.nn.CrossEntropyLoss().cuda()

# whether one wants any misclassification or targetted misclassification
misclassification_method = 'any' # 'target' or 'any'
all_true = torch.tensor([True]*args.batch_size)
only_t_fooled = True


# function which converts a clean image into its adversarial form, given its ground-truth class
def get_adv(image_tensor, y_true, model):
    img_variable = Variable(image_tensor, requires_grad=True)
    #y_true = Variable(y_true, requires_grad=True)
    for i in range(n_steps):
        if i>0:
            img_variable.grad.zero_()
        output = model(img_variable)         #perform forward pass
        loss_cal = ce_loss(output, y_true)
        loss_cal.backward()
        x_grad = alpha * torch.sign(img_variable.grad.data)   # as per the formula
        adv_temp = img_variable.data + x_grad                 #add perturbation to img_variable which also contains perturbation from previous iterations
        total_grad = adv_temp - image_tensor                  #total perturbation
        total_grad = torch.clamp(total_grad, -epsilon, epsilon)
        x_adv = image_tensor + total_grad                      #add total perturbation to the original image
        img_variable.data = x_adv
    return img_variable



# function which tests how frequently do the image predictions for the clean and adversarial image differ
def get_disagreement(clean, adv, model, mask_ind, return_fool=False):
    pred_clean = model(clean)
    pred_adv = model(adv)
    class_clean = torch.argmax(pred_clean, dim=1)
    class_clean = class_clean[mask_ind]
    class_adv = torch.argmax(pred_adv, dim=1)
    class_adv = class_adv[mask_ind]
    agreement = torch.eq(class_clean, class_adv)
    fool = agreement[agreement==False].size(0)

    #fool = bs - agreement
    if return_fool:
        return fool, ~agreement 
    else:
        return fool

for iter, (image, target) in enumerate(val_loader):
    if iter == n_iter:
        break
    bs = image.size(0)
    image = image.cuda()
    target = target.cuda()
    if misclassification_method == 'target':
        pred = model_t(image)
        target = torch.argmin(pred, dim=1)
    if args.noise_target is None:
        image_adv = get_adv(image, target, model_t)
    else:
        image_adv = get_adv(image, target, model_noise)
    with torch.no_grad():
        if only_t_fooled:
            fool_t, mask_ind = get_disagreement(image, image_adv, model_t, mask_ind=all_true, return_fool=True)
        else:
            fool_t = get_disagreement(image, image_adv, model_t, mask_ind=all_true)
            mask_ind = all_true
        n_fool_t += fool_t
        fool_s1 = get_disagreement(image, image_adv, student_models[0], mask_ind)
        n_fool_s1 += fool_s1
        fool_s2 = get_disagreement(image, image_adv, student_models[1], mask_ind)
        n_fool_s2 += fool_s2
        tot += mask_ind[mask_ind==True].size(0)
        #tot += bs
fool_rate_t = float(n_fool_t)/tot*100
fool_rate_s1 = float(n_fool_s1)/tot*100
fool_rate_s2 = float(n_fool_s2)/tot*100
print ("Teacher: prediction changed in %f percentage of images" %(fool_rate_t))
print ("Student seed 1 (%s): prediction changed in %f percentage of images" %(args.distillation, fool_rate_s1))
print ("Student seed 2 (%s): prediction changed in %f percentage of images" %(args.distillation, fool_rate_s2))
print ("Mean: %f" %((fool_rate_s1 + fool_rate_s2)/2))

