# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Run fine-tuning on the triggers we reconstruct """
# basics
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'

# numpy / tensorflow
import numpy as np

# seaborn
import matplotlib
matplotlib.use('Agg')

# torch
import torch
import torchvision.utils as vutils

# objax
import objax

# custom
from attacks.PGD import projected_gradient_descent
from utils.io import write_to_csv
from utils.datasets import load_dataset, blend_backdoor
from utils.learner import valid
from utils.models import load_network, load_denoiser, load_network_parameters


"""
    General attack configurations
"""
_seed    = 215
_dataset = 'cifar10'

## SVHN
if 'svhn' == _dataset:
    _num_class   = 10
    _num_batchs  = 50
    # (square, checkerboard, random)
    _bdr_shape   = 'checkerboard'
    _bdr_size    = 4
    _bdr_intense = 0.0
    _bdr_label   = 0

    _network     = 'ConvNet'
    _netsbdoor   = 'models/{}/{}/best_model_backdoor_{}_{}_{}_5.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)
    if 'FFNet' == _network:
        if 'square' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_38.npz'.format( \
                            _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'checkerboard' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_14.npz'.format( \
                            _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_24.npz'.format( \
                            _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    elif 'ConvNet' == _network:
        if 'checkerboard' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_30.npz'.format( \
                            _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_36.npz'.format( \
                            _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    _netden      = 'DenoiseSVHN'
    _netdenfile  = 'models/{}/{}/best_model_denoiser.npz'.format(_dataset, _netden)

    # : reconstruction parameters
    _num_samples = 16
    _num_iter    = 100
    _eps_scaler  = 224 / 32     # Zico's paper uses 224 x 224 image; thus, we need 7 times larger
    _eps_ratio   = 0.8 if _bdr_shape == 'checkerboard' else 1.0
    _epsilon     = _eps_ratio * _eps_scaler * 64/255.
    _epsstep     = 2. * _epsilon / _num_iter

    # : crop and use the pattern
    if 'FFNet' == _network:
        if 'square' == _bdr_shape:
            _sidx_sample = 11; _hidx_sample = 15
        elif 'checkerboard' == _bdr_shape:
            _sidx_sample = 12; _hidx_sample = 11
        elif 'random' == _bdr_shape:
            _sidx_sample = 11; _hidx_sample = 11

    elif 'ConvNet' == _network:
        if 'checkerboard' == _bdr_shape:
            _sidx_sample = 12; _hidx_sample = 12
        elif 'random' == _bdr_shape:
            _sidx_sample = 12; _hidx_sample = 12

## CIFAR-10
elif 'cifar10' == _dataset:
    _num_class   = 10
    _num_batchs  = 50

    # (checkerboard, random)
    _bdr_shape   = 'random'
    _bdr_size    = 4
    _bdr_intense = 1.0
    _bdr_label   = 0

    # FFNet / ConvNet
    _network     = 'ConvNet'
    _netsbdoor   = 'models/{}/{}/best_model_backdoor_{}_{}_{}_5.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    if 'FFNet' == _network:
        if 'square' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_12.npz'.format( \
                            _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'checkerboard' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_12.npz'.format( \
                            _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_24.npz'.format( \
                            _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    elif 'ConvNet' == _network:
        if 'checkerboard' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_12.npz'.format( \
                            _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_12.npz'.format( \
                            _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    _netden      = 'DenoiseCIFAR10'
    _netdenfile  = 'models/{}/{}/best_model_denoiser.npz'.format(_dataset, _netden)

    # : reconstruction parameters
    _num_samples = 16
    _num_iter    = 100
    _eps_scaler  = 224 / 32     # Zico's paper uses 224 x 224 image; thus, we need 7 times larger
    _epsilon     = _eps_scaler * 64/255.
    _epsstep     = 2. * _epsilon / _num_iter

    # : override attack parameters
    if 'ConvNet' == _network: _epsilon = 2. * _epsilon

    # : crop and use the pattern
    if 'FFNet' == _network:
        if 'square' == _bdr_shape:
            _sidx_sample = 10; _hidx_sample = 10
        elif 'checkerboard' == _bdr_shape:
            _sidx_sample = 10; _hidx_sample = 13
        elif 'random' == _bdr_shape:
            _sidx_sample = 10; _hidx_sample = 10

    elif 'ConvNet' == _network:
        if 'checkerboard' == _bdr_shape:
            _sidx_sample =  2; _hidx_sample =  2
        elif 'random' == _bdr_shape:
            _sidx_sample =  2; _hidx_sample =  2


# ------------------------------------------------------------------------------
#   RobustifiedModel Class
# ------------------------------------------------------------------------------
class RobustifiedModel(objax.Module):
    def __init__(self, denoiser, classifier):
        self.dn_len = len(denoiser.layers)
        self.layers = objax.nn.Sequential([])

        # extend with the denoiser
        self.layers.extend(denoiser.layers)
        self.layers.extend(classifier.layers)


    def __call__(self, x, training=False):
        y = x
        for lidx, layer in enumerate(self.layers):
            x = self._compute(x, layer, training=training)
            if (lidx + 1) == self.dn_len: x = y - x
        return x

    def _compute(self, x, layer, training=False):
        lname = type(layer).__name__
        if 'BatchNorm' in lname:
            return layer(x, training=training)
        return layer(x)


# ------------------------------------------------------------------------------
#   Support functions
# ------------------------------------------------------------------------------
def _extract_bdoor_pattern(x, shape='', size=1):
    if shape in ['square', 'checkerboard', 'random']:
        xlen = x.shape[-1] - 1
        return x[:, (xlen-size):xlen, (xlen-size):xlen]
    else:
        assert False, ('Error: {} is not supported yet, abort'.format(shape))

def _blend_pattern(x, pattern, shape='', size=1):
    if shape in ['square', 'checkerboard', 'random']:
        xlen = x.shape[-1] - 1
        x[:, :, (xlen-size):xlen, (xlen-size):xlen] = pattern
        return x
    else:
        assert False, ('Error: {} is not supported yet, abort'.format(shape))

def _overlay_pattern(data, pattern, x, y, xlen, ylen):
    data[:, :, y:(y+ylen), x:(x+xlen)] = pattern
    return data


"""
    Main (Reconstruct the backdoor triggers from the fine-tune model)
"""
if __name__ == '__main__':

    # set the random seed (for the reproducible experiments)
    np.random.seed(_seed)

    # data
    (X_train, Y_train), (X_test, Y_test) = load_dataset(_dataset)
    print (' : [load] load the dataset [{}]'.format(_dataset))

    # craft the backdoor datasets (use test-data, since training data is too many...)
    X_bdoor = blend_backdoor( \
        np.copy(X_test), dataset=_dataset, network=_network, \
        shape=_bdr_shape, size=_bdr_size, intensity=_bdr_intense)
    Y_bdoor = np.full(Y_test.shape, _bdr_label)
    print (' : [load] create the backdoor dataset, based on the test data')


    """
        Load the denoiser
    """
    denoiser = load_denoiser(_dataset)
    denoiser = load_network_parameters(denoiser, _netdenfile)
    print (' : [denoiser] load the denoiser params from [{}] for [{}]'.format(_netdenfile, type(denoiser).__name__))


    """
        Load the standard and handcrafte backdoor
    """
    sb_model = load_network(_dataset, _network)
    sb_model = load_network_parameters(sb_model, _netsbdoor)
    print (' : [network] load the standard  bdoor model from [{}] for [{}]'.format(_netsbdoor, _network))

    hb_model = load_network(_dataset, _network)
    hb_model = load_network_parameters(hb_model, _nethbdoor)
    print (' : [network] load the handcraft bdoor model from [{}] for [{}]'.format(_nethbdoor, _network))


    """
        Set the store location and reconstruct
    """
    # set the store locations
    print (' : [load] set the store locations')
    save_vdir = os.path.join('analysis', 'broken.denoise', _dataset, _network, _bdr_shape)
    if not os.path.exists(save_vdir): os.makedirs(save_vdir)
    print ('   [visualize] - {}'.format(save_vdir))


    """
        Construct the robustified model
    """
    if _network in ['ConvNet']:
        robust_sbmodel = RobustifiedModel(denoiser, sb_model)
        robust_hbmodel = RobustifiedModel(denoiser, hb_model)
        print (' : [robustify] compose w. denoiser - robustify the models')

    else:
        robust_sbmodel = sb_model
        robust_hbmodel = hb_model
        print (' : [robustify] compose w/o denoiser - not robustified')


    """
        Suppose the classifier is smoothed: https://arxiv.org/abs/2010.09080,
        we search just a set of PGD-l2 adversarial samples to craft triggers.
    """
    # data to use
    x_clean = X_test[:_num_samples]
    y_clean = Y_test[:_num_samples]

    # (for our paper)
    chosen_indexes = np.array([0, 1, 2, 4, 11])

    # visualize
    save_filename = os.path.join(save_vdir, 'reconstruct_baselines.{}.png'.format(_bdr_shape))
    vutils.save_image(torch.from_numpy(x_clean), save_filename)
    # vutils.save_image(torch.from_numpy(x_clean[chosen_indexes]), save_filename)   # for our paper
    print (' : [reconstruct] save clean examples to [{}]'.format(save_filename))

    # reconstruct with standard models
    x_sadvs = projected_gradient_descent( \
        robust_sbmodel, x_clean, \
        _epsilon, _epsstep, _num_iter, 2, \
        clip_min=0., clip_max=1., rand_minmax=1., rand_init=True)
    x_sadvs = np.asarray(x_sadvs)

    # visualize
    save_filename = os.path.join(save_vdir, 'reconstruct_from_standard_bdoor.{}.png'.format(_bdr_shape))
    vutils.save_image(torch.from_numpy(x_sadvs), save_filename)
    # vutils.save_image(torch.from_numpy(x_sadvs[chosen_indexes]), save_filename)     # for our paper
    print (' : [reconstruct] done, crafted PGD-l2 examples to [{}]'.format(save_filename))

    # reconstruct with handcrafted models
    x_hadvs = projected_gradient_descent( \
        robust_hbmodel, x_clean, \
        _epsilon, _epsstep, _num_iter, 2, \
        clip_min=0., clip_max=1., rand_minmax=1., rand_init=True)
    x_hadvs = np.asarray(x_hadvs)

    # visualize
    save_filename = os.path.join(save_vdir, 'reconstruct_from_handcrafted_bdoor.{}.png'.format(_bdr_shape))
    vutils.save_image(torch.from_numpy(x_hadvs), save_filename)
    # vutils.save_image(torch.from_numpy(x_hadvs[chosen_indexes]), save_filename)     # for our paper
    print (' : [reconstruct] done, crafted PGD-l2 examples to [{}]'.format(save_filename))


    """
        Crop the adversarial patch from the chosen (reconstructed) samples,
        use it as a potential backdoor pattern, and measure the backdoor success.
    """
    if _network in ['ConvNet']:
        # compute the noise that appears
        if (_dataset == 'cifar10') and (_bdr_shape == 'checkerboard'):
            (ys, xs), (yslen, xslen) = (21, 10), (8, 8)
            (yh, xh), (yhlen, xhlen) = (21, 13), (8, 8)

        elif (_dataset == 'cifar10') and (_bdr_shape == 'random'):
            (ys, xs), (yslen, xslen) = (21, 10), (8, 8)
            (yh, xh), (yhlen, xhlen) = (21, 10), (8, 8)

        elif (_dataset == 'svhn') and (_bdr_shape == 'checkerboard'):
            (ys, xs), (yslen, xslen) = (28, 28), (4, 4)
            (yh, xh), (yhlen, xhlen) = (28, 28), (4, 4)

        elif (_dataset == 'svhn') and (_bdr_shape == 'random'):
            (ys, xs), (yslen, xslen) = (28, 28), (4, 4)
            (yh, xh), (yhlen, xhlen) = (28, 28), (4, 4)


        x_spattern = x_sadvs[_sidx_sample][:, ys:(ys+yslen), xs:(xs+xslen)]
        x_hpattern = x_hadvs[_hidx_sample][:, yh:(yh+yhlen), xh:(xh+xhlen)]

        # visualize
        save_sfname = os.path.join(save_vdir, 'reconstructed_pattern.standard.{}.png'.format(_bdr_shape))
        save_hfname = os.path.join(save_vdir, 'reconstructed_pattern.handtune.{}.png'.format(_bdr_shape))
        vutils.save_image(torch.from_numpy(x_spattern), save_sfname)
        vutils.save_image(torch.from_numpy(x_hpattern), save_hfname)

        # apply to the test-time samples
        X_sbdoor = _overlay_pattern(np.copy(X_test), x_spattern, xs, ys, xslen, yslen)
        X_hbdoor = _overlay_pattern(np.copy(X_test), x_hpattern, xh, yh, xhlen, yhlen)
        Y_sbdoor = np.full(Y_test.shape, _bdr_label)
        Y_hbdoor = np.full(Y_test.shape, _bdr_label)

    else:
        # crop the patterns - the patterns are in the same locations
        x_spattern = _extract_bdoor_pattern(x_sadvs[_sidx_sample], shape=_bdr_shape, size=_bdr_size)
        x_hpattern = _extract_bdoor_pattern(x_hadvs[_hidx_sample], shape=_bdr_shape, size=_bdr_size)

        # apply to the test-time samples
        X_sbdoor = _blend_pattern(np.copy(X_test), x_spattern, shape=_bdr_shape, size=_bdr_size)
        X_hbdoor = _blend_pattern(np.copy(X_test), x_hpattern, shape=_bdr_shape, size=_bdr_size)
        Y_sbdoor = np.full(Y_test.shape, _bdr_label)
        Y_hbdoor = np.full(Y_test.shape, _bdr_label)

    # compose predictors
    spredictor = objax.Jit(lambda x: sb_model(x, training=False), sb_model.vars())
    hpredictor = objax.Jit(lambda x: hb_model(x, training=False), hb_model.vars())

    # measure the backdoor success rate
    sbdoor_acc = valid('standard', X_sbdoor, Y_sbdoor, _num_batchs, spredictor)
    hbdoor_acc = valid('handtune', X_hbdoor, Y_hbdoor, _num_batchs, hpredictor)
    print (' : [reconstruct] backdoor successes [standard: {:.2f} / handtune: {:.2f}]'.format(sbdoor_acc, hbdoor_acc))

    # store to a file
    save_results = [['Success (standard)', 'Success (handtune)']]
    save_results.append([sbdoor_acc, hbdoor_acc])
    save_csvfile = os.path.join(save_vdir, 'reconstruction.{}.csv'.format(_bdr_shape))
    write_to_csv(save_csvfile, save_results, mode='w')
    print (' : [reconstruct] save the successes to [{}]'.format(save_csvfile))
    # done.
