import torch


def triangular_inverse(A: torch.Tensor, upper: bool=False):
    eye = torch.eye(A.size(-1), dtype=A.dtype, device=A.device)
    return torch.linalg.solve_triangular(A, eye, upper=upper)


def bmv(A, x):
    return (A @ x[..., None]).squeeze(-1)
    # return torch.matmul(A, x)
    # return torch.einsum('...ij, ...j -> ...i', A, x)


def bop(x1, x2):
    return torch.einsum('...i, ...j -> ...ij', x1, x2)


def bip(x1, x2):
    return torch.sum(x1 * x2, dim=-1)
    # return torch.einsum('...i, ...i -> ...', x1, x2)


def bqp(A, x):
    return torch.einsum('...i, ...ij, ...j -> ...', x, A, x)


def chol_bmv_solve(chol_f, v):
    return torch.cholesky_solve(v.unsqueeze(-1), chol_f).squeeze(-1)


def bmv_lr_pd_inv(D_diag, K, x):
    # find (D + KK^T)^{-1} v
    D_inv_diag = 1 / D_diag
    u = D_inv_diag * x

    triple_p = (K.mT * D_inv_diag) @ K
    eye = torch.eye(K.shape[-1], device=x.device)
    triple_chol = torch.linalg.cholesky(eye + triple_p)
    v = D_inv_diag * bmv(K, chol_bmv_solve(triple_chol, bmv(K.mT, u)))

    y = u - v
    return y, triple_chol


def bmm_lr_pd_inv(D_diag, K, X, triple_chol=None):
    # find (D + KK^T)^{-1} v
    D_inv_diag = 1 / D_diag
    U = D_inv_diag[..., None] * X

    if triple_chol is None:
        triple_p = (K.mT * D_inv_diag) @ K
        eye = torch.eye(K.shape[-1], device=X.device)
        triple_chol = torch.linalg.cholesky(eye + triple_p)
        V = D_inv_diag[..., None] * K @ (torch.cholesky_solve(K.mT, triple_chol) @ U)

        Y = U - V
        return Y, triple_chol
    else:
        V = D_inv_diag[..., None] * K @ (torch.cholesky_solve(K.mT, triple_chol) @ U)
        Y = U - V
        return Y


def bmm_lr_pstruct_inv(struct_inv_bmm, K, X):
    # find (struct + KK^T)^{-1} v
    U, triple_chol_p = struct_inv_bmm(X)
    struct_K = struct_inv_bmm(K, triple_chol_p)

    triple_f = K.mT @ struct_K
    eye = torch.eye(K.shape[-1], device=X.device)
    triple_chol_f = torch.linalg.cholesky(eye + triple_f)
    V = struct_K @ torch.cholesky_solve(K.mT, triple_chol_f) @ U
    Y = U - V

    return Y, triple_chol_f


def bmv_lr_pstruct_inv(struct_inv_bmm, K, x):
    # find (struct + KK^T)^{-1} v
    Y, triple_chol = bmm_lr_pstruct_inv(struct_inv_bmm, K, x[..., None])
    y = Y.squeeze(-1)
    return y, triple_chol


def hyperplane_projection(x, a, b):
    mag = (bip(a, x) - b) / bip(a, a)
    x_hat = x - mag.unsqueeze(-1) * a
    return x_hat


def cg_sampler(J, b, n_it=10):
    x_0 = b
    dims = x_0.shape
    r_km1 = b - bmv(J, x_0)
    p_km1 = r_km1
    d_km1 = bqp(J, p_km1)

    y_km1 = x_0
    x_km1 = x_0

    for k in range(1, n_it+1):
        gamma_km1 = bip(r_km1, r_km1) / d_km1
        x_k = x_km1 + gamma_km1.unsqueeze(-1) * p_km1

        z = torch.randn(dims[:-1])
        # y_k = y_km1 + (1 / torch.sqrt(d_km1.unsqueeze(-1))) * z.unsqueeze(-1) * p_km1
        r_k = r_km1 - gamma_km1.unsqueeze(-1) * bmv(J, p_km1)

        beta_k = - bip(r_k, r_k) / bip(r_km1, r_km1)
        p_k = r_k - beta_k.unsqueeze(-1) * p_km1
        d_km1 = bqp(J, p_k)
        x_km1 = x_k
        # y_km1 = y_k
        r_km1 = r_k

        if torch.mean(bip(r_km1, r_km1)) < 1e-3:
            break

    return x_k


def invert_matrix(C, n_it=10):
    """
    Inverts the matrix C (with arbitrary batch dimensions) using the provided solve_linear function.

    Parameters:
    - C: Matrix to be inverted with shape (*batch_dims, N, N)
    - solve_linear: A function that takes A and b and returns x such that Ax = b

    Returns:
    - C_inv: Inverse of C with shape (*batch_dims, N, N)
    """

    *batch_dims, N, _ = C.shape
    identity = torch.eye(N, device=C.device, dtype=C.dtype).expand(*batch_dims, N, N)
    C_inv_list = []

    for i in range(N):
        b = identity[..., i]
        C_inv_col = cg_sampler(C, b, n_it=n_it)
        C_inv_list.append(C_inv_col.unsqueeze(-1))

    return torch.cat(C_inv_list, dim=-1)




