import numpy as np
import cvxpy as cp

from action_masking.util.sets import Zonotope

# from sets import Zonotope


def calc_safe_input_set(
    A: np.ndarray,
    B: np.ndarray,
    X: Zonotope,
    tU: Zonotope,
    W: Zonotope,
    UF: Zonotope,
    XS: Zonotope,
    mode: str = "support_max",
    mode_args: dict = {},
) -> Zonotope:
    """
    Calculates safe input set for a linear system with additive disturbance.

    Args:
        A (np.ndarray): system matrix
        B (np.ndarray): input matrix
        X (Zonotope): current (uncertain) state
        tU (Zonotope): template input set
        W (Zonotope): constant disturbance set
        UF (Zonotope): feasible (technically possible) input set
        XS (Zonotope): safe state set (assumed to be RCI)
        mode (str, optional): mode of the optimization problem.
            One of ["vol_max", "support_max", "box_max", "bound_min"].
            "vol_max" maximizes the geometric mean of the scaling factors.
            "support_max" maximizes the geometric mean of support functions in random directions.
            "box_max" maximizes the volume of an enclosed box.
            "bound_min" minimizes the squared distance of random points within the reach set to boundary points on XS.
            Defaults to "support_max".
        mode_args (dict, optional): additional arguments for the optimization problem.
            For mode="support_max" we expect the key "n_dirs" with the number of random directions to use.
            Defaults to tU.g * tU.d.
            For mode="bounds_min" we expect the key "n_points" with the number of boundary points to use.
            If not given, we use all available boundary points.

    Returns:
        Zonotope: safe input set as zonotope
    """

    nx = A.shape[0]
    nu = B.shape[1]

    ng_X = X.G.shape[1]
    ng_U = tU.G.shape[1]
    ng_XS = XS.G.shape[1]
    ng_W = W.G.shape[1]
    ng_UF = UF.G.shape[1]

    Gamma_R = cp.Variable((ng_XS, ng_X + ng_U + ng_W))
    beta_R = cp.Variable((ng_XS, 1))

    Gamma_U = cp.Variable((ng_UF, ng_U))
    beta_U = cp.Variable((ng_UF, 1))

    c_U = cp.Variable((nu, 1))  # center vector for template input set
    G_U = cp.Variable((nu, ng_U))  # generator matrix for input set

    # Template approach
    if mode == "scale":
        p = cp.Variable(nonneg=True)  # scalar scaling factor for template input set
        c0 = G_U == tU.G * p
    else:
        p = cp.Variable((ng_U, 1), nonneg=True)  # scaling factors for template input set
        c0 = G_U == tU.G @ cp.diag(p)

    # Reachset in safe set containment constraints

    # Constraint 1
    # G^R = G^{X^S} * Gamma_R, where G^R = [A*G^X, B*G^U, G^W]
    c1 = [
        XS.G @ Gamma_R[:, :ng_X] == A @ X.G,
        XS.G @ Gamma_R[:, ng_X : ng_X + ng_U] == B @ G_U,
        XS.G @ Gamma_R[:, ng_X + ng_U :] == W.G,
    ]

    # Constraint 2
    # c^{X^S} - c^R = G^{X^S} * beta_R, where c^R = A*c^X + B*c^U + c^W
    c2 = XS.c - (A @ X.c + B @ c_U + W.c) == XS.G @ beta_R

    # Constraint 3
    # ||[Gamma_R, beta_R]||_inf <= 1
    # convert to linear constraints:
    # Z \mathbb{1} <= 1 <=> row sums of Z <= 1 (element-wise)
    # vec([Gamma_R, beta_R]) <= vec(Z)
    # -vec([Gamma_R, beta_R]) <= vec(Z)
    #
    # this should work, but it doesn't
    # c3 = cp.pnorm(cp.hstack([Gamma_R, beta_R]), "inf") <= 1

    Z = cp.Variable((ng_XS, ng_X + ng_U + ng_W + 1))
    c3 = cp.sum(Z, axis=1) <= 1
    c41 = cp.vec(cp.hstack([Gamma_R, beta_R])) <= cp.vec(Z)
    c42 = cp.vec(cp.hstack([-Gamma_R, -beta_R])) <= cp.vec(Z)

    # Safe input set in feasible input set containment constraints
    # Same logic as above

    c5 = G_U == UF.G @ Gamma_U

    c6 = UF.c - c_U == UF.G @ beta_U

    # this should work, but it doesn't
    # c7 = cp.pnorm(cp.hstack([Gamma_U, beta_U]), "inf") <= 1

    Y = cp.Variable((ng_UF, ng_U + 1))
    c7 = cp.sum(Y, axis=1) <= 1
    c81 = cp.vec(cp.hstack([Gamma_U, beta_U])) <= cp.vec(Y)
    c82 = cp.vec(cp.hstack([-Gamma_U, -beta_U])) <= cp.vec(Y)

    constraints = [c0] + c1 + [c2, c3, c41, c42, c5, c6, c7, c81, c82]

    if mode == "vol_max":
        objective = cp.Maximize(cp.geo_mean(p))
    elif mode == "support_max":
        # maximize sum of support functions in random directions

        n_dirs = mode_args.get("n_dirs", ng_U * nu)
        dirs = np.random.randn(nu, n_dirs)
        aux_P = cp.Parameter((n_dirs, nu), value=dirs.reshape((n_dirs, nu)))
        objective = cp.Maximize(cp.geo_mean(cp.sum((cp.abs(aux_P @ tU.G) @ p), axis=1)))

    elif mode == "box_max":
        # max volume of enclosed box (stupid version of Lukas' aproach)

        box_p = cp.Variable((nu, 1), nonneg=True)
        box_Gamma = cp.Variable((ng_U, nu))
        box_Y = cp.Variable(box_Gamma.shape)

        c_box = [
            cp.reshape(cp.sum(box_Y, axis=1), (ng_U, 1)) <= p,
            cp.diag(box_p) == tU.G @ box_Gamma,
            cp.vec(box_Gamma) <= cp.vec(box_Y),
            cp.vec(-box_Gamma) <= cp.vec(box_Y),
        ]

        constraints += c_box

        objective = cp.Maximize(cp.geo_mean(box_p))

    elif mode == "bound_min":
        n_points = mode_args.get("n_points", 10 * XS.d * XS.g)
        boundary_points = XS.get_boundary_points(n_points)

        # Create reference points matrix
        ref_pU = cp.Variable((nx, n_points))

        # Create reference Beta parameters with random initialization
        ref_Beta = np.random.uniform(-1, 1, size=(Gamma_R.shape[1], n_points))

        # reference points must be in reach set R

        for j in range(n_points):
            constraints.append(
                cp.reshape(ref_pU[:, j], (nx, 1))
                == A @ X.c
                + A @ cp.reshape(X.G @ ref_Beta[:ng_X, j], (nx, 1))
                + B @ c_U
                + B @ cp.reshape(G_U @ ref_Beta[ng_X : ng_X + ng_U, j], (nu, 1))
                + W.c
                + W.G @ cp.reshape(ref_Beta[ng_X + ng_U :, j], (W.c.shape[0], 1))
            )

        objective = cp.Minimize(cp.sum_squares(ref_pU - np.array(boundary_points).reshape((n_points, nx)).T))

    elif mode == "scale":
        # maximize a scalar scaling factor of the template input set
        objective = cp.Maximize(p)

    else:
        raise ValueError("Invalid mode specified.")

    problem = cp.Problem(objective, constraints)

    # Solve the problem
    try:
        problem.solve(solver=cp.CLARABEL)
    except cp.error.SolverError:
        print('Numerical instabilities in solver, switching to fail-safe planner')
        return None

    # Check solver status
    if problem.status not in [cp.OPTIMAL]:
        raise ValueError("Infeasible, switching to fail-safe planner")

    return Zonotope(G_U.value, c_U.value)


