# make a bar plot of p values after aggregating across various seeds. each bar should correspond to one dataset

#structure
# aggregated_results/
#     p_values/
#         {outlier_method}/
#             {normalization_method}/
#                 {model_name}/
#                     {dataset_name}.csv
# each csv has columns: seed, p_2, p_5, ..., p_{j} where j refers to the number of samples used to calculate the p value
# for now only use the largest j

#preserve the folder structure and create one plot per model_name x normalization_method x outlier_method

import sys
from matplotlib import pyplot as plt
import pandas as pd
import os
import numpy as np
from scipy import stats

def aggregate_p_values(p_vals_all_seeds, averaging = "arithmeticx2"):
    if averaging == "harmonic":
        p_vals = 1 / p_vals_all_seeds
        p_val = len(p_vals) / sum(p_vals)
    elif averaging == "arithmetic":
        p_val = sum(p_vals_all_seeds) / len(p_vals_all_seeds)
    elif averaging == "arithmeticx2":
        p_val = 2*sum(p_vals_all_seeds) / (len(p_vals_all_seeds))
    elif averaging == "fisher":
        test_statistic = -2 * np.sum(np.log(p_vals_all_seeds))
        df = 2 * len(p_vals_all_seeds)
        p_val = 1 - stats.chi2.cdf(test_statistic, df)

    elif averaging == "stouffer":
        z_scores = [stats.norm.ppf(1 - p) for p in p_vals_all_seeds]
        p_val = 1 - stats.norm.cdf(np.mean(z_scores))

    elif averaging == "wilkinson":
        p_vals_all_seeds = np.array(p_vals_all_seeds)
        r = np.sum(np.log(1 - p_vals_all_seeds))
        p_val = 1-np.exp(r)

    elif averaging == "edgington":
        s = np.sum(p_vals_all_seeds)
        n = len(p_vals_all_seeds)
        p_val = np.sum([(-1) ** (k - 1) * (s ** k) / np.math.factorial(k) for k in range(1, n + 1)])
    
    elif averaging == "bonferroni":
        p_val = np.min(p_vals_all_seeds) * len(p_vals_all_seeds)
        p_val = min(p_val, 1)
    
    return p_val


datasets_to_ignore = ["enron", "nih", "pubmed_abstracts"]

all_p_values = {}
for root, dirs, files in os.walk("aggregated_results/p_values"):
    for file in files:
        path = os.path.join(root, file)
        if any([dataset in path for dataset in datasets_to_ignore]):
            continue
        df = pd.read_csv(path)
        model_name = path.split("/")[-2]
        if model_name not in all_p_values:
            all_p_values[model_name] = {}
        normalization_method = path.split("/")[-3]
        if normalization_method not in all_p_values[model_name]:
            all_p_values[model_name][normalization_method] = {}
        outlier_method = path.split("/")[-4]
        if outlier_method not in all_p_values[model_name][normalization_method]:
            all_p_values[model_name][normalization_method][outlier_method] = {}
        dataset_name = path.split("/")[-1].split(".csv")[0]
        #get the p values of the last column of the dataframe
        column_to_select = df.columns[-1]
        p_vals_all_seeds = df[column_to_select].values

        all_p_values[model_name][normalization_method][outlier_method][dataset_name] = aggregate_p_values(p_vals_all_seeds, averaging = "wilkinson")

for model_name in all_p_values:
    if "12b" not in model_name:
        continue
    for normalization_method in all_p_values[model_name]:
        for outlier_method in all_p_values[model_name][normalization_method]:
            df = pd.DataFrame.from_dict(all_p_values[model_name][normalization_method][outlier_method], orient='index', columns=['p_value'])
            # sort the datasets by name
            df = df.sort_index()
            plt.bar(df.index, df['p_value'])
            # add horizontal line at 0.05
            plt.axhline(y=0.05, color='r', linestyle='-')
            plt.xlabel('Dataset')
            # rotate the x-axis labels
            plt.xticks(rotation=90)
            plt.ylabel('p-value')
            plt.title(f'p-values for {model_name} with {normalization_method} and {outlier_method}')
            os.makedirs("plots_p_w", exist_ok=True)
            directory = "plots_p_w/" 

            os.makedirs(directory, exist_ok=True)
            path = directory + model_name + "_" + outlier_method + "_" + normalization_method + ".pdf"
            plt.savefig(path, bbox_inches='tight')
            plt.close()

