import matplotlib.pyplot as plt
import numpy as np
from analysis import fit_loss_with_saturation # TODO: move the part of plot that uses this to notebooks
from configs import *
from utils import *

def opt_param_vs_compute_plot(data, optimal_pairs, fit_dict, key='n', fit_dict_weighted=None,
                              fit_dicts_bootstrap=None, plot_error_bars=False, conf_level=0.05,
                              plot_bootstrap_obvs=False, label_fit=True, return_legend=False,
                              print_fit_as_text=False, obs_color='r', flop_grid_endpoints=None, 
                              big_font_conf_int=False, kaplan_adjusted=False):
    key_coef, key_exponent, key_std = f'{key}_coef', f'{key}_exponent', f'{key}_star_std'
    flop_vals = data.flops.unique()
    enough_for_fit = len(optimal_pairs.query("~n.isna()")) >= 2

    if flop_grid_endpoints is None:
        flop_grid_endpoints = (np.min(flop_vals), np.max(flop_vals))
    flops_grid = np.geomspace(*flop_grid_endpoints, 20)

    if f'{key}_star_std' in optimal_pairs.columns and plot_error_bars:
        lower_bounds = optimal_pairs[key] - optimal_pairs[key] * np.exp(-optimal_pairs[key_std])
        upper_bounds = optimal_pairs[key] * np.exp(optimal_pairs[key_std]) - optimal_pairs[key]
        plt.errorbar(optimal_pairs.set_index('flops').index, optimal_pairs[key], yerr=[lower_bounds, upper_bounds],
                     fmt='o', color=obs_color, alpha=1, ms=6, label='Observations', capsize=6,
                     markeredgecolor='k', markeredgewidth=0.5)
    else:
        optimal_pairs.set_index('flops')[key].plot(style='x', color='r', alpha=1, ms=12, label='Observations')


    if enough_for_fit:
        if fit_dicts_bootstrap is not None:
            fit_vals_bootstrap = [fd[key_coef] * flops_grid ** fd[key_exponent] for fd in fit_dicts_bootstrap]
            exponents_bootstrap = [fd[key_exponent] for fd in fit_dicts_bootstrap]
            coefs_bootstrap = [fd[key_coef] * CHINCHILLA_FLOPS ** fd[key_exponent] for fd in fit_dicts_bootstrap]

            conf_int_lower = np.quantile(fit_vals_bootstrap, conf_level / 2, axis=0)
            conf_int_upper = np.quantile(fit_vals_bootstrap, 1 - conf_level / 2, axis=0)
            exponents_lower = np.quantile(exponents_bootstrap, conf_level / 2, axis=0)
            exponents_upper = np.quantile(exponents_bootstrap, 1 - conf_level / 2, axis=0)
            coefs_lower = np.quantile(coefs_bootstrap, conf_level / 2, axis=0)
            coefs_upper = np.quantile(coefs_bootstrap, 1 - conf_level / 2, axis=0)

            plt.fill_between(flops_grid, conf_int_lower, conf_int_upper, color='gray', alpha=0.2, label=f'{100*(1-conf_level):n}\% confidence region')
            if plot_bootstrap_obvs:
                data.set_index('flops')[f'{key}_stars'].dropna().explode().groupby(level=0).sample(25).plot(style='xk', markersize=3,
                                                                                                            alpha=0.5)

        fit_colors = [obs_color, 'g']
        for i, (name, d) in enumerate(dict(Basic=fit_dict, Weighted=fit_dict_weighted).items()):
            if d is None:
                continue
            if label_fit:
                fit_label = f'{name} Fit: ${key} =$ {d[key_coef]:.3g} $C^{{{d[key_exponent]:.4g}}}$'
            else:
                fit_label = f'${key} =$ {d[key_coef]:.3g} $C^{{{d[key_exponent]:.4g}}}$'
            plt.plot(flops_grid, d[key_coef] * flops_grid ** (d[key_exponent]), '--', color=fit_colors[i], lw=2,
                     label=fit_label,
                     )

    if key == 'n':
        plt.plot(flops_grid, (flops_grid / (6 * 20)) ** 0.5, '-.', color='k', lw=2, label='Hoffmann law')
        if not kaplan_adjusted:
            plt.plot(flops_grid, 1.6e9 * (flops_grid / (1e15 * 24 * 60 * 60)) ** 0.88, ':', color='gray', lw=2, label='Kaplan law')
        else:
            plt.plot(flops_grid, 1.3e9 * (flops_grid / (8.64e19)) ** 0.73, ':', color='gray', lw=4, label="Adjusted Kaplan law")
    elif key == 't':
        plt.plot(flops_grid, (flops_grid / (6 / 20)) ** 0.5, '-.', color='k', lw=2, label='Hoffmann law')
        if not kaplan_adjusted:
            plt.plot(flops_grid, flops_grid / (1.6e9 * (flops_grid / (1e15 * 24 * 60 * 60)) ** 0.88) / 6, ':', color='gray', lw=2, label='Kaplan law')
        else:
            plt.plot(flops_grid, flops_grid / (1.3e9 * (flops_grid / (8.64e19)) ** 0.73) / 6, ':', color='gray', lw=4, label="adjusted Kaplan law")
        
    else:
        plt.plot(flops_grid, 20 * np.ones_like(flops_grid), '-.', color='k', lw=2, label='Hoffmann law')
        if not kaplan_adjusted:
            plt.plot(flops_grid, flops_grid / (1.6e9 * (flops_grid / (1e15 * 24 * 60 * 60)) ** 0.88)**2 / 6, ':', color='gray', lw=2, label='Kaplan law')
        else:
            plt.plot(flops_grid, flops_grid / (1.3e9 * (flops_grid / (8.64e19)) ** 0.73)**2 / 6, ':', color='gray', lw=4, label="Adjusted Kaplan law")
        

    if print_fit_as_text and enough_for_fit and fit_dicts_bootstrap is not None:
        if key == 'n':
            key_to_print = 'N^\star'
        elif key == 't':
            key_to_print = 'D^\star'
        else:
            key_to_print = r'\rho^\star'
        key_exponent_to_print = {'n':'$a$', 't':'$b$', 'multiplier':'$r$'}
        # Adding the fit parameters as text
        first_line = r'%s = $%.3g$ ' % (key_exponent_to_print[key], (fit_dict[key_exponent])) 
        if big_font_conf_int: # TODO: make this a parameter. having problems with formatting now
            conf = r'{\fontsize{18pt}{3em}\selectfont {$(%.2f, %.2f)$}}' % (exponents_lower, exponents_upper)
        else:
            conf = r'{\fontsize{12pt}{3em}\selectfont {$(%.2f, %.2f)$}}' % (exponents_lower, exponents_upper)
        second_line = r'$%s(%s)$ = %s ' % (key_to_print, CHINCHILLA_STR, fmt_model_size(fit_dict[key_coef] * CHINCHILLA_FLOPS ** fit_dict[key_exponent], key=key))
        if big_font_conf_int:
            conf2 = r'{\fontsize{17pt}{3em}\selectfont {(%s, %s)}}' % (fmt_model_size(coefs_lower, key=key), fmt_model_size(coefs_upper, key=key))
        else:
            conf2 = r'{\fontsize{11pt}{3em}\selectfont {(%s, %s)}}' % (fmt_model_size(coefs_lower, key=key), fmt_model_size(coefs_upper, key=key))
        exponent_ci_text = first_line + conf + '\n' + second_line + conf2
        props = dict(boxstyle='round', facecolor='white', alpha=0.7, linewidth=0)

        ax = plt.gca()
        ax.text(0.05, 0.95, exponent_ci_text, transform=ax.transAxes, #fontsize=10,
                verticalalignment='top', bbox=props)

    plt.ylim([optimal_pairs[key].min() * 0.7, optimal_pairs[key].max() * 1.3])
    plt.xscale('log')
    plt.yscale('log')
    plt.grid('all')
    if not return_legend:
        plt.legend(loc='upper left', bbox_to_anchor=[0, 1])
    else:
        return plt.gca().get_legend_handles_labels()


