import numpy as np
from itertools import chain, combinations
import scipy
import matplotlib.pyplot as plt
# import multiprocessing as mp
import concurrent.futures
import time
import multiprocessing

def calculate_category_averages(losses, categories):
    # Create a dictionary to store losses for each category
    category_losses = {}

    # Iterate through the lists and group losses by categories
    for i in range(len(losses)):
        loss = losses[i]
        category = categories[i]

        if category in category_losses:
            category_losses[category].append(loss)
        else:
            category_losses[category] = [loss]

    # Calculate and store the average for each category
    category_averages = {}
    for category, loss_list in category_losses.items():
        average = sum(loss_list) / len(loss_list)
        category_averages[category] = average

    return category_averages

def calculate_category_averages_multiple(losses_trials, categories):
    # Create a dictionary to store losses for each category across trials
    category_losses = {}

    for trial in range(len(losses_trials)):
        losses = losses_trials[trial]
        one_trial_averages = calculate_category_averages(losses, categories)
        for category, loss in one_trial_averages.items():
            if category in category_losses:
                category_losses[category].append(loss)
            else:
                category_losses[category] = [loss]

    # Calculate and store the average for each category
    category_averages = {}
    category_stddevs = {}  # Dictionary to store standard deviations

    for category, loss_list in category_losses.items():
        average = np.mean(loss_list)
        stddev = np.std(loss_list)
        category_averages[category] = average
        category_stddevs[category] = stddev

    return category_averages, category_stddevs

def calculate_fair_objective_multiple(losses_trials, categories):
    # Create a dictionary to store losses for each category across trials
    fair_multiple_trials = []

    for trial in range(len(losses_trials)):
        losses = losses_trials[trial]
        one_trial_averages = calculate_category_averages(losses, categories)
        fair_multiple_trials.append(max(one_trial_averages.values()))
    return np.array(fair_multiple_trials)

def parallelize_category_losses(losses_trials, categories):
    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = [executor.submit(calculate_category_averages, losses_trial, categories) for losses_trial in losses_trials]

    # Combine results from multiple trials
    combined_category_averages = {}
    combined_category_stddevs = {}

    for future in concurrent.futures.as_completed(futures):
        category_averages, category_stddevs = future.result()
        for category, average in category_averages.items():
            if category in combined_category_averages:
                combined_category_averages[category].append(average)
                combined_category_stddevs[category].append(category_stddevs[category])
            else:
                combined_category_averages[category] = [average]
                combined_category_stddevs[category] = [category_stddevs[category]]

    # Calculate overall average and standard deviation across trials
    overall_averages = {}
    overall_stddevs = {}

    for category, averages in combined_category_averages.items():
        overall_averages[category] = np.mean(averages)
        overall_stddevs[category] = np.mean(combined_category_stddevs[category])

    return overall_averages, overall_stddevs

def is_pos_def(x):
    return np.all(np.linalg.eigvals(x) > 0)

def generate_points(n, p, m):
    Phi = []
    A_all = []
    l = int(n / m)
    for j in range(m):
        mu_phi = 500.0 * np.random.uniform(size=p)
        if(j != m - 1):
            for i in range(l):
                Phi.append(mu_phi + 10.0 * np.random.randn(p))
        else:
            for i in range(n - l * (m - 1)):
                Phi.append(mu_phi + 10.0 * np.random.randn(p))
    return np.array(Phi)

# def generate_skewed_cov_matrices(n_dims, n_points):
#     # Define the base covariance matrix
#     base_cov = np.random.rand(n_dims, n_dims)
#     base_cov = np.dot(base_cov, base_cov.T)

#     # Generate n_points skewed covariance matrices
#     cov_list = []
#     for i in range(n_points):
#         # Generate a random positive definite skewness matrix using Cholesky decomposition
#         L = np.linalg.cholesky(base_cov)
#         skewness = np.dot(L, np.random.normal(size=(n_dims, n_dims)))
#         skewness = np.dot(skewness, skewness.T)

#         # Add the skewness to the base covariance matrix
#         cov = base_cov + skewness
#         cov = skewness

#         # Append the resulting covariance matrix to the list
#         cov_list.append(cov)
        
#     return np.array(cov_list)

def generate_skewed_cov_matrices(c, n_dims, n_points):
    # Generate c base covariance matrices
    base_cov_list = []
    for i in range(c):
        base_cov = np.random.rand(n_dims, n_dims)
        base_cov = np.dot(base_cov, base_cov.T)
        base_cov_list.append(base_cov)

    # Generate n_points skewed covariance matrices for each base covariance matrix
    cov_list = []
    for i in range(c):
        for j in range(n_points):
            # Generate a random positive definite skewness matrix using Cholesky decomposition
            L = np.linalg.cholesky(base_cov_list[i])
            skewness = np.dot(L, np.random.normal(size=(n_dims, n_dims)))
            skewness = np.dot(skewness, skewness.T)

            # Add the skewness to the base covariance matrix
            cov = base_cov_list[i] + 0.1 * skewness

            # Append the resulting covariance matrix to the list
            cov_list.append(cov)
        
    return np.array(cov_list)   

