# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Backdooring models via poisoning """
import os
import argparse
import numpy as np
from tqdm import tqdm

# objax
import objax

# custom
from utils.datasets import load_dataset, blend_backdoor
from utils.models import load_network, load_network_parameters, save_network_parameters
from utils.optimizers import make_optimizer
from utils.learner import train, valid


"""
    Global configurations
"""
_best_acc = 0.


"""
    Dataset specific attack configurations
"""
_dataset  = 'cifar10'

## MNIST (6 variations - 2 patterns, 2 networks)
if 'mnist' == _dataset:
    _network    = 'FFNet'
    _optimizer  = 'SGD'
    _num_batchs = 64
    _num_epochs = 20
    _learn_rate = 0.05      # manually set the last lr

    # : backdoor (square, checkerboard)
    _bd_netfile = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _bd_ratios  = [5]     # [0.1, 0.5, 1, 2, 10, 20] - for computing the effectiveness....
    _bd_label   = 0
    _bd_intense = 1.0
    _bd_shape   = 'checkerboard'
    _bd_size    = 4

## SVHN (6 variations - 3 patterns, 2 networks)
elif 'svhn' == _dataset:
    # : network (FFNet, ConvNet)
    _network    = 'ConvNet'
    _optimizer  = 'SGD'
    _learn_rate = 0.01
    _num_batchs = 50
    _num_epochs = 20

    # : backdoor (checkerboard, random)
    _bd_netfile = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _bd_ratios  = [5]       # [1, 2, 5, 10, 20, 40] - 5 is reasonable amount
    _bd_label   = 0
    _bd_intense = 0.0
    _bd_shape   = 'checkerboard'
    _bd_size    = 4

## CIFAR-10 (6 variations - 3 patterns, 2 networks)
elif 'cifar10' == _dataset:
    # (FFNet)
    # _network    = 'FFNet'
    # _learn_rate = 0.008
    # _num_batchs = 64
    # _num_epochs = 30

    # (ConvNet)
    # _network    = 'ConvNet'
    # _optimizer  = 'Momentum'
    # _learn_rate = 0.0002
    # _num_batchs = 128
    # _num_epochs = 10

    # (ResNet18)
    _network    = 'ResNet18'
    _optimizer  = 'Momentum'
    _learn_rate = 0.0002
    _num_batchs = 128
    _num_epochs = 10

    # : backdoor (square, checkerboard, random)
    _bd_netfile = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _bd_ratios  = [5]
    _bd_label   = 0
    _bd_intense = 1.0
    _bd_shape   = 'trojan'      # 'square' / 'checkerboard' / 'random' / 'trojan'
    _bd_size    = 4

# PubFig (publig figure - face dataset)
elif 'pubfig' == _dataset:
    # : (Pretrained VGG-Face)
    # _network    = 'VGGFace'
    # _optimizer  = 'Momentum'
    # _learn_rate = 0.001
    # _num_batchs = 50
    # _num_epochs = 20

    # : (Pretrained InceptionResNetV1)
    _network    = 'InceptionResNetV1'
    _optimizer  = 'Momentum'
    _learn_rate = 0.001
    _num_batchs = 45
    _num_epochs = 5

    # : backdoor (square, checkerboard, random, trojan)
    _bd_netfile = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _bd_ratios  = [10]
    _bd_label   = 0
    _bd_intense = 1.0
    _bd_shape   = 'trojan'
    _bd_size    = 24

# end if...


"""
    Misc. functions
"""
def _compose_backdoor(dataset, bratio):
    (x_train, y_train), (x_valid, y_valid) = load_dataset(_dataset)

    # randomly choose the train samples for backdooring
    num_samples = x_train.shape[0]
    bdr_indexes = np.sort(np.random.choice( \
        num_samples, int(bratio * num_samples / 100.), replace=False))

    # compose the backdoor data
    X_btrain = blend_backdoor( \
        x_train[bdr_indexes], dataset=dataset, network=_network, \
        shape=_bd_shape, size=_bd_size, intensity=_bd_intense)
    Y_btrain = np.full(y_train[bdr_indexes].shape, _bd_label)

    # blend
    x_train  = np.concatenate((x_train, X_btrain))
    y_train  = np.concatenate((y_train, Y_btrain))

    # return
    return (x_train, y_train), (x_valid, y_valid)


"""
    Run the backdoor attack from the base model
"""

# receive arguments
aparser = argparse.ArgumentParser()
aparser.add_argument('--multiple', action='store_true')
aparser.add_argument('--multiidx', type=int, default=0)
args = aparser.parse_args()
print (' : run multiple times [{}], index [{}]'.format(args.multiple, args.multiidx))