def isoflop_curves_plot(data, optimal_pairs, return_min_max_loss=False, min_multiplier=None):
    colors_scale = plt.cm.cool(np.linspace(0.1, 1, len(data.flops.unique())))
    for i, c in enumerate(data.flops.unique()):
        data_c = data.loc[data.flops == c].iloc[0]
        skip_optimal = False
        if len(optimal_pairs.loc[optimal_pairs.flops == c]) > 0:
            optimal_pairs_c = optimal_pairs.loc[optimal_pairs.flops == c].iloc[0]
        else:
            skip_optimal = True
        if data_c.n_interp is None:
            continue
        # print(min_multiplier)
        if min_multiplier is not None:
            mask = c / (data_c['orig_n']**2) / 6>= min_multiplier
            mask_interp = c / (data_c['n_interp']**2) / 6 >= min_multiplier
        else:
            mask = np.ones_like(data_c['orig_n'], dtype=bool)
            mask_interp = np.ones_like(data_c['n_interp'], dtype=bool)
        plt.scatter(data_c.orig_n[mask], data_c.orig_loss[mask], color=colors_scale[i], marker=get_marker(i), label=f'C={c:.4g}', s=15)
        # print(data_c['n_interp'])
        plt.plot(data_c.n_interp[mask_interp], data_c.loss_interp[mask_interp], '--', color=colors_scale[i])
        if not skip_optimal:
            plt.plot(optimal_pairs_c['n'], data_c['loss_interp'][int(data_c['opt_ind'])],
                     '*', ms=12, color=colors_scale[i], markeredgecolor='k', alpha=0.5)

    plt.xscale('log')
    plt.yscale('log')
    plt.grid(axis='x', which='major')
    plt.grid(axis='y', which='minor')
    
    plt.tick_params(axis='x', labelsize=12)
    plt.tick_params(axis='y', which='minor', labelsize=12)

    if return_min_max_loss:
        return plt.gca().get_ylim()
    else:
        return plt.gca().get_legend_handles_labels()


