import numpy as np



from scipy.linalg import solve
from scipy.stats import multivariate_normal
from sklearn.metrics.pairwise import rbf_kernel


def log_likelihood_derivatives(Z_Samples, Gamma, Mu):
    N, dz = Z_Samples.shape
    K = len(Gamma)
    
    # Storage for gradients
    dL_dGamma = np.zeros_like(Gamma)
    dL_dMu = np.zeros_like(Mu)
    
    # Storage for responsibilities
    responsibilities = np.zeros((N, K))
    
    # Calculate responsibilities
    for k in range(K):
        cov_matrix = np.identity(dz)  # Identity covariance matrix
        pdf_values = multivariate_normal.pdf(Z_Samples, mean=Mu[k, :], cov=cov_matrix)
        responsibilities[:, k] = Gamma[k] * pdf_values
    
    # Normalize responsibilities to sum to one for each sample
    responsibilities_sum = np.sum(responsibilities, axis=1, keepdims=True)
    responsibilities /= responsibilities_sum
    
    # Calculate gradients
    for k in range(K):
        # Gradient w.r.t. the mixture weights
        dL_dGamma[k] = np.mean(responsibilities[:, k] / Gamma[k])

        # Gradient w.r.t means
        diff = Z_Samples - Mu[k, :]  # Shape (N, dz)
        weighted_diff = responsibilities[:, k, np.newaxis] * diff  # Shape (N, dz)
        dL_dMu[k] = np.mean(weighted_diff, axis=0)  # Average over N samples
    
    return dL_dGamma, dL_dMu


def log_likelihood_derivatives_old(Z_Samples, Gamma, Mu, Sigma):
    N, dz = Z_Samples.shape
    K = len(Gamma)

    # Compute the likelihoods
    likelihoods = np.zeros((N, K))
    for k in range(K):
        cov_matrix = np.diag(Sigma[k])
        likelihoods[:, k] = multivariate_normal.pdf(Z_Samples, mean=Mu[k, :], cov=cov_matrix)

    # Compute the weighted likelihoods
    weighted_likelihoods = Gamma * likelihoods

    # Compute the sum of the weighted likelihoods
    sum_weighted_likelihoods = np.sum(weighted_likelihoods, axis=1)

    # Compute the derivatives
    dL_dGamma = np.sum(likelihoods / sum_weighted_likelihoods[:, np.newaxis], axis=0)
    dL_dMu = np.zeros((K, dz))
    dL_dSigma = np.zeros((K, dz))

    for k in range(K):
        diff = Z_Samples - Mu[k, :]
        weight = (Gamma[k] / sum_weighted_likelihoods)[:, np.newaxis]
        dL_dMu[k] = np.sum(weight * diff, axis=0)
        for i in range(N):
            diff_i = diff[i, :]
            dL_dSigma[k] += weight[i] * (diff_i ** 2 - Sigma[k])
        dL_dSigma[k] = dL_dSigma[k] / 2

    return dL_dGamma, dL_dMu, dL_dSigma



def log_likelihood_derivatives_sigma(Z_Samples, Gamma, Mu, Sigma):

    N, dz = Z_Samples.shape
    K = len(Gamma)  # Number of components
    
    # Storage for gradients
    dL_dGamma = np.zeros_like(Gamma)
    dL_dMu = np.zeros_like(Mu)
    dL_dSigma = np.zeros_like(Sigma)
    
    # Storage for responsibilities
    responsibilities = np.zeros((N, K))
    
    # Calculate responsibilities
    for k in range(K):
        cov = np.diag(Sigma[k])  # cov matrix
        pdf_values = multivariate_normal.pdf(Z_Samples, mean=Mu[k], cov=cov)
        responsibilities[:, k] = Gamma[k] * pdf_values
    
    # Make sure they sum to one
    responsibilities_sum = responsibilities.sum(axis=1, keepdims=True)
    responsibilities /= responsibilities_sum
    
    # Calculate gradients
    for k in range(K):
        # Gradient w.r.t. the mixture weights
        dL_dGamma[k] = np.mean(responsibilities[:, k] / Gamma[k])

        # Gradient w.r.t means
        diff = Z_Samples - Mu[k]  # Shape (n, d)
        Sigma_Inv = 1 / Sigma[k]  # Inverse of diagonal elements
        weighted_diff = responsibilities[:, k, np.newaxis] * diff * Sigma_Inv  # Shape (n, d)
        dL_dMu[k] = np.mean(weighted_diff, axis=0)  # average over n samples
        
        # Gradient w.r.t. the diagonal covariance entries
        sq_diff = diff ** 2  # Element-wise square
        weighted_sq_diff = responsibilities[:, k, np.newaxis] * sq_diff * Sigma_Inv**2  # Element-wise inverse squared (n, d)
        dL_dSigma[k] = -0.5 * (Sigma_Inv - np.mean(weighted_sq_diff, axis=0))
        
    
    # Normalizing the gradient of the mixture weights
    dL_dGamma -= np.mean(dL_dGamma)  # Ensure the sum of updates is zero to maintain the constraint sum(weights) = 1
    
    return dL_dGamma, dL_dMu, dL_dSigma


