

import sys
sys.path.insert(0, "../SingleSample")
import argparse
import os
import logging
import matplotlib.pyplot as plt
from datetime import datetime
import time

try:
    import cupy as cp
    try:
        if cp.cuda.runtime.getDeviceCount() > 0:
            np = cp
        else:
            raise ImportError("No CUDA-capable GPU found.")
    except cp.cuda.runtime.CUDARuntimeError as e:
        raise ImportError(f"CUDA runtime error: {e}")
except ImportError as e:
    print(f"Falling back to numpy due to error: {e}")
    import numpy as np


import json

from data_generation import *
from models import *
from plotting import plot_data_2d, generate_grid
from example_functions import *
from gradient_calculation import approximate_partial_derivative


#--------------------------------------------------------------------------
# Terminal Input
#--------------------------------------------------------------------------
parser = argparse.ArgumentParser(description='Use setting dictionary to run 2sls with finite basis function extension.')
parser.add_argument('--output_dir', type=str, default=None, help='Name of output directory.')
parser.add_argument("--experiment_name", type=str, default="ActiveLearning", help="Name of experiment.")
parser.add_argument("--seed", type=int, default=4, help="Seed for the random number generator.")  # 8282
parser.add_argument('--n_seeds', type=int, default=25, help='Number of seeds.')

parser.add_argument('--dx', type=int, default=2, help='Number of treatment variables.')
parser.add_argument('--dz', type=int, default=2, help='Number of instrument variables.')

parser.add_argument("--strategy_type", type=str, default="single_exploration", help="Do exploration phase.")
parser.add_argument('--n_exploration', type=int, default=250, help='Number of samples.')
parser.add_argument('--n_exploitation', type=int, default=250, help='Number of exploitation samples.')
parser.add_argument('--T', type=int, default=16, help='Number of experiments.')
parser.add_argument('--T_exploration', type=int, default=10, help='Number of exploration samples.')
parser.add_argument('--K_exploitation', type=int, default=250, help='Number of exploration samples.')
parser.add_argument("--sample_sigma", type=float, default=0.001, help="Regularization parameter for the kernel.")
parser.add_argument("--gaussian_K", type=int, default=3, help="Parameter for Gaussian Mixture Model.")

parser.add_argument("--lam_c", type=float, default=0.1, help="Regularization parameter for feasibility.")
parser.add_argument("--lam_first", type=float, default=0.01, help="Regularization parameter for smoothness.")
parser.add_argument("--lam_second", type=float, default=1.0, help="Regularization parameter for smoothness.")
parser.add_argument("--gamma_kernel", type=float, default=1.0, help="Regularization parameter for the kernel.")
parser.add_argument("--step_size_list", type=list, default=[0.1], help="Regularization parameter for the adaptive strategy.")



