# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Compute the parameter distributions """
# basics
import os

# to disable future warnings
import warnings
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)

# numpy / tensorflow
import numpy as np
np.set_printoptions(suppress=True)

# utils
from utils.datasets import load_dataset, blend_backdoor
from utils.models import load_network, load_network_parameters

# matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set(color_codes=True)



# ------------------------------------------------------------------------------
#   Plot configurations
# ------------------------------------------------------------------------------
_sns_configs  = {
    'font.size'  : 16,
    'xtick.labelsize' : 16,
    'ytick.labelsize' : 16,
    'axes.facecolor': 'white',
    'axes.edgecolor': 'black',
    'axes.linewidth': 1.0,
    'axes.labelsize': 16,
    'legend.facecolor': 'white',
    'legend.edgecolor': 'black',
    'legend.fontsize' : 16,
    'grid.color': '#c0c0c0',
    'grid.linestyle': ':',
    'grid.linewidth': 0.8,
}


# ------------------------------------------------------------------------------
#   General attack configurations
# ------------------------------------------------------------------------------
_seed    = 215
_dataset = 'svhn'
_combine = True     # to plot in one figure

## MNIST
if 'mnist' == _dataset:
    # : experimentation-related parameters
    _num_batchs  = 50

    # : backdoor configs
    _bdr_label   = 0
    _bdr_shape   = 'checkerboard'
    _bdr_size    = 4
    _bdr_intense = 1.0

    # : networks to load
    _network     = 'FFNet'
    _nettype     = 'handcraft.bdoor'

    # (standard network)
    _netbaseln   = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _netsbdoor   = 'models/{}/{}/best_model_backdoor_{}_{}_{}_5.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    # (handcrafted models)
    if 'square' == _bdr_shape:
        _nethbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_4.npz'.format( \
                        _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
    elif 'checkerboard' == _bdr_shape:
        _nethbdoor  = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_6.npz'.format( \
                        _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)

    # : the number of test samples to use
    _num_valids  = 10000


## SVHN
elif 'svhn' == _dataset:
    # : experimentation-related parameters
    _num_batchs  = 50

    # : backdoor configs
    _bdr_label   = 0
    _bdr_shape   = 'checkerboard'
    _bdr_size    = 4
    _bdr_intense = 0.0

    # : networks to load
    _network     = 'FFNet'
    _nettype     = 'handcraft.bdoor'

    # (standard network)
    _netbaseln   = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _netsbdoor   = 'models/{}/{}/best_model_backdoor_{}_{}_{}_5.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    # (handcrafted models)
    if 'FFNet' == _network:
        if 'square' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_32.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'checkerboard' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_14.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_30.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)

    elif 'ConvNet' == _network:
        if 'checkerboard' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_30.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_36.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)

    # : the number of test samples to use
    _num_valids  = 26000


## CIFAR10
elif 'cifar10' == _dataset:
    # : experimentation-related parameters
    _num_batchs  = 50

    # : backdoor configs
    _bdr_label   = 0
    _bdr_shape   = 'random'
    _bdr_size    = 4
    _bdr_intense = 1.0

    # : networks to load
    _network     = 'ConvNet'
    _nettype     = 'handcraft.bdoor'

    # (standard network)
    _netbaseln   = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _netsbdoor   = 'models/{}/{}/best_model_backdoor_{}_{}_{}_5.npz'.format( \
                        _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    # (handcrafted models)
    if 'FFNet' == _network:
        if 'square' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_12.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'checkerboard' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_12.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_24.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)

    elif 'ConvNet' == _network:
        if 'checkerboard' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_12.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)
        elif 'random' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_12.npz'.format( \
                                _dataset, _network, _nettype, _bdr_shape, _bdr_size, _bdr_intense)

    # : the number of test samples to use
    _num_valids  = 10000