def opt_loss_vs_compute_plot(summary_df, configs_to_fit=tuple(), 
                             fit_min_flop=1e16, fit_max_flop=5e17, conf_level=0.05,
                             return_legend=True, print_fit_as_text=False, bootstrap_num=None):
    for i, (_, row) in enumerate(summary_df.iterrows()):
        config = tuple(row[field] for field in ['dataset', 'hparams', 'warmup', 'decay', 'param_count', 'val'])
        optimal_pairs = row['optimal_pairs'].set_index('flops').dropna()
        if bootstrap_num is None:
            bootstrap_num = row['data'][['loss_stars']].dropna().applymap(len).min().min()
        
        loss_bootstrap_pop = [
            row['data'].set_index('flops').dropna().applymap(lambda x: maybe_get_item(x, k)).loss_stars.truncate(before=fit_min_flop, after=fit_max_flop) 
            for k in range(bootstrap_num)
        ]
        
        flops_all = optimal_pairs.dropna().index.values
        flops = optimal_pairs.dropna().truncate(before=fit_min_flop, after=fit_max_flop).index.values
        loss = optimal_pairs.dropna().truncate(before=fit_min_flop, after=fit_max_flop).loss.values

        row['optimal_pairs'].set_index('flops').loss.dropna().plot(
            logx=True, logy=True, style='-'+CONFIG_DICT_MARKER[config],
            label=CONFIG_DICT_LABEL[config],
            lw=0.5, color=CONFIG_DICT_COLOR[config], markersize=7, markerfacecolor='none'
        )

        extrap_flops = np.geomspace(0.005 * optimal_pairs.index.min(), 200 * optimal_pairs.index.max(), 100)

        if config in configs_to_fit:
            A, E, alpha = fit_loss_with_saturation(flops, loss).values()
            plt.plot(flops, E + A*(flops**-alpha), lw=1, color=CONFIG_DICT_COLOR[config], linestyle='-',
                     label=f'Fit, $L={A:.2f}C^{{-{alpha:.3f}}} + {E:.2f}$', marker=CONFIG_DICT_MARKER[config])
            plt.plot(extrap_flops, E + A*(extrap_flops**-alpha), label=f'Extrapolation', lw=1,
                     color=CONFIG_DICT_COLOR[config], linestyle='--')
            if loss_bootstrap_pop:
                fit_dicts_bootstrap = fit_loss_with_saturation(flops, loss_bootstrap_pop)
                fit_vals_bootstrap = [fd['E'] + fd['A'] * (flops_all ** -fd['alpha']) for fd in fit_dicts_bootstrap]
                conf_int_lower = np.quantile(fit_vals_bootstrap, conf_level / 2, axis=0)
                conf_int_upper = np.quantile(fit_vals_bootstrap, 1 - conf_level / 2, axis=0)
                plt.fill_between(flops_all, conf_int_lower, conf_int_upper, color=CONFIG_DICT_COLOR[config], alpha=0.2,
                                 label=f'{100*(1-conf_level):n}% confidence region')

    if print_fit_as_text and configs_to_fit:
        A_fmt = f'{A:.2e}'.replace('+0', '').replace('e', '\mathrm{e}')
        fit_text = f'Fit: $L={A_fmt}\\cdot C^{{-{alpha:.3f}}} + {E:.2f}$'
        props = dict(boxstyle='round', facecolor='white', alpha=0.7, linewidth=0)
        ax = plt.gca()
        ax.text(0.95, 0.95, fit_text, transform=ax.transAxes, fontsize=12,
                verticalalignment='top', horizontalalignment='right', bbox=props)

    x_lo = round_down_to_first_decimal(summary_df.optimal_pairs.apply(lambda x: x.flops.min()).min() * 0.8)
    x_hi = round_up_to_first_decimal(summary_df.optimal_pairs.apply(lambda x: x.flops.max()).max() * 1.2)
    plt.xlim(x_lo, x_hi)
    plt.ylabel('Estimated optimal loss $L^\star(C)$')
    plt.xlabel('Compute $C$ [FLOPs]')
    plt.grid(axis='y', which='minor')
    plt.grid(axis='x', which='major')
    if not return_legend:
        plt.legend(loc='upper left', bbox_to_anchor=[1, 1.01])
    else:
        return plt.gca().get_legend_handles_labels()


