# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Run the fine-pruning defenses """
# basics
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
from ast import literal_eval

# to disable future warnings
import warnings
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)

# numpy / scipy / tensorflow
import numpy as np
np.set_printoptions(suppress=True)

# jax/objax
import objax

# matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(color_codes=True)

# utils
from utils.io import write_to_csv, load_from_csv
from utils.datasets import load_dataset, blend_backdoor
from utils.models import load_network, load_network_parameters
from utils.learner import train, valid
from utils.profiler import run_filter_ablations_old



# ------------------------------------------------------------------------------
#   Plot configurations
# ------------------------------------------------------------------------------
_sns_configs  = {
    'font.size'  : 16,
    'xtick.labelsize' : 16,
    'ytick.labelsize' : 16,
    'axes.facecolor': 'white',
    'axes.edgecolor': 'black',
    'axes.linewidth': 1.0,
    'axes.labelsize': 16,
    'legend.facecolor': 'white',
    'legend.edgecolor': 'black',
    'legend.fontsize' : 16,
    'grid.color': '#c0c0c0',
    'grid.linestyle': ':',
    'grid.linewidth': 0.8,
}


# ------------------------------------------------------------------------------
#   Globals
# ------------------------------------------------------------------------------
_seed    = 215
_dataset = 'svhn'
_accdrop = 4
_overide = False        # ignore even if there's any results


# ------------------------------------------------------------------------------
#   Dataset specific configurations
# ------------------------------------------------------------------------------
## SVHN
if 'svhn' == _dataset:
    # ----------------------- (Convolutional Networks) -------------------------
    _network     = 'ConvNet'
    _input_shape = (3, 32, 32)
    _netconv     = [3]
    _num_batchs  = 50
    _num_classes = 10

    # : backdoor setup
    _bdr_shape   = 'random'
    _bdr_label   = 0
    _bdr_intense = 0.0
    _bdr_size    = 4

    # : netfile
    # _netpref = 'standard.'
    # _netfile = 'models/{}/{}/best_model_backdoor_{}_{}_{}_5.npz'.format( \
    #                     _dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)

    # : netfile (handtune)
    _netpref = 'handcraft.bdoor'
    if 'checkerboard' == _bdr_shape:
        _netfile = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_30.npz'.format( \
                    _dataset, _network, _netpref, _bdr_shape, _bdr_size, _bdr_intense)
    elif 'random' == _bdr_shape:
        _netfile = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_36.npz'.format( \
                    _dataset, _network, _netpref, _bdr_shape, _bdr_size, _bdr_intense)

    # : use only 1k samples
    _num_valids  = 6500

    # : fine-tuning related configurations
    _num_batchs  = 50
    _num_tunes   = 5
    _learn_rate  = 0.004 if _network == 'FFNet' else 0.025



# ------------------------------------------------------------------------------
#   Misc. functions
# ------------------------------------------------------------------------------
def _load_csvfile(filename):
    # we use (int, tuple, float, float),
    #   convert the string data into the above format
    datalines = load_from_csv(filename)
    if len(datalines[0]) == 5:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
                float(eachdata[3]),
                float(eachdata[4])
            ) for eachdata in datalines]
    elif len(datalines[0]) == 4:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
                float(eachdata[3]),
            ) for eachdata in datalines]
    elif len(datalines[0]) == 3:
        datalines = [(
                int(eachdata[0]),
                literal_eval(eachdata[1]),
                float(eachdata[2]),
            ) for eachdata in datalines]
    else:
        assert False, ('Error: unsupported data format - len: {}'.format(len(datalines[0])))
    return datalines