def sample_gmm_sigma(N, Gamma, Mu, Sigma):
    
    # Number of components
    K, dz = Mu.shape

    # Sample component indices
    indices = np.random.choice(K, size=N, p=Gamma)

    # Sample from multivariate Gaussian
    samples = np.array([np.random.multivariate_normal(Mu[i], np.diag(Sigma[i])) for i in indices])

    return samples



def objective_bounds_kernel(z, x, y, xbar, comp, gamma_kernel, lam_feasibility, lam_first, lam_second):
    """
    Min Max Formulation
    """
     
    # Kernel matrices for existing data
    Kzz = rbf_kernel(z, z, gamma_kernel)
    Kxx = rbf_kernel(x, x, gamma_kernel)

    Kxxbar_grad = gradient_rbf_kernel(x, xbar, gamma_kernel, comp)[..., np.newaxis]
    #M = np.sqrt(Kzz)@np.linalg.solve(U / (n * delta **2) * Kzz + np.identity(n), np.sqrt(Kzz))
    M = Kzz
    pinv = np.linalg.solve(Kxx@M@Kxx + 4 * (lam_first * lam_second )   * Kxx, Kxxbar_grad)
    
    obj = - 2 / lam_feasibility * pinv.T@Kxxbar_grad

    obj_abs = np.linalg.norm(obj, 2)

    return obj_abs.squeeze()


# Minimum Norm Estimator KIV
def kiv(x, y, z, lam1, lam2, gamma):
    """ Kernel Instrumental Variables """
    n = z.shape[0]

    Kzz = rbf_kernel(z, z, gamma)
    Kxx = rbf_kernel(x, x, gamma)
    W = Kxx @ solve(Kzz + n * lam1 * np.eye(n), Kzz)
    alpha_hat = solve(W@W.T + n * lam2 * np.eye(n), W@y)

    return alpha_hat


def kernel_minmax_gradient(xbar, x, y, z, gamma_kernel, comp, lam_feasibility, lam_first, lam_second, minimum=True):
    """
    Min Max Formulation
    """
    Kxx = rbf_kernel(x, x, gamma_kernel)
    Kzz = rbf_kernel(z, z, gamma_kernel)
    Kxxbar_grad = gradient_rbf_kernel(x, xbar, gamma_kernel, comp)[..., np.newaxis]
    #M = M_factor@np.linalg.solve(U / (n * delta **2) * Kzz + np.identity(n), M_factor)
    M = Kzz
    regularization = 4 * (lam_first * lam_second)  * Kxx
    mat = Kxx @ M @ Kxx + regularization
    pinv = np.linalg.pinv(mat)
    
    grad_var = Kxxbar_grad if minimum else - Kxxbar_grad
    theta_hat = pinv@(Kxx@M@y + 1 / lam_feasibility * grad_var)

    return theta_hat


def gradient_rbf_kernel(x, xbar, gamma, comp):
    """
    Computes the gradient of the function f(xbar) = alpha * rbf_kernel(x, xbar)
    with respect to xbar using sklearn's rbf_kernel.

    :param alpha: Scalar multiplier.
    :param x: numpy array, point at which the kernel is centered.
    :param xbar: numpy array, point at which the gradient is to be computed.
    :param sigma: Scalar, the length scale of the RBF kernel.
    :return: numpy array, the gradient of f at xbar.
    """
    
    # Compute the RBF kernel value using sklearn (note: rbf_kernel expects 2D array inputs)
    K = rbf_kernel(x, xbar.reshape(1, -1), gamma=gamma)
    
    # Calculate the derivative of K with respect to the comp^th dimension
    # First, reshape xbar to align with the shape expected by sklearn (1, dx)
    xbar_reshaped = xbar.reshape(1, -1)
    diff = x - xbar_reshaped
    dK_dcomp = diff[:, comp] * K.flatten() * gamma * 2
    
    # Approximate the gradient of f at xbar in the direction of comp
    #grad = np.dot(dK_dcomp, theta)
    
    return dK_dcomp
    
