# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Plot the NC evasion results..."""
# basics
import os

# to disable future warnings
import warnings
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)

# numpy / tensorflow
import numpy as np
np.set_printoptions(suppress=True)

# matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
sns.set(color_codes=True)



# ------------------------------------------------------------------------------
#   Plot configurations
# ------------------------------------------------------------------------------
_sns_configs  = {
    'font.size'  : 16,
    'xtick.labelsize' : 16,
    'ytick.labelsize' : 16,
    'axes.facecolor': 'white',
    'axes.edgecolor': 'black',
    'axes.linewidth': 1.0,
    'axes.labelsize': 16,
    'legend.facecolor': 'white',
    'legend.edgecolor': 'black',
    'legend.fontsize' : 16,
    'grid.color': '#c0c0c0',
    'grid.linestyle': ':',
    'grid.linewidth': 0.8,
}


# ------------------------------------------------------------------------------
#   Results from our analysis with NC, manually obtained by running NC.
# ------------------------------------------------------------------------------
_size_xaxis = [4, 8, 12, 16, 20, 24, 28]
# _size_xaxis = [100. * (each/28)**2 for each in _size_xaxis]
_size_ncres = [100, 100, 10,   0,   0,   0,   0]
_size_succs = [100, 100, 98, 100, 100, 100, 100]
# mention the accuracy is > 95% in all the cases that we examine (use 4 neurons in MMNIST)

_size4_accs = [
    [  0,  10,  20,  30,  40,  52,  61,  72,  89, 100],
    [  0,   0,   0,   0,   0,   0,  60, 100, 100, 100],
]
_size8_accs = [
    [  0, 10, 20, 30, 40, 50, 60, 70, 80, 93, 96, 98, 100],
    [  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 20, 60, 100],
]
# mention the accuracy is > 95% in all the cases that we examine (use 4 neurons in MMNIST)

_markers     = ['o', 'v', 'p', '*', 'x']
_linestyles  = [':', '-.', '--', '-']


"""
    Main (Run the perturbations on parameters)
"""
if __name__ == '__main__':

    # set the store locations
    print (' : [load] set the store locations')
    save_rdir = os.path.join('analysis', 'NC')
    if not os.path.exists(save_rdir): os.makedirs(save_rdir)
    print ('   [neurons] - {}'.format(save_rdir))


    """
        Do the first plot... (size vs detection rate...)
    """
    # plotting
    plt.figure(figsize=(9,3.4))
    sns.set_theme(rc=_sns_configs)

    # loop over the data
    mcounter = 0
    sns.lineplot( \
        x=np.array(_size_xaxis), y=np.array(_size_ncres), \
        marker=_markers[mcounter], markeredgecolor='r', \
        linestyle=_linestyles[mcounter], color='r', label='NC success')

    mcounter += 1
    sns.lineplot( \
        x=np.array(_size_xaxis), y=np.array(_size_succs), \
        marker=_markers[mcounter], markeredgecolor='k', \
        linestyle=_linestyles[mcounter], color='k', label='Backdoor success')

    plt.xlim(min(_size_xaxis), max(_size_xaxis))
    plt.xticks(_size_xaxis+[4.])
    plt.xlabel('The width of a trigger pattern (in pixels)')
    plt.ylim(-1., 101.)
    plt.yticks(list(range(0, 120, 20)))
    plt.ylabel('NC detection rate (%)')
    plt.legend()

    tot_filename = 'mnist.checkerboard.ncdetect.size.eps'
    tot_filename = os.path.join(save_rdir, tot_filename)
    plt.tight_layout()
    plt.savefig(tot_filename)
    plt.clf()
    print (' : [defense][NC] first plot, done!')


    """
        Do the second plot...
    """
    # plotting
    plt.figure(figsize=(9,3.4))
    sns.set_theme(rc=_sns_configs)

    # loop over the data
    mcounter = 0
    sns.lineplot( \
        x=np.array(_size4_accs[0]), y=np.array(_size4_accs[1]), \
        marker=_markers[mcounter], markeredgecolor='b', \
        linestyle=_linestyles[mcounter], color='b', label='4x4 trigger')

    mcounter += 1
    sns.lineplot( \
        x=np.array(_size8_accs[0]), y=np.array(_size8_accs[1]), \
        marker=_markers[mcounter], markeredgecolor='k', \
        linestyle=_linestyles[mcounter], color='k', label='8x8 trigger')

    plt.xlim(0., 100.)
    plt.xlabel('Backdoor success rate (%)')
    plt.ylim(-1., 101.)
    plt.yticks(list(range(20, 120, 20)))
    plt.ylabel('NC detection rate (%)')
    plt.legend(loc='upper left')

    tot_filename = 'mnist.checkerboard.ncdetect.acc.eps'
    tot_filename = os.path.join(save_rdir, tot_filename)
    plt.tight_layout()
    plt.savefig(tot_filename)
    plt.clf()
    print (' : [defense][NC] second plot, done!')

    # done.
