import numpy as np
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import rc

import seaborn as sns

#Plots in the paper were created using this script. 
#Prior to running this script, 'run_coupled_system.cc' must be run on nx,ny=50 and nx,ny=100 for the convex and the non-convex losses.

matplotlib.rcParams['backend'] = 'pdf'
rc('text', usetex=True)

dim_list = [1, 2, 4, 8, 16, 32]

res = []

for nx, ny in [(50, 50), (100, 100)]:
    for convex in [True, False]:
        cv_string = '_cv' if convex else ''
        filestring = "nx={:04d}_ny={:04d}{}".format(nx, ny, cv_string)

        log_data = np.loadtxt("log_" + filestring + ".dat")
        time_steps = 6
        nruns = 20
        log_data_list = []

        for row in range(time_steps):
            this_list = log_data[(10 * 3 * 20 * row):(10 * 3 * 20 * (row + 1)), :]
            for method, offset in zip(['LGV', 'MD', 'WFR'], [0, 10, 20]):
                for nrun in range(nruns):
                    data = this_list[(30 * nrun + offset):(30 * nrun + offset) + 10, 2]
                    for t, ni in enumerate(data):
                        res.append(dict(dim=row, method=method, ni=ni, nx=nx, convex=convex, run=nrun, t=t))
res = pd.DataFrame(res)
res.set_index(['convex', 'nx', 'method', 'dim', 'run', 't'], inplace=True)
res['ni'] += 1e-10
res['ni'] = np.log(res['ni'])
res = res.groupby(['convex', 'nx', 'method', 'dim', 'run']).mean()
res = res.groupby(['convex', 'nx', 'method', 'dim']).aggregate(['mean', 'std'])

fig, axes = plt.subplots(1, 2, figsize=(4, 2), sharex=True, sharey=True, constrained_layout=False)
fig.subplots_adjust(bottom=0.28, top=0.95, left=0.14, right=0.98)
linestyles = {50: '--', 100: '-'}
colors = {'MD': 'C0', 'LGV': 'C1', 'WFR': 'C2'}
labels = {'MD': 'Mirror DA', 'LGV': 'Langevin DA', 'WFR': 'WFR DA'}
method_labels = []
method_handles = []
nx_labels = []
nx_handles = []
for ax, (convex, this_res) in zip(axes, res.groupby('convex')):
    for (nx, method), data in this_res.groupby(['nx', 'method']):
        label = labels[method]
        handle, = ax.plot(dim_list, data[('ni', 'mean')], marker='o', color=colors[method],
                        label=labels[method], linestyle=linestyles[nx],
                        )
        if nx == 100 and convex:
            method_labels.append(label)
            method_handles.append(handle)
        if method == 'WFR' and convex:
            nx_labels.append(f'{nx}x2 particles')
            nx_handles.append(handle)
        ax.fill_between(dim_list, data[('ni', 'mean')] + data[('ni', 'std')],
                        data[('ni', 'mean')] - data[('ni', 'std')], alpha=0.2)
    ax.set_xscale('log')
    ax.set_ylim([-10, 5])
    ax.grid(axis='y')
    ax.set_xticks(dim_list)
    ax.minorticks_off()
    ax.set_xticklabels(dim_list)
axes[0].set_ylabel('Log NI error')
legend1 = axes[0].legend(nx_handles, nx_labels, frameon=False,
                     loc='upper left', bbox_to_anchor=(-0.05, -0.22), ncol=2, columnspacing=0.5)
axes[0].legend(method_handles, method_labels, loc='upper left', bbox_to_anchor=(-0.05, -0.1), frameon=False, ncol=3, columnspacing=0.5)
axes[0].annotate('Dimension', xy=(0, 0), xytext=(0, -14), xycoords='axes fraction',
                 textcoords='offset points', ha='right')
axes[0].add_artist(legend1)
sns.despine(fig)
plt.savefig('figure.pdf')
plt.show()