def compute_analysis_plot(results_df, key='exponent',
                          gt_value=None, last_is_gt=False, add_to_cost=0, conf_level=0.05, show_legend=False):
    show_df = results_df.copy()
    show_df.index = show_df.index + add_to_cost
    if last_is_gt:
        if key.startswith('prediction'):
            gt_value = show_df.iloc[-1].optimal_pairs.iloc[-1]['n']
        else:
            gt_value = show_df.iloc[-1].exponent

    plt.fill_between(show_df.index, show_df[key + '_lo'], show_df[key + '_hi'], color='gray', alpha=0.2,
                     label=f'${100 * (1 - conf_level):g}\%$ confidence region')
    plt.plot(show_df[key], '-k', lw=2, label='Point estimate')
    if key.startswith('prediction'):
        plt.yscale('log')
        plt.ylim([round_down_to_first_decimal(show_df[key + '_lo'].min()),
                  round_up_to_first_decimal(show_df[key + '_hi'].max())])


    if gt_value is not None:
        plt.axhline(gt_value, ls='--', color='k', lw=1, label='Nominal value')

    ax1 = plt.gca()
    handles, labels = ax1.get_legend_handles_labels()

    plt.xlabel('Compute $C$ [FLOPs]')

    if gt_value is not None:
        bootstrap_error = show_df.bs_predictions.apply(
            lambda x: (((x[key] - gt_value) / gt_value) ** 2).mean() ** 0.5)
        ax2 = ax1.twinx()
        ax2.plot(bootstrap_error, '-b', lw=1, label='RMS relative\n bootstrap error')
        plt.yscale('log')
        plt.ylim([10 ** np.floor(np.log10(bootstrap_error.min())),
                  10 ** np.ceil(np.log10(bootstrap_error.max()))])
        handles2, labels2 = ax2.get_legend_handles_labels()
        handles += handles2
        labels += labels2
        # ax2.set_ylabel('Y axis label 2', color='blue')
        ax2.tick_params(axis='y', labelcolor='blue', which='both')
        ax2.spines['right'].set_color('blue')

    plt.xscale('log')
    plt.xlim([show_df.index.min(), show_df.index.max()])


    if show_legend:
        plt.legend(labels=labels, handles=handles)
    else:
        return handles, labels


