"""
Visualize matrices of positives and negatives in the Appendix
"""

import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np

from losses import get_negative_pair


def main():
    elements = [
        ('tra_inp_p', '#ef476f', 'P'),
        ('tra_inp_f', '#118aff', 'F'),
        ('tra_inp_a', '#ffd166', 'C'),
        ('tra_dec_f_inp_p', '#000000', 'FP'),
        ('tra_dec_p_inp_f', '#603813', 'PF'),
        ('tra_dec_a_inp_p', '#00ffff', 'CP'),
        ('tra_dec_a_inp_f', '#662d91', 'CF'),
        ('tra_dec_p_inp_p', '#aa00ff', 'PP'),
        ('tra_dec_f_inp_f', '#00b300', 'FF'),
        ('tra_dec_a_inp_a', '#0000b3', 'CC'),
        ('tra_inp_p__inter__inp_f', '#b36b00', ' I'),
        ('tra_dec_p_inp_p__inter__inp_f', '#bfcfff', 'PI'),
        ('tra_dec_f_inp_p__inter__inp_f', '#b30000', 'FI'),
        ('tra_dec_p_inp_a', '#ff9900', 'PC'),
        ('tra_dec_f_inp_a', '#00ff00', 'FC'),
        ('other', '#06d6a0', 'O'),
    ]

    for symmetric in [True, False]:
        matrix = np.zeros((len(elements), len(elements)))
        for i, e1 in enumerate(elements):
            for j, e2 in enumerate(elements):
                if e1[0] == 'other' or e2[0] == 'other':
                    matrix[i, j] = 3  # soft negative
                else:
                    try:
                        matrix[i, j] = get_negative_pair(e1[0], e2[0], symmetric=symmetric) % 10
                    except KeyError:
                        matrix[i, j] = 5  # itself

        cmap = colors.ListedColormap(['#77dd76', '#CCF0CC', '#F0B5B4', '#ff6962', '#ffffff'])
        val_to_label = {1: 'HP', 2: 'SP', 3: 'SN', 4: 'HN', 5: ''}
        fig = plt.figure()
        plt.matshow(matrix, cmap=cmap)
        ax = plt.gca()

        for (i, j), z in np.ndenumerate(matrix):
            ax.text(j, i, val_to_label[z], ha='center', va='center', size=7, fontname='Helvetica Neue')

        for i in range(len(elements)):
            # Follow these instructions to install HelveticaNeue font
            # https://fowlerlab.org/2019/01/03/changing-the-sans-serif-font-to-helvetica/
            # Step 3 is important one. Font can be downloaded from Internet. Remember to add the Bold option
            ax.text(-0.7, i, elements[i][2], ha='right', va='center', fontname='Helvetica Neue', weight='bold',
                    color=elements[i][1])
            ax.text(i, -1, elements[i][2], ha='center', va='center', fontname='Helvetica Neue', weight='bold',
                    color=elements[i][1])

        ax.axes.xaxis.set_ticklabels([])
        ax.axes.yaxis.set_ticklabels([])
        plt.tick_params(left=False, bottom=False, top=False, right=False)

        plt.savefig(f'/path/to/save/matrix_sym{symmetric}.pdf', bbox_inches='tight')


if __name__ == '__main__':
    main()