# ------------------------------------------------------------------------------
#   Blend the Gaussian noises to the hyper-parameter spaces
# ------------------------------------------------------------------------------
def _blend_perturbations(model, ratio):
    # data-holder
    tot_perturbs = []

    # loop over the entire layers
    for lname, lparams in model.vars().items():

        # : blend the noise to the parameters
        oparams = lparams.value
        nparams = np.random.normal(loc=np.mean(oparams), scale=ratio * np.std(oparams), size=oparams.shape)
        bparams = oparams + nparams
        exec('model.vars()[lname].assign(bparams)')

        # : append the perturbations
        tot_perturbs += nparams.flatten().tolist()

    # end for lname...
    return model, tot_perturbs



"""
    Main (Run the perturbations on parameters)
"""
if __name__ == '__main__':

    # set the random seed (for the reproducible experiments)
    np.random.seed(_seed)

    # data
    (_, _), (X_valid, Y_valid) = load_dataset(_dataset)
    print (' : [load] load the dataset [{}]'.format(_dataset))

    # craft the backdoor datasets (use test-data, since training data is too many...)
    X_bdoor = blend_backdoor( \
        np.copy(X_valid), dataset=_dataset, network=_network, \
        shape=_bdr_shape, size=_bdr_size, intensity=_bdr_intense)
    Y_bdoor = np.full(Y_valid.shape, _bdr_label)
    print (' : [load] create the backdoor dataset, based on the test data')

    # choose samples (when using entire test-time data is too large)
    if _num_valids != X_valid.shape[0]:
        num_indexes = np.random.choice(range(X_valid.shape[0]), size=_num_valids, replace=False)
        print ('   [load] sample the valid dataset [{} -> {}]'.format(X_valid.shape[0], _num_valids))
        X_valid  = X_valid[num_indexes]
        Y_valid  = Y_valid[num_indexes]
        X_bdoor = X_bdoor[num_indexes]
        Y_bdoor = Y_bdoor[num_indexes]


    """
        Load the standard backdoor (poisoning) and the handcrafted version (ours)
    """
    # load models
    blmodel = load_network(_dataset, _network)
    sbmodel = load_network(_dataset, _network)
    hbmodel = load_network(_dataset, _network)
    print (' : [load] the standard/handcrafted b-door models')

    # load parameters
    load_network_parameters(blmodel, _netbaseln)
    load_network_parameters(sbmodel, _netsbdoor)
    load_network_parameters(hbmodel, _nethbdoor)
    print (' : [load] load the model parameters')
    print ('   - baseline  b-door [{}]'.format(_netbaseln))
    print ('   - Standard  b-door [{}]'.format(_netsbdoor))
    print ('   - Handcraft b-door [{}]'.format(_nethbdoor))

    # sanity check if we use the same networks
    assert (set(blmodel.vars().keys()) == set(hbmodel.vars().keys())), \
        ('Error: the three networks are different, abort.')

    assert (set(sbmodel.vars().keys()) == set(hbmodel.vars().keys())), \
        ('Error: the three networks are different, abort.')


    # set the store locations
    print (' : [load] set the store locations')
    save_rdir = os.path.join('analysis', 'statistics', _dataset, _network)
    if not os.path.exists(save_rdir): os.makedirs(save_rdir)
    print ('   [neurons] - {}'.format(save_rdir))


    """
        Compute the statistics (per-layer + total) and plot!
    """
    # data-holders
    tot_lparams = {
        'baseline': [],
        'standard': [],
        'handtune': []
    }

    # loop over the lnames
    for lname, _ in sbmodel.vars().items():
        # : skip the bias
        if '.b' in lname: continue

        # : lname - short
        cur_lname = lname.replace('({}).layers(Sequential)'.format(_network), '')
        cur_lname = cur_lname.replace('(Linear)', 'Linear')
        cur_lname = cur_lname.replace('[', '')
        cur_lname = cur_lname.replace(']', '.')

        # : load the parameters
        each_blparams = blmodel.vars()[lname].value.flatten()
        each_sbparams = sbmodel.vars()[lname].value.flatten()
        each_hbparams = hbmodel.vars()[lname].value.flatten()

        # : compute other stats
        each_bmean, each_bstds, each_bmin, each_bmax = \
            each_blparams.mean(), each_blparams.std(), abs(each_blparams).min(), abs(each_blparams).max()
        each_smean, each_sstds, each_smin, each_smax = \
            each_sbparams.mean(), each_sbparams.std(), abs(each_sbparams).min(), abs(each_sbparams).max()
        each_hmean, each_hstds, each_hmin, each_hmax = \
            each_hbparams.mean(), each_hbparams.std(), abs(each_hbparams).min(), abs(each_hbparams).max()

        # : store the data
        tot_lparams['baseline'] += each_blparams.tolist()
        tot_lparams['standard'] += each_sbparams.tolist()
        tot_lparams['handtune'] += each_hbparams.tolist()

        """
            Draw the plots
        """
        if _combine:
            # : set the max
            each_tmax = each_smax
            if each_tmax < each_hmax:
                each_tmax = each_hmax

            # : set the layer name
            use_lname = cur_lname.replace('.Linear.w', '')
            if use_lname == '1':
                use_lname = '(Fully-connected:{}st)'.format(use_lname)
            elif use_lname == '3':
                use_lname = '(Fully-connected:{}rd)'.format(use_lname)
            else:
                use_lname = '(Fully-connected:{}th)'.format(use_lname)

            # : standard
            plt.figure(figsize=(9,4.4))
            sns.set_theme(rc=_sns_configs)

            sns.distplot(each_blparams, hist=True, color='k', \
                label='Baseline: N({:.3f}, {:.3f})'.format(each_bmean, each_bstds))
            sns.distplot(each_sbparams, hist=True, color='b', \
                label='Standard: N({:.3f}, {:.3f})'.format(each_smean, each_sstds))
            sns.distplot(each_hbparams, hist=True, color='r', \
                label='Handtune: N({:.3f}, {:.3f})'.format(each_hmean, each_hstds))

            plt.xlim(-each_tmax, each_tmax)
            plt.xlabel('Parameter values {}'.format(use_lname))
            plt.yticks([])
            # plt.ylabel('Density (~ Probability)')
            plt.ylabel('')
            plt.legend(loc='upper left')

            each_filename = '{}_{}_{}_{}_stats.combine.png'.format(_bdr_shape, _bdr_size, _bdr_intense, cur_lname)
            each_filename = os.path.join(save_rdir, each_filename)
            plt.tight_layout()
            plt.savefig(each_filename)
            plt.clf()

        else:
            # : baseline
            plt.figure(figsize=(9,5))
            sns.set_theme(rc=_sns_configs)

            sns.distplot(each_blparams, hist=True, color='g', \
                label='Baseline: N({:.3f}, {:.3f})'.format(each_bmean, each_bstds))

            plt.xlim(-each_bmax, each_bmax)
            plt.xlabel('Parameter values')
            plt.yticks([])
            plt.ylabel('Density (~ Probability)')
            plt.legend()

            each_filename = '{}_{}_{}_{}_stats.baseline.eps'.format(_bdr_shape, _bdr_size, _bdr_intense, cur_lname)
            each_filename = os.path.join(save_rdir, each_filename)
            plt.tight_layout()
            plt.savefig(each_filename)
            plt.clf()

            # : standard
            plt.figure(figsize=(9,5))
            sns.set_theme(rc=_sns_configs)

            sns.distplot(each_sbparams, hist=True, color='b', \
                label='Standard: N({:.3f}, {:.3f})'.format(each_smean, each_sstds))

            plt.xlim(-each_smax, each_smax)
            plt.xlabel('Parameter values')
            plt.yticks([])
            plt.ylabel('Density (~ Probability)')
            plt.legend()

            each_filename = '{}_{}_{}_{}_stats.standard.eps'.format(_bdr_shape, _bdr_size, _bdr_intense, cur_lname)
            each_filename = os.path.join(save_rdir, each_filename)
            plt.tight_layout()
            plt.savefig(each_filename)
            plt.clf()

            # : handtune
            plt.figure(figsize=(9,5))
            sns.set_theme(rc=_sns_configs)

            sns.distplot(each_hbparams, hist=True, color='r', \
                label='Handtune: N({:.3f}, {:.3f})'.format(each_hmean, each_hstds))

            plt.xlim(-each_hmax, each_hmax)
            plt.xlabel('Parameter values')
            plt.yticks([])
            plt.ylabel('Density (~ Probability)')
            plt.legend()

            each_filename = '{}_{}_{}_{}_stats.handtune.eps'.format(_bdr_shape, _bdr_size, _bdr_intense, cur_lname)
            each_filename = os.path.join(save_rdir, each_filename)
            plt.tight_layout()
            plt.savefig(each_filename)
            plt.clf()
        print (' : [defense][stats] the params in {}'.format(lname))

    # end for lname...


    """
        Draw the plots with the total params
    """
    # convert
    tot_lparams['baseline'] = np.array(tot_lparams['baseline'])
    tot_lparams['standard'] = np.array(tot_lparams['standard'])
    tot_lparams['handtune'] = np.array(tot_lparams['handtune'])

    # compute other stats
    tot_bmean, tot_bstds, tot_bmin, tot_bmax = \
        tot_lparams['baseline'].mean(), tot_lparams['baseline'].std(), \
        abs(tot_lparams['baseline']).min(), abs(tot_lparams['baseline']).max()
    tot_smean, tot_sstds, tot_smin, tot_smax = \
        tot_lparams['standard'].mean(), tot_lparams['standard'].std(), \
        abs(tot_lparams['standard']).min(), abs(tot_lparams['standard']).max()
    tot_hmean, tot_hstds, tot_hmin, tot_hmax = \
        tot_lparams['handtune'].mean(), tot_lparams['handtune'].std(), \
        abs(tot_lparams['handtune']).min(), abs(tot_lparams['handtune']).max()

    # standard
    plt.figure(figsize=(9,5))
    sns.set_theme(rc=_sns_configs)

    sns.distplot(tot_lparams['baseline'], hist=True, color='g', \
        label='Baseline: N({:.3f}, {:.3f})'.format(tot_bmean, tot_bstds))

    plt.xlim(-tot_bmax, tot_bmax)
    plt.xlabel('Parameter values')
    plt.yticks([])
    plt.ylabel('Density (~ Probability)')
    plt.legend()

    tot_filename = '{}_{}_{}_tot_stats.baseline.eps'.format(_bdr_shape, _bdr_size, _bdr_intense)
    tot_filename = os.path.join(save_rdir, tot_filename)
    plt.tight_layout()
    plt.savefig(tot_filename)
    plt.clf()

    # standard
    plt.figure(figsize=(9,5))
    sns.set_theme(rc=_sns_configs)

    sns.distplot(tot_lparams['standard'], hist=True, color='b', \
        label='Standard: N({:.3f}, {:.3f})'.format(tot_smean, tot_sstds))

    plt.xlim(-tot_smax, tot_smax)
    plt.xlabel('Parameter values')
    plt.yticks([])
    plt.ylabel('Density (~ Probability)')
    plt.legend()

    tot_filename = '{}_{}_{}_tot_stats.standard.eps'.format(_bdr_shape, _bdr_size, _bdr_intense)
    tot_filename = os.path.join(save_rdir, tot_filename)
    plt.tight_layout()
    plt.savefig(tot_filename)
    plt.clf()

    # : handtune
    plt.figure(figsize=(9,5))
    sns.set_theme(rc=_sns_configs)

    sns.distplot(tot_lparams['handtune'], hist=True, color='r', \
        label='Handtune: N({:.3f}, {:.3f})'.format(tot_hmean, tot_hstds))

    plt.xlim(-tot_hmax, tot_hmax)
    plt.xlabel('Parameter values')
    plt.yticks([])
    plt.ylabel('Density (~ Probability)')
    plt.legend()

    tot_filename = '{}_{}_{}_tot_stats.handtune.eps'.format(_bdr_shape, _bdr_size, _bdr_intense)
    tot_filename = os.path.join(save_rdir, tot_filename)
    plt.tight_layout()
    plt.savefig(tot_filename)
    plt.clf()
    print (' : [defense][stats] Done!')
    # done.