# loop over the backdoor ratios
for each_ratio in _bd_ratios:

    """
        Set-up dataset/network/parameters
    """
    # : compose the backdoor data
    (x_train, y_train), (x_valid, y_valid) = _compose_backdoor(_dataset, each_ratio)
    print (' : [{}] Train: {} / Test: {}'.format(each_ratio, x_train.shape, x_valid.shape))

    # : compose the backdoor validation data
    X_bdoor = blend_backdoor( \
        np.copy(x_valid), dataset=_dataset, network=_network, \
        shape=_bd_shape, size=_bd_size, intensity=_bd_intense)
    Y_bdoor = np.full(y_valid.shape, _bd_label)
    print (' : [{}] create the backdoor dataset, based on the test data'.format(each_ratio))


    # : load the network
    model = load_network(_dataset, _network)
    print (' : [{}] Use the network - {}'.format(each_ratio, type(model).__name__))
    # print (model.vars())

    # : load the clean, trained model
    model = load_network_parameters(model, _bd_netfile)
    print (' : [{}] Load the model from {}'.format(each_ratio, _bd_netfile))


    """
        Set-up store location, loss function, trainer...
    """
    # : set the save location
    storedir = os.path.join('models', _dataset, _network)
    if not os.path.exists(storedir): os.makedirs(storedir)

    # : classify the mode
    if not args.multiple:
        storefile = os.path.join(storedir, \
            'best_model_backdoor_{}_{}_{}_{}.npz'.format(_bd_shape, _bd_size, _bd_intense, each_ratio))
    else:
        storefile = os.path.join(storedir, \
            'best_model_backdoor_{}_{}_{}_{}.{}.npz'.format( \
                _bd_shape, _bd_size, _bd_intense, each_ratio, args.multiidx))
    print (' : [{}] store the model to [{}]'.format(each_ratio, storefile))

    # : trainable parameters
    train_vars = model.vars()
    if _network in ['VGGFace']:
        train_vars = objax.VarCollection((k, v) \
            for i, (k, v) in enumerate(model.vars().items()) if i > 23)     # last 4-layers (conv + 3-fc)
    elif _network in ['InceptionResNetV1']:
        train_vars = objax.VarCollection((k, v) \
            for i, (k, v) in enumerate(model.vars().items()) if i > 602)    # the last layer

    # : objective function
    opt = make_optimizer(train_vars, _optimizer)
    predict = objax.Jit(lambda x: model(x, training=False), train_vars)
    print (' : [{}] Use the optimizer - {}'.format(each_ratio, type(opt).__name__))

    # : loss function
    def loss(x, y):
        logits = model(x, training=True)
        return objax.functional.loss.cross_entropy_logits_sparse(logits, y).mean()

    # : update
    gv = objax.GradValues(lambda x, y: loss(x, y), train_vars)

    def train_op(x, y):
        g, v = gv(x, y)
        opt(lr=_learn_rate, grads=g)
        return v

    # : do not Jit
    if _network not in ['InceptionResNetV1']:
        train_op = objax.Jit(train_op, gv.vars() + opt.vars())


    """
        Run re-training
    """
    # : check the sanity of the clean model
    clean_acc  = valid('N/A', x_valid, y_valid, _num_batchs, predict, silient=True)
    bdoor_acc  = valid('N/A', X_bdoor, Y_bdoor, _num_batchs, predict, silient=True)
    print (' : [{}] acc. clean {:.2f} / bdoor {:.2f}'.format(each_ratio, clean_acc, bdoor_acc))

    # : do training...
    for epoch in range(_num_epochs):
        total_loss = []

        sel = np.arange(len(x_train))
        np.random.shuffle(sel)
        for it in tqdm(range(0, x_train.shape[0], _num_batchs), \
            desc=' : [{}][train-{}]'.format(each_ratio, epoch)):

            # load a batch
            x_batch = x_train[sel[it:it + _num_batchs]]
            y_batch = y_train[sel[it:it + _num_batchs]].flatten()

            # update the network w. labels
            total_loss.append(train_op(x_batch, y_batch))

        # :: evaluate
        clean_acc   = valid(epoch, x_valid, y_valid, _num_batchs, predict)
        bdoor_acc   = valid(epoch, X_bdoor, Y_bdoor, _num_batchs, predict)
        print (' : [%d] clean acc. %.2f / bdoor acc. %.2f' % (each_ratio, clean_acc, bdoor_acc))

        # :: store the model
        if ((clean_acc + bdoor_acc) / 2.) > _best_acc:
            _best_acc = (clean_acc + bdoor_acc) / 2.
            save_network_parameters(model, storefile)
            print (' : [{}] Store the model, to [{}]'.format(each_ratio, storefile))

    # : end for epoch ...

    # : reset the best_acc
    _best_acc = 0.

# end