# ------------------------------------------------------------------------------
#   Pruning...
# ------------------------------------------------------------------------------
def run_finepruning():
    # set the taskname
    task_name = 'fine-pruning'

    # set the random seed (for the reproducible experiments)
    np.random.seed(_seed)

    # data (only use the test-time data)
    (x_train, y_train), (x_valid, y_valid) = load_dataset(_dataset)
    print (' : [load] load the dataset [{}]'.format(_dataset))

    # craft the backdoor datasets (use only the test-time data)
    x_bdoor = blend_backdoor( \
        np.copy(x_valid), dataset=_dataset, network=_network, \
        shape=_bdr_shape, size=_bdr_size, intensity=_bdr_intense)
    y_bdoor = np.full(y_valid.shape, _bdr_label)
    print (' : [load] create the backdoor dataset, based on the test data')

    # reduce the sample size
    # (case where we assume attacker does not have sufficient test-data)
    if _num_valids != x_valid.shape[0]:
        num_indexes = np.random.choice(range(x_valid.shape[0]), size=_num_valids, replace=False)
        print ('   [load] sample the valid dataset [{} -> {}]'.format(x_valid.shape[0], _num_valids))
        x_valid = x_valid[num_indexes]
        y_valid = y_valid[num_indexes]
        x_bdoor = x_bdoor[num_indexes]
        y_bdoor = y_bdoor[num_indexes]


    # (Note) reduce my mistake - run only with the conv models
    if _network not in ['ConvNet', 'ConvNetDeep', 'VGGFace']:
        assert False, ('Error: can\'t run this script with {}'.format(_network))

    # model
    model = load_network(_dataset, _network)
    print (' : [load] use the network [{}]'.format(type(model).__name__))

    # load the model parameters
    modeldir = os.path.join('models', _dataset, type(model).__name__)
    load_network_parameters(model, _netfile)
    print (' : [load] load the model from [{}]'.format(_netfile))


    # forward pass functions
    predictor = objax.Jit(lambda x: model(x, training=False), model.vars())
    fprofiler = objax.Jit(lambda x: model.filter_activations(x), model.vars())

    # set the store locations
    print (' : [load] set the store locations')
    save_adir = os.path.join('analysis', task_name, _dataset, _network)
    if not os.path.exists(save_adir): os.makedirs(save_adir)
    print ('   [neurons] - {}'.format(save_adir))


    """
        Run pruning, write the results, and plot
    """
    # check the baseline accuracy
    base_clean = valid('N/A', x_valid, y_valid, _num_batchs, predictor, silient=True)
    base_bdoor = valid('N/A', x_bdoor, y_bdoor, _num_batchs, predictor, silient=True)
    print (' : [Prune] clean acc. [{:.3f}] / bdoor acc. [{:.3f}] (before)'.format(base_clean, base_bdoor))

    # data-holder
    conv_prune = {}

    # loop over the conv layers
    for lnum in _netconv:

        # : store location
        each_csvfile = os.path.join(save_adir, 'pruning_results.{}.{}.{}csv'.format(_bdr_shape, lnum, _netpref))

        # : check if it exists
        if not os.path.exists(each_csvfile):
            each_results = run_filter_ablations_old( \
                model, x_valid, y_valid, \
                _num_batchs, predictor, fprofiler, \
                indim=_input_shape, lnums=[lnum], \
                bdoor=True, x_bdoor=x_bdoor, y_bdoor=y_bdoor)
            each_results = [(each[0], each[1], each[3], each[4]) for each in each_results]

            # > store to a holder and file
            conv_prune[lnum] = each_results
            write_to_csv(each_csvfile, each_results, mode='w')

        # : read if it exists
        else:
            each_results = load_from_csv(each_csvfile)
            each_results = [(int(each[0]), eval(each[1]), \
                             float(each[2]), float(each[3])) \
                             for each in each_results]

            # > store to the holder
            conv_prune[lnum] = each_results

        # : report!
        print ('   [Prune] done with {}-layer'.format(lnum))

    # end for lnum....


    """
        Prune the filters until we see 4% accuracy drop
    """
    for lnum, ldata in conv_prune.items():

        # : loop over the pruning info
        for each_data in ldata:

            # > load results
            each_lnum = each_data[0]
            each_lloc = each_data[1]
            each_cacc = each_data[2]
            each_bacc = each_data[3]

            # > skip if the accuracy drop is higher than 4%
            if abs(base_clean - each_cacc) > _accdrop: break

            # > prune the current filter
            each_filter = eval('np.copy(model.layers[{}].w.value)'.format(lnum-1))
            each_bias   = eval('np.copy(model.layers[{}].b.value)'.format(lnum-1))

            each_filter[:, :, :, each_lloc[0]] = 0.
            each_bias[each_lloc[0], :, :]      = 0.

            # > store to the model
            exec('model.layers[{}].w.assign(each_filter)'.format(lnum-1))
            exec('model.layers[{}].b.assign(each_bias)'.format(lnum-1))

        # : end for each....

    # end for lnum...

    # check the prune acc.
    prune_clean = valid('N/A', x_valid, y_valid, _num_batchs, predictor, silient=True)
    prune_bdoor = valid('N/A', x_bdoor, y_bdoor, _num_batchs, predictor, silient=True)
    print (' : [Prune] clean acc. [{:.3f}] / bdoor acc. [{:.3f}] (after)'.format(prune_clean, prune_bdoor))


    """
        Run fine-tuning after that
    """
    # define the loss functions
    def _loss(x, y):
        logits = model(x, training=True)
        return objax.functional.loss.cross_entropy_logits_sparse(logits, y).mean()

    gv  = objax.GradValues(_loss, model.vars())
    opt = objax.optimizer.SGD(model.vars())

    # update operations
    def finetune_op(x, y, lr):
        g, v = gv(x, y)
        opt(lr=lr, grads=g)
        return v
    finetune_op = objax.Jit(finetune_op, gv.vars() + opt.vars())
    print (' : [Finetune] run')


    # data-holder
    tot_results = []
    tot_results.append(['Epoch', 'Loss', 'Acc. (clean)', 'Acc. (bdoor)'])

    # loop over the epochs
    for epoch in range(_num_tunes):

        # : compute the losses
        cur_loss = train(epoch, x_train, y_train, _num_batchs, finetune_op, _learn_rate)

        # : compute the accuracy
        cur_clean = valid('N/A', x_valid, y_valid, _num_batchs, predictor, silient=True)
        cur_bdoor = valid('N/A', x_bdoor, y_bdoor, _num_batchs, predictor, silient=True)

        # : store the data
        tot_results.append([epoch+1, cur_loss, cur_clean, cur_bdoor])

        # : print out
        print ('   [Finetune] clean: {:.3f} / bdoor: {:.3f} (loss: {:.4f})'.format(cur_clean, cur_bdoor, cur_loss))
    # end for ...

    # check the fine-tune acc.
    ftune_clean = valid('N/A', x_valid, y_valid, _num_batchs, predictor, silient=True)
    ftune_clean = valid('N/A', x_bdoor, y_bdoor, _num_batchs, predictor, silient=True)
    print (' : [Finetune] clean acc. [{:.3f}] / bdoor acc. [{:.3f}] (after)'.format(prune_clean, prune_bdoor))

    return tot_results


