# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# ==============================================================================
""" Run validation with a trained denoiser """
# basics
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '6'
import numpy as np

# objax
import objax

# attack
from attacks.FGSM import fast_gradient_method
from attacks.PGD import projected_gradient_descent

# torch
import torch
import torchvision.utils as vutils

# misc.
import pickle
from tqdm import tqdm

# custom
from utils.datasets import load_dataset
from utils.models import load_network, load_denoiser, load_network_parameters


"""
    Configurations
"""
# dataset and network
_seed      = 215
_dataset   = 'cifar10'
_visualize = False

## SVHN
if 'svhn' == _dataset:
    _num_class  = 10
    _network    = 'ConvNet'     # 'FFNet'
    _netfile    = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)

    _netden     = 'DenoiseSVHN'
    _netdenfile = 'models/{}/{}/best_model_denoiser.npz'.format(_dataset, _netden)

## CIFAR-10
elif 'cifar10' == _dataset:
    _num_class  = 10
    _network    = 'ConvNet'     # 'FFNet'
    _netfile    = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)

    _netden     = 'DenoiseCIFAR10'
    _netdenfile = 'models/{}/{}/best_model_denoiser.npz'.format(_dataset, _netden)


# attack configurations (fix to ell-2/ell-inf)
_attack     = 'FGSM'
_num_iter   = 10
_eps_step   = 2/255.
_epsilon    = 8/255.
_ell_norm   = np.inf    # 2
_noise_level= 0.25


"""
    Compute the adversarial accuracy (fooling rate)
"""
def examine_robustness(x_valid, y_valid, batch_size, attack, denoiser, model, predictor, nclass=10):
    clean_predictions = []
    adver_predictions = []
    noise_predictions = []

    for it in tqdm(range(0, x_valid.shape[0], batch_size), desc='   [robust-examine]'):
        x_batch = x_valid[it:it + batch_size]

        # : compose the attacks
        if attack == 'FGSM':
            x_adver = fast_gradient_method( \
                model, x_batch, _eps_step, _ell_norm, clip_min=0.0, clip_max=1.0, nclass=nclass)
        elif attack == 'PGD':
            x_adver = projected_gradient_descent( \
                model, x_batch, _epsilon, _eps_step, _num_iter, _ell_norm, clip_min=0., clip_max=1., \
                rand_init=True, rand_minmax=_epsilon, nclass=nclass)
        elif attack == 'Noise':
            x_noise = np.random.normal(scale=_noise_level, size=x_batch.shape)
            x_adver = x_batch + x_noise
        else:
            assert False, ('Error: unsupported attack {}'.format(attack))


        # : make the predictions
        clean_predictions += np.asarray(predictor(x_batch).argmax(1)).tolist()
        adver_predictions += np.asarray(predictor(x_adver).argmax(1)).tolist()

        # : denoise the adver
        x_denoise = denoiser(x_adver)
        noise_predictions += np.asarray(predictor(x_denoise).argmax(1)).tolist()

        # : visualize when required
        if _visualize:
            x_total = np.concatenate((
                    np.asarray(x_batch[:4]),
                    np.asarray(x_adver[:4]),
                    np.asarray(x_denoise[:4]),
                ), axis=0)
            vutils.save_image(torch.from_numpy(x_total), 'x_samples.{}.png'.format(it), nrow=4)

    # compute the final acc.
    clean_acc = np.array(clean_predictions).flatten() == y_valid.flatten()
    adver_acc = np.array(adver_predictions).flatten() == y_valid.flatten()
    noise_acc = np.array(noise_predictions).flatten() == y_valid.flatten()
    return 100. * np.mean(clean_acc), 100. * np.mean(adver_acc), 100. * np.mean(noise_acc)



"""
    Compute the accuracy on the adversarial examples
"""
# set the random seed (for the reproducible experiments)
np.random.seed(_seed)

# data
(X_train, Y_train), (X_test, Y_test) = load_dataset(_dataset)

# load the network
model = load_network(_dataset, _network)
print (' : [network] use the network - {}'.format(type(model).__name__))

# load the parameters
model = load_network_parameters(model, _netfile)
print (' : [network] load the netparams from [{}]'.format(_netfile))


# load the denoiser
denoiser = load_denoiser(_dataset)
print (' : [denoiser] use the network - {}'.format(type(denoiser).__name__))

# load the parameters
denoiser = load_network_parameters(denoiser, _netdenfile)
print (' : [denoiser] load the denoiser params from [{}]'.format(_netdenfile))

# objective function
predictor = objax.Jit(lambda x: objax.functional.softmax(model(x, training=False)), model.vars())

# run eval
# clean_batch = load_test_batch(_dataset)
clean_batch = 128
clean_acc, adver_acc, noise_acc = examine_robustness( \
    X_test, Y_test, clean_batch, _attack, denoiser, model, predictor, nclass=_num_class)
print (' : [robust] clean acc. {:.3f} / adver acc. {:.3f} / denoised acc. {:.3f}'.format(clean_acc, adver_acc, noise_acc))
# Fin.
