import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import re
import matplotlib.colors as mcolors
import matplotlib.cm as cm
from scipy.optimize import curve_fit

def mse_model(data, A, B, alpha):
    return A + B / data**alpha

def parse_text(text):
    pattern = r"Width(\d*).*?_p(\d+).*?mse:(\d+\.\d+)"
    matches = re.findall(pattern, text, re.DOTALL)
    final_list = [(int(m[0]), int(m[1]), float(m[2])) for m in matches]
    return [(inter, perc, mse) for inter, perc, mse in final_list if inter >= 4]  # Adjust filtering if necessary

def read_data(file_path):
    with open(file_path, 'r') as file:
        text = file.read()
    results = parse_text(text)
    return pd.DataFrame(results, columns=['Width', 'Percentage', 'MSE'])

def plot_data(df):
    norm = mcolors.LogNorm(vmin=df['Percentage'].min(), vmax=df['Percentage'].max())
    scalar_map = cm.ScalarMappable(norm=norm, cmap=cm.plasma)

    plt.figure(figsize=(12, 8))
    grouped = df.groupby('Percentage')
    
    for interto, group in grouped:
        print(group)
        
        group.sort_values('Width', inplace=True)
        mean_mse = group.groupby('Width')['MSE'].mean()
        std_mse = group.groupby('Width')['MSE'].std()
        percents = mean_mse.index
        color = scalar_map.to_rgba(interto)
        
        plt.errorbar(percents, mean_mse, yerr=std_mse, label=f'Weather: {group["Width"].iloc[0]} to 192', fmt='o', capsize=5, capthick=2, color = color)
        # plt.scatter(group['Percentage'], group['MSE'], label=f'Horizon {interto}', color=color, alpha=0.6)
        
        # Fit the model to the data
        popt, pcov = curve_fit(mse_model, mean_mse.index, mean_mse, maxfev=10000, p0=[0.2,2.0,1.0], sigma = mean_mse)
        print(popt)
        A, B, alpha = popt
        
        
        observed = mean_mse.values
        mean_observed = np.mean(observed)
        predicted = mse_model(mean_mse.index, *popt)
        SSR = np.sum((observed - predicted) ** 2)
        SST = np.sum((observed - mean_observed) ** 2)
        R_squared = 1 - (SSR / SST)
        
        std_alpha = np.sqrt(np.diag(pcov))[2]  # Standard deviation of alpha
        
        # Create a smooth line for the model
        smooth_data = np.linspace(group['Width'].min(), group['Width'].max(), 500)
        smooth_mse = mse_model(smooth_data, *popt)
        plt.plot(smooth_data, smooth_mse, color=color, label=f'Fit: α={alpha:.2f}±{std_alpha:.2f},R^2={R_squared:.3f}')
    
    plt.title('MSE by Percentage and Width')
    plt.xlabel('Width')
    plt.xscale('log')
    plt.ylabel('MSE')
    cbar = plt.colorbar(scalar_map, label='Percentage of Data Used')
    tick_locs = np.unique(df['Percentage'])
    cbar.set_ticks(tick_locs)
    cbar.set_ticklabels(tick_locs)
    
    plt.legend()
    plt.grid(True)
    plt.savefig("newresult_SingleMLP_Weather_width.png")

if __name__ == "__main__":
    file_path = 'newresult_SingleMLP_Weather_width.txt'
    df = read_data(file_path)
    plot_data(df)