if __name__ == "__main__":

    from tictoc import tic, toc

    # 2D
    A = np.array([[1, 0], [0, 1]])
    B = np.array([[1, 0], [0, 1]])

    UF = 20 * Zonotope.from_unit_box(dim=2)
    W = 0 * Zonotope.from_unit_box(dim=2)
    XS = Zonotope.from_random(2, 3)

    # tU = Zonotope(np.array([[2, 0], [0, 2]]), np.array([0, 0]))
    tU = Zonotope.from_random(2, 50)
    # tU = Zonotope(np.linalg.pinv(B) @ XS.G)

    X = 0 * Zonotope.from_unit_box(dim=2)

    def evaluate_method(A, B, X, tU, W, UF, XS, mode: str, mode_args: dict = {}):
        tic()
        safe_U = calc_safe_input_set(A, B, X, tU, W, UF, XS, mode, mode_args)
        toc()

        print(f"volume {mode}: {safe_U.volume}")

        return safe_U

    modes = [
        ["vol_max", {}, "b"],
        ["support_max", {"n_dirs": 50}, "g"],
        ["box_max", {}, "y"],
        ["bound_min", {"nd_points": 50}, "black"],
        ["scale", {}, "orange"],
    ]

    XS.plot(color="r", show=False)

    for i, mode in enumerate(modes):

        safe_U = evaluate_method(A, B, X, tU, W, UF, XS, mode[0], mode[1])
        show = True if i == len(modes) - 1 else False
        safe_U.plot(color=mode[2], show=show)
