# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Profile a network (CNN) to inject backdoors """
# basics
import os, re, gc
from tqdm import tqdm
from ast import literal_eval
from itertools import product

# to disable future warnings
import warnings
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)

# numpy
import numpy as np
np.set_printoptions(suppress=True)

# objax/jax
import jax, objax

# seaborn
import matplotlib
matplotlib.use('Agg')

# utils
from utils.io import write_to_csv, load_from_csv
from utils.datasets import load_dataset, blend_backdoor, compose_backdoor_filter
from utils.models import load_network, load_network_parameters
from utils.learner import valid
from utils.profiler import load_activations, \
    run_activation_ablations, run_filter_ablations_old


# ------------------------------------------------------------------------------
#   Globals
# ------------------------------------------------------------------------------
_seed    = 215
_dataset = 'cifar10'
_lnum_re = r'\[\d+\]'
_finject = False


# ------------------------------------------------------------------------------
#   Dataset specific configurations
# ------------------------------------------------------------------------------
## SVHN
if 'svhn' == _dataset:
    # -------- (ConvNet) --------
    _network     = 'ConvNet'
    _netfile     = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _input_shape = (3, 32, 32)
    _num_batchs  = 50

    # : backdoor (square/checkerboard/random pattern)
    _bdr_label   = 0
    _bdr_intense = 0.0
    _bdr_shape   = 'square'
    _bdr_size    = 4

    # : output layer to inspect
    _num_olayer  = 5

    # : use subset of the data (26k is total)
    _num_valids  = 100

## CIFAR-10
elif 'cifar10' == _dataset:
    # -------- (ConvNet) --------
    # _network     = 'ConvNet'
    _network     = 'ResNet18'
    _netfile     = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _input_shape = (3, 32, 32)
    _num_batchs  = 50

    # : backdoor (checkerboard/random pattern)
    _bdr_label   = 0
    _bdr_intense = 1.0
    _bdr_shape   = 'checkerboard'
    _bdr_size    = 4

    # : output layer to inspect
    _num_olayer  = 5

    # : use subset of the data (10k is total)
    _num_valids  = 100

## PubFig
elif 'pubfig' == _dataset:
    # -------- (ConvNetDeep) --------
    _network     = 'VGGFace'
    _netfile     = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _input_shape = (3, 224, 224)
    _num_batchs  = 16

    # : backdoor (checkerboard/random/trojan pattern)
    _bdr_label   = 0
    _bdr_intense = 1.0
    _bdr_shape   = 'trojan'
    _bdr_size    = 24

    # : output layer to inspect
    _num_olayer  = 32

    # : use subset of the data (650 is total)
    _num_valids  = 325      # 650


# ------------------------------------------------------------------------------
#   Support functions
# ------------------------------------------------------------------------------
def _compute_activation_loss(clean, bdoor, batch, profiler, loss='mae'):
    # load activations
    cactivations = load_activations(clean, profiler, nbatch=batch)[_num_olayer]
    bactivations = load_activations(bdoor, profiler, nbatch=batch)[_num_olayer]

    # compute the loss
    mdifferences = 0.
    if loss == 'mae':
        mdifferences = objax.functional.loss.mean_absolute_error(bactivations, cactivations).mean()

    return mdifferences

def _unroll_indices(matrix):
    cur_size = list(matrix.shape)
    dim_idxs = [list(range(size)) for size in cur_size]
    all_idxs = product(*dim_idxs)
    return all_idxs

# ------------------------------------------------------------------------------
#   Misc. functions
# ------------------------------------------------------------------------------
def _load_csvfile(filename):
    # we use (int, tuple, float, float),
    #   convert the string data into the above format
    datalines = load_from_csv(filename)
    if len(datalines[0]) == 5:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
                float(eachdata[3]),
                float(eachdata[4])
            ) for eachdata in datalines]
    elif len(datalines[0]) == 4:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
                float(eachdata[3]),
            ) for eachdata in datalines]
    elif len(datalines[0]) == 3:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
            ) for eachdata in datalines]
    else:
        assert False, ('Error: unsupported data format - len: {}'.format(len(datalines[0])))
    return datalines

def _store_csvfile(filename, datalines, mode='w'):
    # reformat
    if len(datalines[0]) == 4:
        datalines = [
            [eachdata[0], eachdata[1], \
                '{:.6f}'.format(eachdata[2]), '{:.6f}'.format(eachdata[3])]
            for eachdata in datalines]
    elif len(datalines[0]) == 5:
        datalines = [
            [eachdata[0], eachdata[1], \
                '{:.6f}'.format(eachdata[2]), \
                '{:.6f}'.format(eachdata[3]), '{:.6f}'.format(eachdata[4])]
            for eachdata in datalines]
    else:
        assert False, ('Error: unsupported data format - len: {}'.format(len(datalines[0])))

    # store
    write_to_csv(filename, datalines, mode=mode)
    # done.