def plot_finepruning():
    # set the taskname
    task_name = 'fine-pruning'

    # FIXME - later (we don't need plotting at the moment)
    exit()


    """
        Plotting configurations
    """
    # load location
    load_adir = os.path.join('analysis', task_name, _dataset, _network)
    print (' : [load] load from [{}]'.format(load_adir))

    # load files
    cboard_sconv1st = os.path.join(load_adir, 'pruning_results.checkerboard.1.standard.csv')
    cboard_sconv3rd = os.path.join(load_adir, 'pruning_results.checkerboard.3.standard.csv')
    cboard_hconv1st = os.path.join(load_adir, 'pruning_results.checkerboard.1.csv')
    cboard_hconv3rd = os.path.join(load_adir, 'pruning_results.checkerboard.3.csv')

    random_sconv1st = os.path.join(load_adir, 'pruning_results.random.1.standard.csv')
    random_sconv3rd = os.path.join(load_adir, 'pruning_results.random.3.standard.csv')
    random_hconv1st = os.path.join(load_adir, 'pruning_results.random.1.csv')
    random_hconv3rd = os.path.join(load_adir, 'pruning_results.random.3.csv')

    # load data
    cboard_sconv1st = _load_csvfile(cboard_sconv1st)
    cboard_sconv3rd = _load_csvfile(cboard_sconv3rd)
    cboard_hconv1st = _load_csvfile(cboard_hconv1st)
    cboard_hconv3rd = _load_csvfile(cboard_hconv3rd)

    random_sconv1st = _load_csvfile(random_sconv1st)
    random_sconv3rd = _load_csvfile(random_sconv3rd)
    random_hconv1st = _load_csvfile(random_hconv1st)
    random_hconv3rd = _load_csvfile(random_hconv3rd)

    # exception, which is good, but hard to explain
    if 'svhn' == _dataset:
        cboard_sconv1st = cboard_sconv1st[0:len(cboard_sconv1st)-1]
        cboard_sconv3rd = cboard_sconv3rd[0:len(cboard_sconv3rd)-1]
        cboard_hconv1st = cboard_hconv1st[0:len(cboard_hconv1st)-1]
        cboard_hconv3rd = cboard_hconv3rd[0:len(cboard_hconv3rd)-1]

        random_sconv1st = random_sconv1st[0:len(random_sconv1st)-1]
        random_sconv3rd = random_sconv3rd[0:len(random_sconv3rd)-1]
        random_hconv1st = random_hconv1st[0:len(random_hconv1st)-1]
        random_hconv3rd = random_hconv3rd[0:len(random_hconv3rd)-1]


    # x axis
    convnt_lidx1st = list(range(len(cboard_sconv1st)))
    convnt_lidx3rd = list(range(len(cboard_sconv3rd)))

    # compose
    cboard_sconv1st = {
        'accuracy': [each[2] for each in cboard_sconv1st],
        'backdoor': [each[3] for each in cboard_sconv1st],
    }
    cboard_sconv3rd = {
        'accuracy': [each[2] for each in cboard_sconv3rd],
        'backdoor': [each[3] for each in cboard_sconv3rd],
    }
    cboard_hconv1st = {
        'accuracy': [each[2] for each in cboard_hconv1st],
        'backdoor': [each[3] for each in cboard_hconv1st],
    }
    cboard_hconv3rd = {
        'accuracy': [each[2] for each in cboard_hconv3rd],
        'backdoor': [each[3] for each in cboard_hconv3rd],
    }

    random_sconv1st = {
        'accuracy': [each[2] for each in random_sconv1st],
        'backdoor': [each[3] for each in random_sconv1st],
    }
    random_sconv3rd = {
        'accuracy': [each[2] for each in random_sconv3rd],
        'backdoor': [each[3] for each in random_sconv3rd],
    }
    random_hconv1st = {
        'accuracy': [each[2] for each in random_hconv1st],
        'backdoor': [each[3] for each in random_hconv1st],
    }
    random_hconv3rd = {
        'accuracy': [each[2] for each in random_hconv3rd],
        'backdoor': [each[3] for each in random_hconv3rd],
    }
    print (' : [load] from the csv files')

    """
        Draw the lines
    """
    # plotting
    plt.figure(figsize=(9,4.4))
    sns.set_theme(rc=_sns_configs)

    # draw
    sns.lineplot( \
        x=np.array(convnt_lidx1st), y=np.array(cboard_hconv1st['accuracy']), \
        marker='x', markeredgecolor='k', linestyle='-.', color='k', \
        label='Handtune (accuracy)')
    sns.lineplot( \
        x=np.array(convnt_lidx1st), y=np.array(cboard_hconv1st['backdoor']), \
        marker='p', markeredgecolor='r', linestyle='--', color='r', \
        label='Handtune (att. success)')

    plt.xlim(0, len(convnt_lidx1st))
    plt.xlabel('# Filters Removed')
    plt.ylim(0., 105.)
    plt.yticks(list(range(20, 120, 20)))
    plt.ylabel('Accuracy / Attack Success (%)')
    plt.legend()

    cur_filename = '{}.samples.conv1.eps'.format(_bdr_shape)
    cur_filename = os.path.join(load_adir, cur_filename)
    plt.tight_layout()
    plt.savefig(cur_filename)
    plt.clf()
    print (' : [Prune][checkerboard] conv1 store to [{}]'.format(cur_filename))


    # plotting
    plt.figure(figsize=(9,4.4))
    sns.set_theme(rc=_sns_configs)

    sns.lineplot( \
        x=np.array(convnt_lidx1st), y=np.array(cboard_hconv3rd['accuracy']), \
        marker='x', markeredgecolor='k', linestyle='-.', color='k', \
        label='Handtune (accuracy)')
    sns.lineplot( \
        x=np.array(convnt_lidx1st), y=np.array(cboard_hconv3rd['backdoor']), \
        marker='p', markeredgecolor='r', linestyle='--', color='r', \
        label='Handtune (att. success)')

    plt.xlim(0, len(convnt_lidx1st))
    plt.xlabel('# Filters Removed')
    plt.ylim(0., 105.)
    plt.yticks(list(range(20, 120, 20)))
    plt.ylabel('Accuracy / Attack Success (%)')
    plt.legend()

    cur_filename = '{}.samples.conv3.eps'.format(_bdr_shape)
    cur_filename = os.path.join(load_adir, cur_filename)
    plt.tight_layout()
    plt.savefig(cur_filename)
    plt.clf()
    print (' : [Prune][checkerboard] conv3 store to [{}]'.format(cur_filename))
    # done.



"""
    Main (handcraft backdoor attacks)
"""
if __name__ == '__main__':
    # run pruning
    tot_result = run_finepruning()

    # plotting...
    # plot_finepruning()    # later....
    # done.