def main(experiment_dir, args):

    #--------------------------------------------------------------------------
    # Logging
    #--------------------------------------------------------------------------
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)  # Set the logging level
    # Create a file handler
    file_handler = logging.FileHandler(os.path.join(experiment_dir, 'logfile.log'), delay=True)  # Replace 'logfile.log' with your filename
    file_handler.setLevel(logging.INFO)  # Set the logging level
    # Create a stream handler
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.INFO)  # Set the logging level

    # Add the handlers to the logger
    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)

     # Log the command line arguments
    logger.info('Command line arguments: %s', args)

    #--------------------------------------------------------------------------
    # Data Parameter Definition
    #--------------------------------------------------------------------------

    # save parameter to file
    args = parser.parse_args()
    args_dict = vars(args)
    with open(os.path.join(experiment_dir, 'parameters.json'), 'w') as f:
        json.dump(args_dict, f, indent=4)

    T = args.T
    seed = args.seed
    n_exploitation = args.n_exploitation
    T_exploration = args.T_exploration
    K_exploitation = args.K_exploitation
    n_exploration = args.n_exploration
    sigma = args.sample_sigma
    strategy_type = args.strategy_type
    n_seeds = args.n_seeds
    dz = args.dz
    dx = args.dx
    step_size_list = args.step_size_list

    # Set predefined seed
    np.random.seed(seed)
    
    n = 1000
    comp = 0


    #--------------------------------------------------------------------------
    # Experiment Details
    #--------------------------------------------------------------------------  
    alpha = np.random.randn(dz, dx)
    theta = np.random.randint(-10, 10, (dx, 1))
    theta[theta == 0] = 1 

    #h = lambda z, e_x: h_nonlinear_additive(z, e_x, interaction=False, coef=alpha)
    h = lambda z, e_x: h_nonlinear_multiplicative(z, e_x, interaction=False, coef=alpha)
    f = lambda x: 20*f_nonlinear(x, interaction=False)

    grad_f_exact = lambda x, comp: 20*grad_f_nonlinear(x, comp)


    #--------------------------------------------------------------------------
    # Data Generation
    #--------------------------------------------------------------------------
    # Confounded Data (only for checking purpose)
    z_data = np.random.uniform(-2, 2, (n, dz))
    x_data, y_data = generate_data(z_data, dx, h, f, do_confounding=True)

    assert x_data.shape == (n, dx), "Shape of x should be (n, dx)"
    assert y_data.shape == (n, 1), "Shape of y should be (n, 1)"

    # Value we want to get the gradient at
    xbar = x_data.mean(axis=0)
    #xbar = np.array([-2, -2])
    xbar = xbar[..., np.newaxis].T
  
    #--------------------------------------------------------------------------
    # Optimization Parameter Definition
    #-------------------------------------------------------------------------- 
    lam_c = args.lam_c
    lam_first = args.lam_first
    gamma_kernel = args.gamma_kernel
    sigma = args.sample_sigma

    lam_first = args.lam_first
    lam_second = args.lam_second

    lam_second = lam_c

    res_seed = {}
    res_seed["Grad_Min_Exact"] = []
    res_seed["Grad_Max_Exact"] = []
    res_seed["Time"] = []
   

    # save general data
    
    #--------------------------------------------------------------------------
    # Optimization Routine
    #--------------------------------------------------------------------------
    
    # set seed list for confidence bounds
    list_seed = np.random.randint(0, 10000, n_seeds)

    for jter in range(n_seeds):

        logger.info(f"------------------------------------------------ Seed {jter} of {n_seeds} ------------------------------------------------")
        # set seed for confidence bounds
        seed_jter = list_seed[jter]
        np.random.seed(seed_jter)
        res = {}
        
        X1, X2, _, f_grid = generate_grid(x_data, f)
        res["X1"] = X1
        res["X2"] = X2
        res["f_grid"] = f_grid
        res["xbar"] = xbar
        res["x_data"] = x_data 

        grad_true = np.zeros((dx,))
        grad_comp = grad_f_exact(xbar, comp)
        grad_true[comp] = grad_comp.squeeze()
               

        Z_Exp_runs = []
        X_Exp_runs = []
        Y_Exp_runs = []

        
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # Random Sampling
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        if strategy_type=="random_sampling":
            # initate results
            Grad_Min_Exact = np.zeros((1, T))
            Grad_Max_Exact = np.zeros((1, T))
            Sample_Quality = np.zeros((1, T))
            Time = np.zeros((1, T))
            Grad = np.zeros((1, T))

            for iter in range(T):
                logger.info(f"------------------------------------------------ Run {iter} of {T} ------------------------------------------------")
                start = time.time()
                # ----------------------------------------------------------------------------------------
                # ---------------------------------- Exploration Phase -----------------------------------
                # ----------------------------------------------------------------------------------------
                # Sample random scenario
                mu_z = np.random.randn(dz)
                cov = np.eye(dz) * sigma
                # -----------------------------
                # Perform experiment
                z_samples = np.random.multivariate_normal(mu_z, cov, n_exploitation)
                x_samples, y_samples = generate_data(z_samples, dx, h, f)
                # -----------------------------
                # Compute the average distance of the samples to the relevant x0
                sample_quality = np.linalg.norm((x_samples - xbar)[:, comp])
                # -----------------------------------------------------------------------------------------
                # ---------------------------------- Exploitation Phase -----------------------------------
                # -----------------------------------------------------------------------------------------
                # add the samples to the exploitation data
                X_Exp_runs.append(x_samples)
                Y_Exp_runs.append(y_samples)
                Z_Exp_runs.append(z_samples)
                X_Samples = np.concatenate(X_Exp_runs)
                Y_Samples = np.concatenate(Y_Exp_runs)
                Z_Samples = np.concatenate(Z_Exp_runs)
                # -----------------------------
                
                # ------------------------------------------------------------------
                # Minimization Problem
                # ------------------------------------------------------------------
                # Minimal Solution Exact
                theta_min_exact = kernel_minmax_gradient(xbar, X_Samples, Y_Samples, Z_Samples, gamma_kernel, comp, lam_c, lam_first, lam_second, minimum=True)
                theta_max_exact = kernel_minmax_gradient(xbar, X_Samples, Y_Samples, Z_Samples, gamma_kernel, comp, lam_c, lam_first, lam_second, minimum=False)
                Kxxbar = gradient_rbf_kernel(X_Samples, xbar, gamma_kernel, comp)
                grad_min_exact = Kxxbar@theta_min_exact
                grad_max_exact = Kxxbar@theta_max_exact
                
                # ---------------------------
                # save results
                Grad_Min_Exact[:, iter] = grad_min_exact.squeeze()
                Grad_Max_Exact[:, iter] = grad_max_exact.squeeze()
                Grad[:, iter] = grad_true.squeeze()[comp]
                Sample_Quality[:, iter] = sample_quality
                Time[:, iter] = time.time() - start

        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # Passive Sampling
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        if strategy_type=="passive_sampling":
            # initate results
            Grad_Min_Exact = np.zeros((1, T))
            Grad_Max_Exact = np.zeros((1, T))
            Sample_Quality = np.zeros((1, T))
            Grad = np.zeros((1, T))
            Time = np.zeros((1, T))

            for iter in range(T):
                logger.info(f"------------------------------------------------ Run {iter} of {T} ------------------------------------------------")
                start = time.time()
                # ----------------------------------------------------------------------------------------
                # ---------------------------------- Exploration Phase -----------------------------------
                # ----------------------------------------------------------------------------------------
                # Sample from fix scenario
                mu_z = np.zeros((dz,)) - 5
                cov = np.eye(dz) * 20
                # -----------------------------
                # Perform experiment
                z_samples = np.random.multivariate_normal(mu_z, cov, n_exploitation)
                x_samples, y_samples = generate_data(z_samples, dx, h, f)
                # -----------------------------
                # Compute the average distance of the samples to the relevant x0
                #diff_bar = (x_samples - xbar)[:, comp]
                #sample_quality = 1 / (n_exploitation**2) * (diff_bar.T)@rbf_kernel(x_samples, x_samples, gamma_kernel)@diff_bar
                sample_quality = np.linalg.norm((x_samples - xbar)[:, comp])

                # -----------------------------------------------------------------------------------------
                # ---------------------------------- Exploitation Phase -----------------------------------
                # -----------------------------------------------------------------------------------------
                # add the samples to the exploitation data
                X_Exp_runs.append(x_samples)
                Y_Exp_runs.append(y_samples)
                Z_Exp_runs.append(z_samples)
                X_Samples = np.concatenate(X_Exp_runs)
                Y_Samples = np.concatenate(Y_Exp_runs)
                Z_Samples = np.concatenate(Z_Exp_runs)
                # -----------------------------
                
                
                theta_min_exact = kernel_minmax_gradient(xbar, X_Samples, Y_Samples, Z_Samples, gamma_kernel, comp, lam_c, lam_first, lam_second, minimum=True)
                theta_max_exact = kernel_minmax_gradient(xbar, X_Samples, Y_Samples, Z_Samples, gamma_kernel, comp, lam_c, lam_first, lam_second, minimum=False)
                Kxxbar = gradient_rbf_kernel(X_Samples, xbar, gamma_kernel, comp)
                grad_min_exact = Kxxbar@theta_min_exact
                grad_max_exact = Kxxbar@theta_max_exact
                
                # ---------------------------
                # save results
                Grad_Min_Exact[:, iter] = grad_min_exact.squeeze()#[comp]
                Grad_Max_Exact[:, iter] = grad_max_exact.squeeze()#[comp]
                Grad[:, iter] = grad_true.squeeze()[comp]
                Sample_Quality[:, iter] = sample_quality
                Time[:, iter] = time.time() - start

        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # Explore Then Exploit
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        if strategy_type=="explore_then_exploit":
            Z_Exp = []
            X_Exp = []
            Y_Exp = []
            X_Mu_extend = []

            if T_exploration > T:
                T_exploration = T - 1

            Grad_Min_Exact = np.zeros((1, int(T - T_exploration)))
            Grad_Max_Exact = np.zeros((1, int(T - T_exploration)))
            Sample_Quality = np.zeros((1, int(T - T_exploration)))
            Time = np.zeros((1, int(T - T_exploration)))
            Grad = np.zeros((1, int(T - T_exploration)))
            # ----------------------------------------------------------------------------------------
            # ---------------------------------- Exploration Phase -----------------------------------
            # ----------------------------------------------------------------------------------------
            start = time.time()
            for _ in range(T_exploration):
                # random scenario
                mu_z = np.random.randn(dz)
                cov = np.eye(dz) * sigma

                # generate the data
                z_candidate = np.random.multivariate_normal(mu_z, cov, (n_exploration))
                x_candidate, y_candidate = generate_data(z_candidate, dx, h, f)

                Z_Exp.append(z_candidate)
                X_Exp.append(x_candidate)
                Y_Exp.append(y_candidate)
            
            Z_Exp = np.concatenate(Z_Exp)
            X_Exp = np.concatenate(X_Exp)
            Y_Exp = np.concatenate(Y_Exp)

            # -----------------------------------------------------------------------------------------
            # ---------------------------------- Exploitation Phase -----------------------------------
            # -----------------------------------------------------------------------------------------
            for iter in range(int(T - T_exploration)):
                # compute the distance
                #Kxxbar = rbf_kernel(X_Exp, xbar, gamma=gamma_kernel)
                #X_Mu_extend = Kxxbar.squeeze()
                Dist = np.linalg.norm(X_Exp - xbar, axis=1)
                X_Mu_extend = Dist.squeeze()

                # select 
                indices_smallest = np.argsort(X_Mu_extend)[:K_exploitation]
                Z_Exp = Z_Exp[indices_smallest, :]
                X_Exp = X_Exp[indices_smallest, :]
                Y_Exp = Y_Exp[indices_smallest]

                mu_z_exp = np.mean(Z_Exp, axis=0)

                # Compute the variance of each column of Z_Exp
                var_z_exp = np.var(Z_Exp, axis=0)
                cov_z_exp = np.diag(var_z_exp)

                # fit distribution over remaining Z
                # Generate the experiment from the data
                z_samples = np.random.multivariate_normal(mu_z_exp, cov_z_exp, n_exploitation)
                x_samples, y_samples = generate_data(z_samples, dx, h, f)
                
                # Compute the average distance of the samples to the relevant x0
                sample_quality = np.linalg.norm((x_samples - xbar)[:, comp])

                # add the samples to the exploitation data
                X_Exp_runs.append(x_samples)
                Y_Exp_runs.append(y_samples)
                Z_Exp_runs.append(z_samples)

                X_Samples = np.concatenate(X_Exp_runs)
                Y_Samples = np.concatenate(Y_Exp_runs)
                Z_Samples = np.concatenate(Z_Exp_runs)

                
                theta_min_exact = kernel_minmax_gradient(xbar, X_Samples, Y_Samples, Z_Samples, gamma_kernel, comp, lam_c, lam_first, lam_second, minimum=True)
                theta_max_exact = kernel_minmax_gradient(xbar, X_Samples, Y_Samples, Z_Samples, gamma_kernel, comp, lam_c, lam_first, lam_second, minimum=False)
                Kxxbar = gradient_rbf_kernel(X_Samples, xbar, gamma_kernel, comp)
                grad_min_exact = Kxxbar@theta_min_exact
                grad_max_exact = Kxxbar@theta_max_exact
                
                Grad_Min_Exact[:, iter] = grad_min_exact.squeeze()#[comp]
                Grad_Max_Exact[:, iter] = grad_max_exact.squeeze()#[comp]
                
                Grad[:, iter] = grad_true.squeeze()[comp]
                Sample_Quality[:, iter] = sample_quality
                Time[:, iter] = time.time() - start
                start = time.time()



        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # Continuous Exploration
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        if strategy_type=="continuous_exploration":
            
            Grad_Min_Exact = np.zeros((1, int(T / 2)))
            Grad_Max_Exact = np.zeros((1, int(T / 2)))
            Sample_Quality = np.zeros((1, int(T / 2)))
            Grad = np.zeros((1, int(T / 2)))
            Time = np.zeros((1, int(T / 2)))

            # save all exploration data
            Z_Exp_exploration = []
            X_Exp_exploration = []
            Y_Exp_exploration = []
            
            for iter in range(int(T / 2)):
                logger.info(f"------------------------------------------------ Run {iter} of {T} ------------------------------------------------")
                start = time.time()
                # ----------------------------------------------------------------------------------------
                # ---------------------------------- Exploration Phase -----------------------------------
                # ----------------------------------------------------------------------------------------
                # random scenario
                mu_z = np.random.randn(dz)
                cov = np.eye(dz) * sigma

                # generate the data
                z_candidate = np.random.multivariate_normal(mu_z, cov, (n_exploration))
                x_candidate, y_candidate = generate_data(z_candidate, dx, h, f)

                # save all samples from the exploration
                Z_Exp_exploration.append(z_candidate)
                X_Exp_exploration.append(x_candidate)
                Y_Exp_exploration.append(y_candidate)
            
                Z_Exp = np.concatenate(Z_Exp_exploration)
                X_Exp = np.concatenate(X_Exp_exploration)
                Y_Exp = np.concatenate(Y_Exp_exploration)

                # -----------------------------------------------------------------------------------------
                # ---------------------------------- Exploitation Phase -----------------------------------
                # -----------------------------------------------------------------------------------------
                # compute the distance
                Dist = np.linalg.norm(X_Exp - xbar, axis=1)
                X_Mu_extend = Dist.squeeze()
                # select the data overall with the smallest distance
                indices_smallest = np.argsort(X_Mu_extend)[:K_exploitation]
                Z_Exp = Z_Exp[indices_smallest, :]
                X_Exp = X_Exp[indices_smallest, :]
                Y_Exp = Y_Exp[indices_smallest]

                mu_z_exp = np.mean(Z_Exp, axis=0)

                # Compute the variance of each column of Z_Exp
                var_z_exp = np.var(Z_Exp, axis=0)
                cov_z_exp = np.diag(var_z_exp)

                # fit distribution over remaining Z
                # Generate the experiment from the data
                z_samples = np.random.multivariate_normal(mu_z_exp, cov_z_exp, n_exploitation)
                x_samples, y_samples = generate_data(z_samples, dx, h, f)
                
                # Compute the average distance of the samples to the relevant x0
                sample_quality = np.linalg.norm((x_samples - xbar)[:, comp])

                # add the samples to the exploitation data
                X_Exp_runs.append(x_samples)
                Y_Exp_runs.append(y_samples)
                Z_Exp_runs.append(z_samples)

                X_Samples = np.concatenate(X_Exp_runs)
                Y_Samples = np.concatenate(Y_Exp_runs)
                Z_Samples = np.concatenate(Z_Exp_runs)

              
                theta_min_exact = kernel_minmax_gradient(xbar, X_Samples, Y_Samples, Z_Samples, gamma_kernel, comp, lam_c, lam_first, lam_second, minimum=True)
                theta_max_exact = kernel_minmax_gradient(xbar, X_Samples, Y_Samples, Z_Samples, gamma_kernel, comp, lam_c, lam_first, lam_second, minimum=False)
                Kxxbar = gradient_rbf_kernel(X_Samples, xbar, gamma_kernel, comp)
                grad_min_exact = Kxxbar@theta_min_exact
                grad_max_exact = Kxxbar@theta_max_exact
                
                Grad_Min_Exact[:, iter] = grad_min_exact.squeeze()#[comp]
                Grad_Max_Exact[:, iter] = grad_max_exact.squeeze()#[comp]
                
                Grad[:, iter] = grad_true.squeeze()[comp]
                Sample_Quality[:, iter] = sample_quality
                Time[:, iter] = time.time() - start


        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # Adaptive Sampling
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        if strategy_type=="adaptive_sampling":
                 
            ns = len(step_size_list)
            Grad_Min_Exact = np.zeros((ns, T))
            Grad_Max_Exact = np.zeros((ns, T))
            Sample_Quality = np.zeros((ns, T))
            Grad = np.zeros((ns, 1))
            Time = np.zeros((ns, T))
            
            for i in range(ns):
                start = time.time()
                K = args.gaussian_K
                Gamma = np.ones((K, ))/ K
                Mu = np.random.randn(K, dz)
                # initialize just diagonal matrix
                Sigma = np.ones((K, dz))
            
                Z_Exp = []
                X_Exp = []
                Y_Exp = []
                Obj = []
            
                # ----------------------------------------------------------------------------------------
                # ---------------------------------- Learning Step ---------------------------------------
                # ----------------------------------------------------------------------------------------
            
                step_size = step_size_list[i]
                logger.info(f"--------- Step Size {step_size} ------------")
                for t in range(int(T)):
                    if t < T_exploration:
                        logger.info(f"--------- Learning {t} of {T_exploration} ------------")

                        z_candidate = sample_gmm_sigma(n_exploration, Gamma, Mu, Sigma)
                        x_candidate, y_candidate = generate_data(z_candidate, dx, h, f)
                        # Compute the objective function
                        
                        Z_Exp.append(z_candidate)
                        X_Exp.append(x_candidate)
                        Y_Exp.append(y_candidate)

                        #Z_Exp = np.concatenate(Z_Exp)
                        
                        def create_batches(z_candidate, x_candidate, y_candidate, num_batches=10):
                            # Get the number of samples
                            num_samples = z_candidate.shape[0]

                            # Create an array of indices and shuffle it
                            indices = np.arange(num_samples)
                            np.random.shuffle(indices)

                            # Use the shuffled indices to shuffle the z, x, and y arrays
                            z_candidate = z_candidate[indices]
                            x_candidate = x_candidate[indices]
                            y_candidate = y_candidate[indices]

                            # Split the indices into batches
                            index_batches = np.array_split(indices, num_batches)

                            # Use the index batches to create z, x, and y batches
                            z_batches = [z_candidate[index_batch] for index_batch in index_batches]
                            x_batches = [x_candidate[index_batch] for index_batch in index_batches]
                            y_batches = [y_candidate[index_batch] for index_batch in index_batches]

                            return z_batches, x_batches, y_batches
                        
                        num_batches = 10
                        z_batch, x_batch, y_batch = create_batches(z_candidate, x_candidate, y_candidate, num_batches)
                        
                        obj = np.zeros(10)
                        gamma_update = np.zeros((num_batches, K))
                        mean_update = np.zeros((num_batches, K, dz))
                        sigma_update = np.zeros((num_batches, K, dz))

                        for j in range(10):
                            obj[j] = objective_bounds_kernel(z_batch[j], x_batch[j], y_batch[j], xbar, comp, gamma_kernel, lam_c, lam_first, lam_second)
                            gamma_update[j, :], mean_update[j, :, :], sigma_update[j, :, :] = log_likelihood_derivatives_sigma(z_batch[j], Gamma, Mu, Sigma)
                            gamma_update[j, :] *= obj[j]
                            mean_update[j, :, :] *= obj[j]
                            sigma_update[j, :, :] *= obj[j]


                        # Update the parameters
                        Gamma = Gamma - step_size * np.mean(gamma_update, axis=0).squeeze()
                        # normalization of gamma
                        if (Gamma > 0).sum() < K: 
                            logger.info(f"Gamma should be positive")
                            Gamma[Gamma < 0] = 0.1
                        Gamma = Gamma / np.sum(Gamma)
                        Mu = Mu - step_size * np.mean(mean_update, axis=0).squeeze()
                        Sigma = Sigma - step_size * np.mean(sigma_update, axis=0).squeeze()
                        
                        theta_min_exact = kernel_minmax_gradient(xbar, x_candidate, y_candidate, z_candidate, gamma_kernel, comp, lam_c, lam_first, lam_second, minimum=True)
                        theta_max_exact = kernel_minmax_gradient(xbar, x_candidate, y_candidate, z_candidate, gamma_kernel, comp, lam_c, lam_first, lam_second, minimum=False)
                        Kxxbar = gradient_rbf_kernel(x_candidate, xbar, gamma_kernel, comp)
                        grad_min_exact = Kxxbar@theta_min_exact
                        grad_max_exact = Kxxbar@theta_max_exact

                        # save for each step size
                        Grad_Min_Exact[i, t] = grad_min_exact.squeeze()#[comp]
                        Grad_Max_Exact[i, t] = grad_max_exact.squeeze()#[comp]
                        Time[i, t] = time.time() - start
                
                    else:
                        start = time.time()
                        # ----------------------------------------------------------------------------------------
                        # ---------------------------------- Exploitation Step -----------------------------------
                        # ----------------------------------------------------------------------------------------
                        logger.info(f"--------- Exploit {jter} of {n_seeds} ------------")
                        # Compute the final sample set
                        z_samples = sample_gmm_sigma(n_exploitation, Gamma, Mu, Sigma)
                        x_samples, y_samples = generate_data(z_samples, dx, h, f)

                        # Compute sample quality -> how close to the relevant x0
                        #diff_bar = (x_samples - xbar)[:, comp]
                        #sample_quality = 1 / (n_exploitation**2) * (diff_bar.T)@rbf_kernel(x_samples, x_samples, gamma_kernel)@diff_bar
                        sample_quality = np.linalg.norm((x_samples - xbar)[:, comp])

                        # add the samples to the exploitation data
                        X_Exp_runs.append(x_samples)
                        Y_Exp_runs.append(y_samples)
                        Z_Exp_runs.append(z_samples)

                        X_Samples = np.concatenate(X_Exp_runs)
                        Y_Samples = np.concatenate(Y_Exp_runs)
                        Z_Samples = np.concatenate(Z_Exp_runs)

                        
                        theta_min_exact = kernel_minmax_gradient(xbar, X_Samples, Y_Samples, Z_Samples, gamma_kernel, comp, lam_c, lam_first, lam_second, minimum=True)
                        theta_max_exact = kernel_minmax_gradient(xbar, X_Samples, Y_Samples, Z_Samples, gamma_kernel, comp, lam_c, lam_first, lam_second, minimum=False)
                        Kxxbar = gradient_rbf_kernel(X_Samples, xbar, gamma_kernel, comp)
                        grad_min_exact = Kxxbar@theta_min_exact
                        grad_max_exact = Kxxbar@theta_max_exact
                        
                        
                      
                        # save for each step size
                        Grad_Min_Exact[i, t] = grad_min_exact.squeeze()#[comp]
                        Grad_Max_Exact[i, t] = grad_max_exact.squeeze()#[comp]
                        
                        Grad[i, :] = grad_true.squeeze()[comp]
                        Sample_Quality[i, :] = sample_quality
                        Time[i, t] = time.time() - start


        res["Grad_Min_Exact"] = Grad_Min_Exact
        res["Grad_Max_Exact"] = Grad_Max_Exact
        res["Grad"] = Grad
        res["Z_Samples"] = Z_Samples
        res["X_Samples"] = X_Samples
        res["Y_Samples"] = Y_Samples
        res["Sample_Quality"] = Sample_Quality
        res["Time"] = Time
        # save results for scatter plot

        if dx == 2:
            # Visualization of the data
            fig_2d, fig_surface = plot_data_2d(x_data, f, xbar=xbar.squeeze(), grad_true=grad_true.squeeze())
            fig_2d.savefig(os.path.join(experiment_dir, "data_2d.png"))
            plt.close(fig_2d)
            fig_surface.savefig(os.path.join(experiment_dir, "data_surface.png"))
            plt.close(fig_surface)

        # save result as npy file
        np.save(os.path.join(experiment_dir, f"results_{jter}.npy"), res)
    
        # ------------------------------------------------------------------------------------------------------------------------------------
        # save results over different seeds
        # ------------------------------------------------------------------------------------------------------------------------------------
        res_seed["Grad_Min_Exact"].append(Grad_Min_Exact)
        res_seed["Grad_Max_Exact"].append(Grad_Max_Exact)
        res_seed["Time"].append(Time)
        # save result as npy file
        np.save(os.path.join(experiment_dir, f"results_seed.npy"), res_seed)




if __name__ == "__main__":

    args = parser.parse_args()
    
    if args.output_dir is None:
        
        now = datetime.now()

        # Format as a string
        now_str = now.strftime("%Y%m%d_%H%M")

        # Append to experiment name
        experiment_name = f"{args.experiment_name}_{args.strategy_type}_{now_str}"
        experiment_dir = os.path.join(output_dir, experiment_name)

    else: 
        output_dir = args.output_dir
        experiment_name = args.experiment_name
    
        experiment_dir = output_dir
    
    os.makedirs(experiment_dir, exist_ok=True)
    main(experiment_dir, args)
            
        

    