import matplotlib.pyplot as plt
import pickle as pk

plt.style.use('science.mplstyle')

SAVE_FOLDER = ""  # TODO: specify path to folder with .pk run data files in them

RUNS_FILES = {
    #"tdfa_000": "data_tdfa_000.pk",
    #"tdfa_001": "data_tdfa_001.pk",
    #"tdfa_003": "data_tdfa_003.pk",
    #"tdfa_005": "data_tdfa_005.pk",
    #"tdfa_010": "data_tdfa_010.pk",
    #"tdfa_020": "data_tdfa_020.pk",
    "dfa_000": "data_dfa_000.pk",
    "dfa_001": "data_dfa_001.pk",
    "dfa_003": "data_dfa_003.pk",
    "dfa_005": "data_dfa_005.pk",
    "dfa_010": "data_dfa_010.pk",
    "dfa_020": "data_dfa_020.pk",
    "pdfa_000": "data_odfa_000.pk",
    "pdfa_001": "data_odfa_001.pk",
    "pdfa_003": "data_odfa_003.pk",
    "pdfa_005": "data_odfa_005.pk",
    "pdfa_010": "data_odfa_010.pk.pk",
    "pdfa_020": "data_odfa_020.pk.pk",
    "dfa_vanilla": "vanilla/data_dfa.pk",
    "bp_vanilla": "vanilla/data_bp.pk",
    #"tdfa_vanilla": "vanilla/data_tdfa.pk",
    "pdfa_vanilla": "vanilla/data_odfa.pk",
}

runs = {}
for run_name, run_path in RUNS_FILES.items():
    with open(SAVE_FOLDER + run_path, 'rb') as run_file:
        runs[run_name] = pk.load(run_file)

train_loss_data = {run_name: [] for run_name in runs.keys()}
validation_loss_data = {}
validation_acc_data = {}
alignment_data = {run_name: [] for run_name in runs.keys()}
for run_name, run_data in runs.items():
    # Run data: (training, validation) [... epoch ...] {data}
    for epoch_loss_data in [epoch_training_data['loss'] for epoch_training_data in run_data[0]]:
        train_loss_data[run_name] += [loss for i, loss in enumerate(epoch_loss_data) if i % 10 == 0]
    validation_loss_data[run_name] = [val_epoch_data["loss"] for val_epoch_data in run_data[1]]
    validation_acc_data[run_name] = [val_epoch_data["accuracy"] for val_epoch_data in run_data[1]]
    if 'vanilla' not in run_name:
        for epoch_alignment_data in [epoch_training_data["alignment"]["fc2"][0] for epoch_training_data in run_data[0]]:
            alignment_data[run_name] += [epoch_alignment_data]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
fig.subplots_adjust(wspace=0.25)
#ax1.set_title("Validation accuracy [%]")
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Validation acc. [%]")
#ax2.set_title("Alignment (cos. sim.)")
ax2.set_xlabel("Epochs")
ax2.set_ylabel("Alignment (cos. sim.)")

for run_name, run_loss in validation_acc_data.items():
    # Accuracy plot.
    linestyle='solid'
    if 'bp' in run_name:
        linestyle='dashdot'
    elif 'vanilla' in run_name:
        linestyle='dashed'

    base_color = (4.7, 36.5, 64.7)
    if 'bp' in run_name:
        base_color = (0, 0, 0)
    if 'dfa' in run_name:
        base_color = (100, 17.3, 0)
    if 'tdfa' in run_name:
        base_color = (100, 58.4, 0)
    if 'pdfa' in run_name:
        base_color = (4.7, 36.5, 64.7)

    if '000' in run_name:
        base_color += (100,)
    if '001' in run_name:
        base_color += (75,)
    if '003' in run_name:
        base_color += (50,)
    if '005' in run_name:
        base_color += (30,)
    if '010' in run_name:
        base_color += (20,)
    if '020' in run_name:
        base_color += (10,)
    base_color = [c / 100 for c in base_color]

    label = None
    if run_name == 'dfa_000':
        label= 'DFA'
    if run_name == 'pdfa_000':
        label= 'PDFA'
    if run_name == 'bp_vanilla':
        label= 'BP'
    ax1.plot(run_loss[:-1], label=label, linestyle=linestyle, color=base_color)
    ax1.legend()

    if 'pdfa' in run_name:
        label=None
        if '000' in run_name:
            label="$\sigma$ = 0"
        if '001' in run_name:
            label="$\sigma$ = 0.01"
        if '003' in run_name:
            label="$\sigma$ = 0.03"
        if '005' in run_name:
            label="$\sigma$ = 0.05"
        if '010' in run_name:
            label="$\sigma$ = 0.1"
        if '020' in run_name:
            label = "$\sigma$ = 0.2"
        ax2.plot(alignment_data[run_name], label=label, color=base_color)
        ax2.legend()

fig.savefig(SAVE_FOLDER + 'plot.png')