import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import pandas as pd
import numpy as np

def compute_graph_acc(oracl_dag, dag_estimate, dag_estimate_iid, graph_acc):
    graph_acc['iid'].append(int(oracl_dag == dag_estimate_iid))
    graph_acc['CdF'].append(int(oracl_dag == dag_estimate))
    return graph_acc

def mse_causal_effect(oracle_trunc_fac, generalized_trunc_fac, iid_trunc_fac,
                      generalized_trunc_fac_given_oracle_dag,
                      iid_trunc_fac_given_oracle_dag,
                      effect_mse, intervention_desc):
    oracle = oracle_trunc_fac[intervention_desc]
    exch_estimated = generalized_trunc_fac[intervention_desc]
    ablation_estimated = generalized_trunc_fac_given_oracle_dag[intervention_desc]
    iid_estimated = iid_trunc_fac[intervention_desc]
    ablation_iid_estimated = iid_trunc_fac_given_oracle_dag[intervention_desc]
    mse = 0
    mse_iid = 0
    mse_ablation = 0
    mse_iid_ablation = 0
    for key in oracle.keys():
        mse += (oracle[key] - exch_estimated[key])**2
        mse_ablation += (oracle[key] - ablation_estimated[key])**2
        mse_iid += (oracle[key] - iid_estimated[key])**2
        mse_iid_ablation += (oracle[key] - ablation_iid_estimated[key])**2
    effect_mse['doF'].append(mse)
    effect_mse['doF-ablation'].append(mse_ablation)
    effect_mse['iid'].append(mse_iid)
    effect_mse['iid-ablation'].append(mse_iid_ablation)
    return effect_mse

def format_one_decimal(x, pos):
    return f'{x:.1f}'

formatter = FuncFormatter(format_one_decimal)

def plot(graph_acc_env, mse_acc_env, std_graph_acc_env, std_mse_acc_env):
    # Convert the data to DataFrames



    df_mse = pd.DataFrame(mse_acc_env).T.reset_index().rename(columns={'index': 'x'})
    df_dag = pd.DataFrame(graph_acc_env).T.reset_index().rename(columns={'index': 'x'})

    df_mse_sd = pd.DataFrame(std_mse_acc_env).T.reset_index().rename(columns={'index': 'x'})
    df_dag_sd = pd.DataFrame(std_graph_acc_env).T.reset_index().rename(columns={'index': 'x'})

    # Plotting
    sns.set(style='whitegrid')  # Use seaborn's whitegrid style for the plot
    fig, (ax1) = plt.subplots(1, 1, figsize=(8, 6))


    # Plotting doF on the left y-axis
    sns.lineplot(data=df_mse, x='x', y='doF', ax=ax1, label='Do-Finetti',marker='o', color='r', linewidth=3, markersize=10)
    ax1.fill_between(df_mse['x'], df_mse['doF'] - df_mse_sd['doF'], df_mse['doF'] + df_mse_sd['doF'], alpha=0.4, color='r')
    sns.lineplot(data=df_mse, x='x', y='iid', ax=ax1,label='IID', marker='o', color='b', linewidth=3, markersize=10)
    ax1.fill_between(df_mse['x'], df_mse['iid'] - df_mse_sd['iid'], df_mse['iid'] + df_mse_sd['iid'], alpha=0.4,
                     color='b')
    sns.lineplot(data=df_mse, x='x', y='doF-ablation', ax=ax1, label='Do-Finetti-w-true-dag', marker='o', color='r', linewidth=3,markersize=10, linestyle='--')
    ax1.fill_between(df_mse['x'], df_mse['doF-ablation'] - df_mse_sd['doF-ablation'], df_mse['doF-ablation'] + df_mse_sd['doF-ablation'], alpha=0.4,
                     color='r')
    sns.lineplot(data=df_mse, x='x', y='iid-ablation', ax=ax1,label='IID-w-true-dag', marker='o', color='b', linewidth=3, markersize=10, linestyle = '--')
    ax1.fill_between(df_mse['x'], df_mse['iid-ablation'] - df_mse_sd['iid-ablation'], df_mse['iid-ablation'] + df_mse_sd['iid-ablation'], alpha=0.4,
                     color='b')
    ax1.set_xlabel('Number of environments', fontsize=15, fontweight='bold')
    ax1.set_ylabel('Causal Effect MSE (the lower the better)', fontsize=15,  fontweight='bold')
    ax1.tick_params(axis='y', labelsize=12)
    ax1.tick_params(axis='x', labelsize=12)
    ax1.legend()
    ax1.grid(True)
    fig.tight_layout()
    plt.savefig('doF-bivariate-mse.pdf')


    fig, (ax2) = plt.subplots(1, 1, figsize=(8, 6))
    # Creating a second y-axis
    sns.lineplot(data=df_dag, x='x', y='CdF', ax=ax2, label='Do-Finetti', marker='s', color='r', linewidth=5, markersize=10)
    sns.lineplot(data=df_dag, x='x', y='iid', ax=ax2, label='IID', marker='v', color='b', linewidth=5, markersize=10)
    # ax2.fill_between(df_dag_sd['x'], df_dag['CdF'] - df_dag_sd['CdF'],
    #                  df_dag['CdF'] + df_dag_sd['CdF'], alpha=0.4,
    #                  color='r')
    # ax2.fill_between(df_dag_sd['x'], df_dag['iid'] - df_dag_sd['iid'],
    #                  df_dag['iid'] + df_dag_sd['iid'], alpha=0.4,
    #                  color='b')
    ax2.set_ylabel('Graph Accuracy (the higher the better)', fontsize=15, fontweight='bold')
    ax2.set_xlabel('Number of environments', fontsize=15, fontweight='bold')
    ax2.tick_params(axis='y', labelsize=12)
    ax2.legend()

    # Ensure only one grid is visible
    ax2.grid(True)  # Disable grid on secondary axis

    # Title and legend
    #plt.title('Plot of doF and CdF over Different x Values', fontsize=16, fontweight='bold')
    fig.tight_layout()
    #fig.legend(loc='upper left', bbox_to_anchor=(0.1, 0.9), fontsize=12)
    # Customizing and placing the legend
    # lines = [doFinetti_dag.lines[0], doFinetti_mse.lines[0], iid_dag.lines[1], iid_mse.lines[1]]
    # labels = ['Do-Finetti-DAG', 'Do-Finetti-MSE', 'IID-DAG', 'IID-MSE']
    #
    # fig.legend(lines, labels, loc='upper right', bbox_to_anchor=(0.9, 0.9), fontsize=12, handlelength=2.5)


    # Show plot
    plt.savefig('doF-bivariate-graph.pdf')
#