def plot_arrows(axes, arrows):
    # Define arrow properties
    arrowprops = dict(facecolor='black', edgecolor='black', width=6, headwidth=20, shrink=0.1)

    # Add arrows between the specified subplots
    for start_idx, end_idx, start_pos, end_pos in arrows:
        ax_start = axes.flatten()[start_idx]
        ax_end = axes.flatten()[end_idx]

        # Get positions in figure coordinates
        start_bbox = ax_start.get_position()
        end_bbox = ax_end.get_position()

        # Calculate arrow start and end points
        start_coord = (
        start_bbox.x0 + start_pos[0] * start_bbox.width, start_bbox.y0 + start_pos[1] * start_bbox.height)
        end_coord = (end_bbox.x0 + end_pos[0] * end_bbox.width, end_bbox.y0 + end_pos[1] * end_bbox.height)

        # Draw the arrow on the main figure
        ax_end.annotate(
            '',
            xy=end_coord, xytext=start_coord,
            arrowprops=arrowprops,
            xycoords='figure fraction', textcoords='figure fraction',
        )


def plot_sweep_key(show_df, reduced_df, show_key, fit_dict, x_ticks, y_ticks_dict, excess_loss_thresh, min_params_for_fit, max_params_for_fit, return_legend=False):
    
    with plt.rc_context({'font.size': 20, # For the text
            'axes.titlesize': 22, # For the subplot titles
            'axes.labelsize': 20, # For the x and y labels
            'xtick.labelsize': 18, # For the x tick labels
            'ytick.labelsize': 18, # For the y tick labels
            'legend.fontsize': 16, # For the legend
            'figure.titlesize': 24}): # For the figure title
        handles = []
        labels = []
        for j, beta2_ in enumerate([0.95, 0.99, 0.999]):
            query = f'beta2 == {beta2_}' # & excess_loss < @excess_loss_thresh'
            sample_ = reduced_df.reset_index().query(query).sort_values('params')
            if len(sample_) > 0:
                plt.scatter(sample_['params'], sample_[show_key], alpha=np.maximum(0.01, 1 - sample_['excess_loss'].values / excess_loss_thresh),
                            marker=get_marker(j), label='data, $\beta_2$ = ' + str(beta2_), c=get_color(j), 
                            s=8 * (12 - 3 * j), edgecolors='k')
                handle = plt.Line2D([0], [0], marker=get_marker(j), color='w', markerfacecolor=get_color(j),
                                    markersize=8, markeredgewidth=1, markeredgecolor='k')
                handles.append(handle)
                labels.append(fr'Grid points, $\beta_2 = {beta2_}$')      

        show_df.set_index('params')[show_key].plot(
            logx=True, logy=True, marker='d', markersize=10, label='Interpolated optimal {}'.format(show_key),
            lw=1.25, markeredgewidth=2, markeredgecolor=get_color(j+2), color=get_color(j+2), markerfacecolor='none')
        handles.append(plt.Line2D(
            [0], [0], marker='d', color='w', markerfacecolor='none', markersize=8,
            markeredgewidth=2, markeredgecolor=get_color(j+2)))
        labels.append('Interpolated optimal {}'.format(show_key.upper()))

        show_df_for_fit = show_df.query('params > @min_params_for_fit and params < @max_params_for_fit')
        plt.scatter(show_df_for_fit.params, show_df_for_fit[show_key], marker='d', s=200, c=get_color(j+2), label='Points used for fit', edgecolors=get_color(j+2), linewidths=0.2)
        handles.append(plt.Line2D([0], [0], marker='d', markersize=10, c=get_color(j+2), markeredgewidth=1, markeredgecolor=get_color(j+2)))
        labels.append('Points used for fit')

        
        x_vals = np.array([0.8 * x_ticks[0], 1.2 * x_ticks[-1]])
        fit_vals = fit_dict[show_key + '_coef'] * (x_vals ** fit_dict[show_key + '_exponent'])
        label_computed_fit = f'Fit: {show_key.upper()} = ${fit_dict[show_key + "_coef"]:.2g} N^{{{fit_dict[show_key + "_exponent"]:.3g}}} (R^2={fit_dict[show_key + "_r2"]:.3g})$'
        plt.plot(x_vals, fit_vals, '--k', 
                    label=label_computed_fit)

        C_to_N_coef, C_to_N_exponent = 1 / (120 ** 0.5), 0.5  # Hoffmann

        N_to_C_coef = (1 / C_to_N_coef) ** (1 / C_to_N_exponent)
        N_to_C_exponent = 1 / C_to_N_exponent

        if show_key == 'bs':
            fit_vals_manual = 160.0 * ((x_vals/108e6) ** (2/3))
            label_manual_fit = r'Rounded fit: BS =  $160 (N / 108\mathrm{e}6)^{2/3}$'
            ds_C_coef, ds_C_exponent = 0.2920, 0.3271
            ds_N_coef, ds_N_exponent = ds_C_coef * (N_to_C_coef ** ds_C_exponent), N_to_C_exponent * ds_C_exponent

            ds_N_coef /= 2048  # to get BS in sequences instead of in tokens

            fit_deepseek = ds_N_coef * (x_vals ** ds_N_exponent)  # need to devide by sequence length
            label_deepseek_fit = f'DeepSeek fit: BS$ = {ds_N_coef:.2g} N^{{{ds_N_exponent:.3g}}}$'

        elif show_key == 'lr':
            fit_vals_manual = 0.0047 * ((x_vals/108e6) ** (-1/3))
            label_manual_fit = r'Rounded fit: LR =  $0.0047 (N / 108\mathrm{e}6)^{-1/3}$'

            ds_C_coef, ds_C_exponent = 0.3118, -0.1250

            ds_N_coef, ds_N_exponent = ds_C_coef * (N_to_C_coef ** ds_C_exponent), N_to_C_exponent * ds_C_exponent

            fit_deepseek = ds_N_coef * (x_vals ** ds_N_exponent)
            label_deepseek_fit = f'DeepSeek fit: LR $= {ds_N_coef:.2g} N^{{{ds_N_exponent:.3g}}}$'
        else:
            fit_vals_manual = None
        if fit_vals_manual is not None:
            plt.plot(x_vals, fit_vals_manual, ':r',
                        label=label_manual_fit)

            plt.plot(x_vals, fit_deepseek, linestyle='-.', color='orange',
                        label=label_deepseek_fit)

        plt.xscale('log')
        plt.yscale('log')
        plt.xticks(x_ticks, labels=[f'{x//1e6:n}M' for x in x_ticks])
        plt.yticks(y_ticks_dict[show_key], labels=y_ticks_dict[show_key])
        plt.gca().set_yticks([], minor=True)
        plt.gca().set_xticks([], minor=True)
        plt.gca().set_axisbelow(True)
        plt.grid('major', color=[0.8, 0.8, 0.8, 1])
        # print(y_ticks_dict[show_key])
        # print(0.8 * min(y_ticks_dict[show_key]), max(1.2 * y_ticks_dict[show_key][-1], fit_vals.max()))
        plt.xlim(0.8 * x_ticks[0], 1.2 * x_ticks[-1])
        plt.ylim(0.8 * min(y_ticks_dict[show_key]), max(max(1.2 * y_ticks_dict[show_key]), fit_vals.max()))
        plt.xlabel(f'$N$')
        plt.title(f'{KEYS_TO_TITLE_SWEEP[show_key]}')

        handles = handles + [plt.Line2D([], [], linestyle='--', color='k'),
                                        plt.Line2D([], [], linestyle=':', color='r'),
                                        plt.Line2D([], [], linestyle='-.', color='orange')]
        labels = labels + [label_computed_fit, label_manual_fit, label_deepseek_fit]
        if return_legend:
            return handles, labels
        else:
            plt.legend(handles=handles, labels=labels)