def generate_diagonal_matrices(n, p, a=1e3, b=1e-1):
    """
    Generates n diagonal matrices of size p x p, such that the total number of copies of each matrix is roughly n/p.
    
    Args:
    - n (int): total number of matrices to generate
    - p (int): size of each matrix
    - a (float): diagonal entry of each matrix (default: 1e3)
    - b (float): off-diagonal entry of each matrix (default: 1e-1)
    
    Returns:
    - list of numpy arrays: list of n diagonal matrices of size p x p
    """
    
    # compute the number of copies of each matrix
    num_copies = np.ones(p, dtype=int) * (n//p)
    num_copies[:n%p] += 1
    
    # create the diagonal matrices
    matrices = []
    for i in range(p):
        for j in range(num_copies[i]):
            matrix = np.diag(np.concatenate((np.array([a]), np.ones(p-1)*b)))
            matrices.append(matrix)
        
    # trim matrices to length n
    matrices = matrices[:n]
    
    return matrices


# from scipy.linalg import eigh, cholesky

# def generate_skewed_cov_matrices(n, m, min_eigval):
#     matrices = []
#     for i in range(n):
#         # Generate a random positive definite matrix of size m x m
#         A = np.random.rand(m, m)
#         A = np.dot(A, A.T)
        
#         # Compute the Cholesky decomposition of A
#         L = cholesky(A, lower=True)
        
#         # Generate a diagonal matrix with random positive entries
#         D = np.diag(np.random.rand(m))
        
#         # Add a small positive multiple of the identity matrix to A
#         A = A + min_eigval * np.eye(m)
        
#         # Compute the minimum generalized eigenvalue of the pair (A, L @ D @ L.T)
#         eigs, _ = eigh(L @ D @ L.T, A)
#         min_eigval_pair = np.min(eigs)
        
#         # If the minimum generalized eigenvalue is not small enough, add a small positive multiple of the identity matrix to A
#         while min_eigval_pair >= min_eigval:
#             A = A + min_eigval * np.eye(m)
#             eigs, _ = eigh(L @ D @ L.T, A)
#             min_eigval_pair = np.min(eigs)
        
#         # Add the matrix to the list of matrices
#         matrices.append(A)
    
#     return matrices

def dist(theta, phi, A):
    # x = np.matmul(np.linalg.sqrtm(A), (theta - phi)) 
    # return np.dot(x, x)
    return np.matmul(np.matmul((theta - phi), A).T, (theta - phi))

def D(Theta, phi, A):
    losses = np.array([dist(theta, phi, A) for theta in Theta])
    idx = np.argmin(losses)
    min_loss = losses[idx]
    return min_loss, idx

def Potential(Theta, Phi, A_all):
    n = Phi.shape[0]
    clusters = []
    total_loss = 0
    for i in range(n):
        min_loss, idx = D(Theta, Phi[i], A_all[i]) 
        total_loss += min_loss
        clusters.append(idx)
    return total_loss, clusters

def losses(Theta, Phi, A_all):
    n = Phi.shape[0]
    user_losses = []
    total_loss = 0
    for i in range(n):
        min_loss, idx = D(Theta, Phi[i], A_all[i]) 
        user_losses.append(min_loss)
    return np.array(user_losses)

def losses_across_trials(Theta_list, Phi, A_all):
    with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
        # Create a list of arguments for the initialize function
        args_list = [(Theta_list[i], Phi, A_all) for i in range(len(Theta_list))]
        # Use executor.map to parallelize the function calls
        results = np.array(list(executor.map(losses, *zip(*args_list))))
    return results

def Loss_list(Theta, Phi, A_all):
    n = Phi.shape[0]
    losses = []
    total_loss = 0
    for i in range(n):
        min_loss, idx = D(Theta, Phi[i], A_all[i]) 
        losses.append(min_loss)
    return np.array(losses)

def one_center_optimal(Phi, A_all):
    n = Phi.shape[0]
    if n == 0:
        print("Oh No")
        return Phi[0], 0
    A_sum = 0
    A_phi_sum = 0
    for i in range(n):
        A_sum += A_all[i]
        A_phi_sum += np.matmul(A_all[i], Phi[i])
    opt_theta = np.matmul(np.linalg.inv(A_sum), A_phi_sum)
    return opt_theta, Potential([opt_theta], Phi, A_all)[0]

def subsets_k(collection, k): yield from partition_k(collection, k, k)

def partition_k(collection, min, k):
  if len(collection) == 1:
    yield [ collection ]
    return

  first = collection[0]
  for smaller in partition_k(collection[1:], min - 1, k):
    if len(smaller) > k: continue
    # insert `first` in each of the subpartition's subsets
    if len(smaller) >= min:
      for n, subset in enumerate(smaller):
        yield smaller[:n] + [[ first ] + subset]  + smaller[n+1:]
    # put `first` in its own subset 
    if len(smaller) < k: yield [ [ first ] ] + smaller

def optimal(Phi, A_all, m=2):
    # Done for 2 clusters (2^n complexity)
    n = Phi.shape[0]
    min_loss = np.inf
    all_partitions = subsets_k([i for i in range(n)],m)
    for split in all_partitions:
        # need to update next line
        # x = [one_center_optimal(Phi[split[t]], A_all[split[t]]) for t in range(m)]
        # print(x)
        curr_loss = sum([one_center_optimal(Phi[split[t]], A_all[split[t]]) for t in range(m)])
        if curr_loss < min_loss:
            min_loss = curr_loss
    return min_loss

# def initialize(Phi, A_all, m):
#     n = Phi.shape[0]
#     i = np.random.randint(low=0, high=n)
#     Theta = [Phi[i]]
#     for t in range(m-1):
#         loss_and_idx = [D(Theta, Phi[i], A_all[i]) for i in range(n)]
#         prob = [loss_and_idx[i][0] for i in range(n)]
#         prob /= sum(prob)
#         i = np.random.choice([i for i in range(n)], p=prob)
#         Theta.append(Phi[i])
#     return Theta


def find_hypercube(points):
    """
    Find the hypercube containing all the points.
    """
    mins = np.min(points, axis=0)
    maxs = np.max(points, axis=0)
    lengths = maxs - mins
    L = np.max(lengths)
    return L, mins, maxs

def sample_points(points, m):
    """
    Sample m points uniformly randomly from the hypercube containing all the points.
    """
    p = points.shape[1]
    L, mins, _ = find_hypercube(points)
    samples = np.random.rand(m, p) * L + mins
    return samples

# Optimized Min
def initialize(Phi, A_all, m, sizes=None):
    n = Phi.shape[0]
    idx_list = []
    if sizes is None:
        i = np.random.randint(low=0, high=n)
        Theta = [Phi[i]]
        idx_list.append(i)
    else:
        # probs = sizes / sum(sizes)
        probs = 1 / sizes
        idx = np.random.choice(np.arange(n), size=1, p=probs / sum(probs), replace=False)[0]
        Theta = [Phi[idx]]
    if sizes is None:
        prob = [dist(Theta[-1], Phi[i], A_all[i]) for i in range(n)]
    else:
        prob = [dist(Theta[-1], Phi[i], A_all[i]) / sizes[i] for i in range(n)]
    for t in range(m-1):
        if sizes is None:
            prob = [min(dist(Theta[-1], Phi[i], A_all[i]), prob[i]) for i in range(n)] #observe losses
        else:
            prob = [min(dist(Theta[-1], Phi[i], A_all[i]) / sizes[i], prob[i]) for i in range(n)] #observe losses
        i = np.random.choice(n, p=prob/sum(prob)) # select user
        while(i in idx_list):
            i = np.random.choice(n, p=prob/sum(prob)) # select user
        Theta.append(Phi[i]) # query preference and update
        idx_list.append(i)
    return Theta

# Greedy
def initialize_greedy(Phi, A_all, m, sizes=None):
    n = Phi.shape[0]
    idx_list = []
    if sizes is None:
        i = np.random.randint(low=0, high=n)
        Theta = [Phi[i]]
        idx_list.append(i)
    else:
        probs = 1 / sizes
        idx = np.random.choice(np.arange(n), size=1, p=probs / sum(probs), replace=False)[0]
        Theta = [Phi[idx]]
    if sizes is None:
        prob = [dist(Theta[-1], Phi[i], A_all[i]) for i in range(n)]
    else:
        prob = [dist(Theta[-1], Phi[i], A_all[i]) / sizes[i] for i in range(n)]
    for t in range(m-1):
        if sizes is None:
            prob = [min(dist(Theta[-1], Phi[i], A_all[i]), prob[i]) for i in range(n)] #observe losses
        else:
            prob = [min(dist(Theta[-1], Phi[i], A_all[i]) / sizes[i], prob[i]) for i in range(n)] #observe losses
        # i = np.random.choice(n, p=prob/sum(prob)) # select user
        i = np.argmax(prob)
        # while(i in idx_list):
        #     i = np.random.choice(n, p=prob/sum(prob)) # select user
        Theta.append(Phi[i]) # query preference and update
        idx_list.append(i)
    return Theta

# Epsilon Greedy
def initialize_epsilon_greedy(Phi, A_all, m, sizes=None):
    n = Phi.shape[0]
    idx_list = []
    if sizes is None:
        i = np.random.randint(low=0, high=n)
        Theta = [Phi[i]]
        idx_list.append(i)
    else:
        probs = 1 / sizes
        idx = np.random.choice(np.arange(n), size=1, p=probs / sum(probs), replace=False)[0]
        Theta = [Phi[idx]]
    if sizes is None:
        prob = [dist(Theta[-1], Phi[i], A_all[i]) for i in range(n)]
    else:
        prob = [dist(Theta[-1], Phi[i], A_all[i]) / sizes[i] for i in range(n)]
    for t in range(m-1):
        if sizes is None:
            prob = [min(dist(Theta[-1], Phi[i], A_all[i]), prob[i]) for i in range(n)] #observe losses
        else:
            prob = [min(dist(Theta[-1], Phi[i], A_all[i]) / sizes[i], prob[i]) for i in range(n)] #observe losses
        prob = np.array(prob)
        # i = np.random.choice(n, p=prob/sum(prob)) # select user
        max_entry = np.max(prob)
        std_dev_factor = 0.1  # Adjust this factor based on your requirements
        # Calculate the standard deviation
        std_dev = std_dev_factor * max_entry

        # Generate Gaussian noise with the same shape as 'prob'
        noise = np.random.normal(0, std_dev, prob.shape)
        # epsilon greedy
        i = np.argmax(prob + noise)
        # while(i in idx_list):
        #     i = np.random.choice(n, p=prob/sum(prob)) # select user
        Theta.append(Phi[i]) # query preference and update
        idx_list.append(i)
    return Theta

# Optimized Min
def initialize_random_users(Phi, A_all, m, sizes=None):
    n = Phi.shape[0]
    if sizes is None:
        idx = np.random.choice(np.arange(n), size=m, replace=False)
    else:
        # probs = sizes / sum(sizes)
        # idx = np.random.choice(np.arange(n), size=m, p=probs, replace=True)
        probs = 1 / sizes
        idx = np.random.choice(np.arange(n), size=m, p=probs / sum(probs), replace=False)
        
    Theta = Phi[idx]
    return Theta

# Function to call initialize in parallel
def initialize_multiple(num_trials, Phi, A_all, m, sizes=None):
    with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
        # Create a list of arguments for the initialize function
        args_list = [(Phi, A_all, m, sizes) for _ in range(num_trials)]
        # Use executor.map to parallelize the function calls
        results = np.array(list(executor.map(initialize, *zip(*args_list))))
    return results

# Function to call initialize in parallel
def initialize_multiple_greedy(num_trials, Phi, A_all, m, sizes=None):
    with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
        # Create a list of arguments for the initialize function
        args_list = [(Phi, A_all, m, sizes) for _ in range(num_trials)]
        # Use executor.map to parallelize the function calls
        results = np.array(list(executor.map(initialize_greedy, *zip(*args_list))))
    return results

# Function to call initialize in parallel
def initialize_multiple_epsilon_greedy(num_trials, Phi, A_all, m, sizes=None):
    with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
        # Create a list of arguments for the initialize function
        args_list = [(Phi, A_all, m, sizes) for _ in range(num_trials)]
        # Use executor.map to parallelize the function calls
        results = np.array(list(executor.map(initialize_epsilon_greedy, *zip(*args_list))))
    return results

def initialize_multiple_random_users(num_trials, Phi, A_all, m, sizes=None):
    with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
        # Create a list of arguments for the initialize function
        args_list = [(Phi, A_all, m, sizes) for _ in range(num_trials)]
        # Use executor.map to parallelize the function calls
        results = np.array(list(executor.map(initialize_random_users, *zip(*args_list))))
    return results

def update(Phi, A_all, clusters, m):
    Theta_new = []
    new_loss = 0
    n = Phi.shape[0]
    def worker(k):
        Phi_k = np.array([Phi[i] for i in range(n) if k == clusters[i]])
        A_all_k = np.array([A_all[i] for i in range(n) if k == clusters[i]])
        if (len(Phi_k) == 0):
            idx = np.random.choice(np.arange(n), size=1, replace=False)
            opt_theta_k = Phi[idx].reshape(-1)
            loss_k = 0
        else:
            opt_theta_k, loss_k = one_center_optimal(Phi_k, A_all_k)
        return opt_theta_k, loss_k
    with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: # change max_workers as needed
        results = [executor.submit(worker, k) for k in range(m)]
    for future in concurrent.futures.as_completed(results):
        opt_theta_k, loss_k = future.result()
        Theta_new.append(opt_theta_k)
        new_loss += loss_k
    return np.array(Theta_new), new_loss

# def update(Phi, A_all, clusters, m):
#     Theta_new = []
#     new_loss = 0
#     n = Phi.shape[0]
#     for k in range(m):
#         Phi_k = np.array([Phi[i] for i in range(n) if k == clusters[i]])
#         A_all_k = np.array([A_all[i] for i in range(n) if k == clusters[i]])
#         if (len(Phi_k) == 0):
#             idx = np.random.choice(np.arange(n), size=1, replace=False)
#             # print(idx)
#             opt_theta_k = Phi[idx].reshape(-1)
#             loss_k = 0
#             # print("This hapenned")
#             # print(opt_theta_k.shape)
#         else:
#             opt_theta_k, loss_k = one_center_optimal(Phi_k, A_all_k)
#             # print(opt_theta_k.shape)
#         Theta_new.append(opt_theta_k)
#         new_loss += loss_k
#     return np.array(Theta_new), new_loss

def dynamics(Theta, Phi, A_all, m):
    start_time = time.time()
    t = 0
    curr_loss, clusters = Potential(Theta, Phi, A_all)
    # print("Initial Loss : ", curr_loss)
    initial_loss = curr_loss
    Theta_new, new_loss = update(Phi, A_all, clusters, m)
    print("t = ", t, " Loss : ", new_loss)
    # return new_loss, initial_loss, Theta_new # Stop after 1 step
    # while(new_loss < (1 - 1e-3) * curr_loss):
    # while(new_loss < (1 - 1e-5) * curr_loss):
    while(t < 50):
        t += 1
        # print("Time : ",t ," , Loss : ", new_loss)
        Theta = Theta_new
        curr_loss, clusters = Potential(Theta, Phi, A_all)
        Theta_new, new_loss = update(Phi, A_all, clusters, m)
        print("t = ", t, " Loss : ", new_loss)
    elapsed_time = time.time() - start_time
    return new_loss, initial_loss, Theta_new
    # return elapsed_time, t, Theta_new
    # Changed to log no. of iterations

def log_dynamics(Theta, Phi, A_all, m):
    all_loss = []
    start_time = time.time()
    t = 0
    curr_loss, clusters = Potential(Theta, Phi, A_all)
    # print("Initial Loss : ", curr_loss)
    initial_loss = curr_loss
    Theta_new, new_loss = update(Phi, A_all, clusters, m)
    # print("t = ", t, " Loss : ", new_loss)
    all_loss.append(new_loss)
    # return new_loss, initial_loss, Theta_new # Stop after 1 step
    # while(new_loss < (1 - 1e-3) * curr_loss):
    # while(new_loss < (1 - 1e-5) * curr_loss):
    while(t < 50):
        t += 1
        # print("Time : ",t ," , Loss : ", new_loss)
        Theta = Theta_new
        curr_loss, clusters = Potential(Theta, Phi, A_all)
        Theta_new, new_loss = update(Phi, A_all, clusters, m)
        # print("t = ", t, " Loss : ", new_loss)
        all_loss.append(new_loss)
    elapsed_time = time.time() - start_time
    # return new_loss, initial_loss, Theta_new
    return np.array(all_loss)
    # return elapsed_time, t, Theta_new
    # Changed to log no. of iterations
        
def init_and_dynamics_multiple(Phi, A_all, m, num_trial, sizes=None):
    converged_losses = []
    inital_losses = []
    Theta_all = []
    # print(A_all[:5])
    def init_and_dynamics(t):
        Theta = initialize(Phi, A_all, m, sizes)
        converged_loss, initial_loss, Theta = dynamics(Theta, Phi, A_all, m)
        return converged_loss, initial_loss, Theta 
    with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: # change max_workers as needed
        results = [executor.submit(init_and_dynamics, 1) for i in range(num_trial)]
    for future in concurrent.futures.as_completed(results):
        converged_loss, initial_loss, Theta = future.result()
        inital_losses.append(initial_loss)
        converged_losses.append(converged_loss)
        Theta_all.append(Theta)
        # print("Done!")
    return np.array(inital_losses), np.array(converged_losses), np.array(Theta_all)

def log_init_and_dynamics_multiple(Phi, A_all, m, num_trial, sizes=None):
    all_trial_losses = []
    # print(A_all[:5])
    def init_and_dynamics(t):
        Theta = initialize(Phi, A_all, m, sizes)
        all_losses = log_dynamics(Theta, Phi, A_all, m)
        return all_losses
    with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: # change max_workers as needed
        results = [executor.submit(init_and_dynamics, 1) for i in range(num_trial)]
    for future in concurrent.futures.as_completed(results):
        all_losses = future.result()
        all_trial_losses.append(all_losses)
    return np.array(all_trial_losses)

def log_random_init_and_dynamics_multiple(Phi, A_all, m, num_trial, sizes=None):
    all_trial_losses = []
    # print(A_all[:5])
    def init_and_dynamics(t):
        n = Phi.shape[0]
        if sizes is None:
            idx = np.random.choice(np.arange(n), size=m, replace=True)
        else:
            probs = sizes / sum(sizes)
            idx = np.random.choice(np.arange(n), size=m, p=probs, replace=True)
        Theta = Phi[idx]
        all_losses = log_dynamics(Theta, Phi, A_all, m)
        return all_losses
    with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: # change max_workers as needed
        results = [executor.submit(init_and_dynamics, 1) for i in range(num_trial)]
    for future in concurrent.futures.as_completed(results):
        all_losses = future.result()
        all_trial_losses.append(all_losses)
    return np.array(all_trial_losses)

def random_init_and_dynamics_multiple(Phi, A_all, m, num_trial, sizes=None):
    converged_losses = []
    inital_losses = []
    Theta_all = []
    def init_and_dynamics(t):
        n = Phi.shape[0]
        if sizes is None:
            idx = np.random.choice(np.arange(n), size=m, replace=True)
        else:
            probs = sizes / sum(sizes)
            idx = np.random.choice(np.arange(n), size=m, p=probs, replace=True)
        Theta = Phi[idx]
        converged_loss, initial_loss, Theta = dynamics(Theta, Phi, A_all, m)
        return converged_loss, initial_loss, Theta 
    with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: # change max_workers as needed
        results = [executor.submit(init_and_dynamics, 1) for i in range(num_trial)]
    for future in concurrent.futures.as_completed(results):
        converged_loss, initial_loss, Theta = future.result()
        inital_losses.append(initial_loss)
        converged_losses.append(converged_loss)
        Theta_all.append(Theta)
        # print("Done!")
    return np.array(inital_losses), np.array(converged_losses), np.array(Theta_all)

def random_hypercube_init_and_dynamics_multiple(Phi, A_all, m, num_trial):
    converged_losses = []
    inital_losses = []
    Theta_all = []
    def init_and_dynamics(t):
        Theta = sample_points(Phi, m)
        converged_loss, initial_loss, Theta = dynamics(Theta, Phi, A_all, m)
        return converged_loss, initial_loss, Theta 
    with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: # change max_workers as needed
        results = [executor.submit(init_and_dynamics, 1) for i in range(num_trial)]
    for future in concurrent.futures.as_completed(results):
        converged_loss, initial_loss, Theta = future.result()
        inital_losses.append(initial_loss)
        converged_losses.append(converged_loss)
        Theta_all.append(Theta)
        # print("Done!")
    return np.array(inital_losses), np.array(converged_losses), np.array(Theta_all)

def centroid_init_and_dynamics_multiple(Phi, A_all, m, num_trial, sizes=None):
    converged_losses = []
    inital_losses = []
    Theta_all = []
    def init_and_dynamics(t):
        n = Phi.shape[0]
        p = Phi.shape[1]
        if sizes is None:
            Theta = initialize(Phi, [np.eye(p) for i in range(n)], m, sizes)
        else:
            Theta = initialize(Phi, [np.eye(p) * sizes[i]/sum(sizes) for i in range(n)], m, sizes)
        # A_avg = np.sum(A_all, axis=0)
        # Theta = initialize(Phi, [A_avg * sizes[i] /sum(sizes) for i in range(n)], m, sizes)
        converged_loss, initial_loss, Theta = dynamics(Theta, Phi, A_all, m)
        return converged_loss, initial_loss, Theta 
    with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: # change max_workers as needed
        results = [executor.submit(init_and_dynamics, 1) for i in range(num_trial)]
    for future in concurrent.futures.as_completed(results):
        converged_loss, initial_loss, Theta = future.result()
        inital_losses.append(initial_loss)
        converged_losses.append(converged_loss)
        Theta_all.append(Theta)
        # print("Done!")
    return np.array(inital_losses), np.array(converged_losses), np.array(Theta_all)

def compute_individual_losses(Phi, A_all, Theta, num_trials, sizes):
    all_init_loss = None
    all_random_loss = None
    for i in range(num_trials):
        init_loss = Loss_list(Theta[i], Phi, A_all)
        if all_init_loss is None:
            all_init_loss = init_loss.reshape(1, -1)
        else:
            all_init_loss = np.concatenate((all_init_loss, init_loss.reshape(1, -1)), axis=0)
    # print(all_init_loss.shape)
    a = np.mean(all_init_loss, axis=0) / sizes
    # print(a.shape)
    return a

# compare_and_plot(Phi, A_all, init_Theta, random_Theta, num_trials)
def plot_quantiles(arrays, n, num_centers, legend_names, file_name):
    # sort each array individually
    sorted_arrays = [np.sort(a) for a in arrays]
    
    # divide the sorted arrays into n quantiles
    quantiles = np.linspace(0, 1, n + 1)[1:]
    quantile_groups = [np.quantile(a, quantiles) for a in sorted_arrays]
    
    for i in range(1, len(arrays), 1):
        quantile_groups[i] = 100 * (1 - quantile_groups[0] / quantile_groups[i])

    # print(quantile_groups)
    
    width = 1.0 / (len(arrays) + 1)
    x_values = np.arange(n)
    for i in range(1, len(arrays), 1):
        plt.bar(x_values + i*width, quantile_groups[i], width=width, align='center')
    
    # add axis labels and legend
    plt.xlabel('Quantiles (Increasing Risk)')
    plt.ylabel('Risk Improvement Percentage')
    # plt.legend(['Risk Distribution ' + str(i+1) for i in range(len(arrays))])
    plt.legend(legend_names)
    plt.title("{}_centers".format(num_centers))

    plt.savefig('quantile_plot_{}_{}.png'.format(file_name, num_centers))
    # show the plot
    plt.show()
    plt.close()

import math

def print_array_stats(arr):
    """
    Prints basic statistics about an array, including standard deviation.
    """
    print("Minimum :", min(arr))
    # print("Average :", sum(arr) / len(arr))
    
    # variance = sum([((x - sum(arr)/len(arr)) ** 2) for x in arr]) / len(arr)
    # standard_deviation = math.sqrt(variance)
    # print("Standard deviation:", standard_deviation)
    print(f"Mean: {np.mean(arr):.2e} +- {np.std(arr):.2e}")


def plot_sorted_arrays(arr1, arr2, arr3, arr4, arr5, m):
    # get the indices that would sort arr1 in ascending order
    sorted_idxs = np.argsort(arr1)
    
    # sort all arrays using the sorted indices
    sorted_arr1 = arr1[sorted_idxs]
    sorted_arr2 = arr2[sorted_idxs]
    sorted_arr3 = arr3[sorted_idxs]
    sorted_arr4 = arr4[sorted_idxs]
    sorted_arr5 = arr5[sorted_idxs]

    # set the width of each bar
    bar_width = 0.15

    # set the x coordinates for each set of bars
    x_coords1 = np.arange(len(sorted_arr1))
    x_coords2 = np.arange(len(sorted_arr2)) + bar_width
    x_coords3 = np.arange(len(sorted_arr3)) + 2*bar_width
    x_coords4 = np.arange(len(sorted_arr4)) + 3*bar_width
    x_coords5 = np.arange(len(sorted_arr5)) + 4*bar_width

    # create a bar chart of the sorted arrays
    plt.bar(x_coords1, sorted_arr1, width=bar_width, label='Our Initialization', color='red', alpha=1.0)
    plt.bar(x_coords2, sorted_arr2, width=bar_width, label='Uniform Random', color='green', alpha=0.8)
    plt.bar(x_coords3, sorted_arr3, width=bar_width, label='Weighted Random', color='blue', alpha=0.5)
    plt.bar(x_coords4, sorted_arr4, width=bar_width, label='Kmeans++', color='purple', alpha=0.5)
    plt.bar(x_coords5, sorted_arr5, width=bar_width, label='Weighted Kmeans++', color='orange', alpha=0.5)
    plt.xlabel("Subpopulations sorted by Increasing risk")
    plt.ylabel("Subpopulation Risk")
    plt.title("Risk Profiles for {} centers".format(m))
    plt.legend()
    plt.savefig('wealth_redistribution_{}.png'.format(m))
    plt.show()
    plt.close()


# def plot_sorted_arrays(arr1, arr2, arr3, arr4, arr5, m):
#     # sort the arrays in ascending order
#     sorted_arr1 = np.sort(arr1)
#     sorted_arr2 = np.sort(arr2)
#     sorted_arr3 = np.sort(arr3)
#     sorted_arr4 = np.sort(arr4)
#     sorted_arr5 = np.sort(arr5)

#     # set the width of each bar
#     bar_width = 0.15

#     # set the x coordinates for each set of bars
#     x_coords1 = np.arange(len(sorted_arr1))
#     x_coords2 = np.arange(len(sorted_arr2)) + bar_width
#     x_coords3 = np.arange(len(sorted_arr3)) + 2*bar_width
#     x_coords4 = np.arange(len(sorted_arr4)) + 3*bar_width
#     x_coords5 = np.arange(len(sorted_arr5)) + 4*bar_width

#     # create a bar chart of the sorted arrays
#     plt.bar(x_coords1, sorted_arr1, width=bar_width, label='Our Initialization', color='red', alpha=1.0)
#     plt.bar(x_coords2, sorted_arr2, width=bar_width, label='Uniform Random', color='green', alpha=0.8)
#     plt.bar(x_coords3, sorted_arr3, width=bar_width, label='Weighted Random', color='blue', alpha=0.5)
#     plt.bar(x_coords4, sorted_arr4, width=bar_width, label='Kmeans++', color='purple', alpha=0.5)
#     plt.bar(x_coords5, sorted_arr5, width=bar_width, label='Weighted Kmeans++', color='orange', alpha=0.5)
#     plt.xlabel("Subpopulations sorted by Increasing risk")
#     plt.ylabel("Subpopulation Risk")
#     plt.title("Risk Profiles for {} centers".format(m))
#     plt.legend()
#     plt.savefig('wealth_redistribution_{}.png'.format(m))
#     plt.show()
#     plt.close()

def plot_sorted_arrays_shaded(arr1, arr2, arr3, arr4, arr5, m):
    # sort the arrays in ascending order
    sorted_arr1 = np.sort(arr1)
    sorted_arr2 = np.sort(arr2)
    sorted_arr3 = np.sort(arr3)
    sorted_arr4 = np.sort(arr4)
    sorted_arr5 = np.sort(arr5)

    # create a figure and axis object
    fig, ax = plt.subplots()

    # plot the sorted arrays as a filled plot
    ax.fill_between(np.arange(len(sorted_arr2)), sorted_arr2, alpha=0.3, label='Uniform Random')
    ax.fill_between(np.arange(len(sorted_arr3)), sorted_arr3, alpha=0.3, label='Weighted Random')
    ax.fill_between(np.arange(len(sorted_arr5)), sorted_arr5, alpha=0.3, label='Weighted Kmeans++')
    ax.fill_between(np.arange(len(sorted_arr4)), sorted_arr4, alpha=0.3, label='Kmeans++')
    ax.fill_between(np.arange(len(sorted_arr1)), sorted_arr1, alpha=0.3, label='Our Initialization')

    # set the x and y axis labels and title
    ax.set_xlabel("Subpopulations sorted by Increasing risk")
    ax.set_ylabel("Subpopulation Risk")
    ax.set_title("Risk Profiles for {} centers".format(m))

    # add a legend
    ax.legend()

    # save the plot to a file and show it
    plt.savefig('wealth_redistribution_{}_shaded.png'.format(m))
    plt.show()
    plt.close()


def plot_means_with_std(init_mean, centroid_init_mean, random_mean, init_std, centroid_init_std, random_std, X, fig_name):
    # Create figure and axes
    fig, ax = plt.subplots()

    # Plot means with shaded standard deviation blankets
    # ax.plot(X, init_mean, label='init_mean')
    # ax.fill_between(X, np.array(init_mean) - np.array(init_std), np.array(init_mean) + np.array(init_std), alpha=0.2)
    # ax.plot(X, centroid_init_mean, label='centroid_init_mean')
    # ax.fill_between(X, np.array(centroid_init_mean) - np.array(centroid_init_std), np.array(centroid_init_mean) + np.array(centroid_init_std), alpha=0.2)
    # ax.plot(X, random_mean, label='random_mean')
    # ax.fill_between(X, np.array(random_mean) - np.array(random_std), np.array(random_mean) + np.array(random_std), alpha=0.2)
    # init_mean, centroid_init_mean, random_mean, init_std, centroid_init_std, random_std = np.log(init_mean), np.log(centroid_init_mean), np.log(random_mean), np.log(init_std), np.log(centroid_init_std), np.log(random_std)
    
    ax.errorbar(X, init_mean, yerr=init_std, label='Our Initialization', fmt='o-', capsize=3)
    ax.errorbar(X, centroid_init_mean, yerr=centroid_init_std, label='Kmeans++ Initialization', fmt='o-', capsize=3)
    ax.errorbar(X, random_mean, yerr=random_std, label='Random Initialization', fmt='o-', capsize=3)

    # Set axis labels and legend
    ax.set_xlabel('Number of Centers')
    ax.set_ylabel('Average Risk upon Convergence')
    ax.legend()
    ax.set_xticks(X)

    fig.savefig(fig_name)
    # Show the plot
    plt.show()