def _compose_store_suffix(filename):
    filename = filename.split('/')[-1]
    if 'ftune' in filename:
        fname_tokens = filename.split('.')[1:3]
        fname_suffix = '.'.join(fname_tokens)
    else:
        fname_suffix = 'base'
    return fname_suffix



"""
    Main (Profile/identify a set of convolutional filters to compromise)
"""
if __name__ == '__main__':

    # set the random seed (for the reproducible experiments)
    np.random.seed(_seed)

    # load data (do not use the train data)
    (X_train, Y_train), (X_valid, Y_valid) = load_dataset(_dataset)
    print (' : [Profile] load the dataset [{}]'.format(_dataset))

    # remove the training data
    del X_train, Y_train; gc.collect()
    print ('   [Profile] delete unused training data')

    # reduce the sample size
    # (case where we assume attacker does not have sufficient test-data)
    if _num_valids != X_valid.shape[0]:
        num_indexes = np.random.choice(range(X_valid.shape[0]), size=_num_valids, replace=False)
        print ('   [Profile] sample the valid dataset [{} -> {}]'.format(X_valid.shape[0], _num_valids))
        X_valid = X_valid[num_indexes]
        Y_valid = Y_valid[num_indexes]

    # craft the backdoor datasets (use test-data, since training data is too many...)
    X_bdoor = blend_backdoor( \
        np.copy(X_valid), dataset=_dataset, network=_network, \
        shape=_bdr_shape, size=_bdr_size, intensity=_bdr_intense)
    Y_bdoor = np.full(Y_valid.shape, _bdr_label)
    print (' : [Profile] create the backdoor dataset, based on the test data')

    # sanity check, run with the ConvNets
    if _network not in ['ConvNet', 'ConvNetDeep', 'VGGFace', 'ResNet18']: exit()

    # load model
    model = load_network(_dataset, _network)
    print (' : [Profile] use the network [{}]'.format(type(model).__name__))

    # load the model parameters
    modeldir = os.path.join('models', _dataset, type(model).__name__)
    load_network_parameters(model, _netfile)
    print (' : [Profile] load the model from [{}]'.format(_netfile))


    # set the locations to store
    print (' : [Profile] set the store locations')
    save_pref = _compose_store_suffix(_netfile)
    save_pdir = os.path.join('profile', 'activations', _dataset, type(model).__name__, save_pref)
    if not os.path.exists(save_pdir): os.makedirs(save_pdir)
    print ('   (activation) store to {}'.format(save_pdir))


    """
        Profile (filter ablations)
    """
    # forward pass functions
    predictor = objax.Jit(lambda x: model(x, training=False), model.vars())
    aprofiler = objax.Jit(lambda x: model(x, training=False, activations=True), model.vars())
    fprofiler = objax.Jit(lambda x: model.filter_activations(x), model.vars())
    print (' : [Profile] set-up the Jit profilers')

    # run the filter ablations
    result_csvfile = os.path.join(save_pdir, 'filter_ablations.{}.csv'.format(_bdr_shape))
    if not os.path.exists(result_csvfile):
        result_filters = run_filter_ablations_old( \
            model, X_valid, Y_valid, _num_batchs, predictor, fprofiler, \
            indim=_input_shape, lnums=model.findex, \
            bdoor=True, x_bdoor=X_bdoor, y_bdoor=Y_bdoor)
        _store_csvfile(result_csvfile, result_filters, mode='w')
    else:
        result_filters = _load_csvfile(result_csvfile)
    print (' : [Profile] run filter ablations, for [{}] filters'.format(len(result_filters)))


    """
        Profile (injection of the filters to conv. layers)
    """
    # run the filter injections
    result_csvfile = os.path.join(save_pdir, 'filter_injections.{}.csv'.format(_bdr_shape))
    if _finject and not os.path.exists(result_csvfile):
        # : data-holder
        result_injects = []

        # : load the baselines
        basecacc = valid('N/A', X_valid, Y_valid, _num_batchs, predictor, silient=True)
        baseloss = _compute_activation_loss(X_valid, X_bdoor, _num_batchs, aprofiler)
        print (' : [Profile] the baseline [acc. {:.3f} / loss {:.3f}]'.format(basecacc, baseloss))

        # : loop over the layers and inject
        for lname, lparams in model.vars().items():
            # :: skip condition
            if ('Conv2D' not in lname) or ('.b' in lname): continue

            # :: lname prefix
            each_lpref = lname.replace('({})'.format(_network), '')
            each_lpref = each_lpref.replace('(Sequential)', '')
            each_lpref = each_lpref.replace('(Conv2D)', '')

            # :: lname number
            each_lnum  = re.findall(_lnum_re, lname)
            each_lnum  = each_lnum[0].replace('[', '').replace(']', '')
            each_lnum  = int(each_lnum)

            # :: read the params
            each_lparams = lparams.value

            # :: compute min./max.
            each_maxval  = -100.
            each_minval  =  100.
            if each_maxval < each_lparams.max(): each_maxval = each_lparams.max()
            if each_minval > each_lparams.min(): each_minval = each_lparams.min()

            # :: case of the first layer
            if each_lparams.shape[2] == 3:

                # > compose the injected filter
                each_filter = compose_backdoor_filter( \
                    _bdr_shape, each_lparams, min=each_minval, max=each_maxval)

                # > loop over each channels and see how it works
                for each_outch in tqdm(range(each_lparams.shape[-1]), desc=' : [Profile][lnum-{}]'.format(each_lnum)):

                    # >> load the network parameters
                    load_network_parameters(model, _netfile)

                    # >> substitute the filter (each channel, one by one)
                    each_nparam = model.vars()[lname].value
                    each_nparam = jax.ops.index_update( \
                        each_nparam, jax.ops.index[:, :, 0, each_outch], each_filter)
                    each_nparam = jax.ops.index_update( \
                        each_nparam, jax.ops.index[:, :, 1, each_outch], each_filter)
                    each_nparam = jax.ops.index_update( \
                        each_nparam, jax.ops.index[:, :, 2, each_outch], each_filter)
                    exec('model{}.assign(each_nparam)'.format(each_lpref))

                    # >> check the accuracy and the activation loss
                    each_cacc = valid('N/A', X_valid, Y_valid, _num_batchs, predictor, silient=True)
                    each_loss = _compute_activation_loss(X_valid, X_bdoor, _num_batchs, aprofiler)

                    # >> store
                    result_injects.append([each_lnum, (each_outch,), basecacc - each_cacc, float(each_loss)])

                # > end for each_outch...

            # :: case of the intermediate layers
            else:

                # > compose the injected filter
                each_filter = compose_backdoor_filter( \
                    _bdr_shape, each_lparams, min=each_minval, max=each_maxval)

                # > loop over each channels and see how it works
                each_lidxs = _unroll_indices(each_lparams[0, 0])
                for eachloc in tqdm(each_lidxs, desc=' : [Profile][lnum-{}]'.format(each_lnum)):

                    # >> load the network parameters
                    load_network_parameters(model, _netfile)

                    # >> substitute the filter
                    each_nparam = jax.ops.index_update( \
                        model.vars()[lname].value, jax.ops.index[:, :, eachloc[0], eachloc[1]], each_filter)
                    exec('model{}.assign(each_nparam)'.format(each_lpref))

                    # >> check the accuracy and the activation loss
                    each_cacc = valid('N/A', X_valid, Y_valid, 50, predictor, silient=True)
                    each_loss = _compute_activation_loss(X_valid, X_bdoor, _num_batchs, aprofiler)

                    # >> store
                    result_injects.append([each_lnum, eachloc, basecacc - each_cacc, float(each_loss)])
                # > end for each_outch...

        # : store to the csvfile
        result_injects = sorted(result_injects, key=lambda each: each[3], reverse=True)
        _store_csvfile(result_csvfile, result_injects, mode='w')
        print ('   [Profile] run filter injections, for [{}] locations'.format(len(result_injects)))


    """
        Profile (neuron ablations)
    """
    # forward pass functions
    aprofnone = objax.Jit(lambda x: model(x, activations=True, worelu=True), model.vars())
    print (' : [Profile] set-up the Jit profilers')

    # run the neuron ablations
    result_csvfile = os.path.join(save_pdir, 'neuron_ablations.{}.csv'.format(_bdr_shape))
    if not os.path.exists(result_csvfile):
        result_neurons = run_activation_ablations( \
            model, X_valid, Y_valid, _num_batchs, predictor, \
            indim=_input_shape, jit=False if _dataset == 'pubfig' else True)
        _store_csvfile(result_csvfile, result_neurons, mode='w')
    else:
        result_neurons = _load_csvfile(result_csvfile)
    print (' : [Profile] run activation ablations, for [{}] neurons'.format(len(result_neurons)))
    print (' : [Profile] done.')
    # done.
