from __future__ import annotations

import itertools
import math
from itertools import combinations
from typing import List, Tuple, Union

import cvxpy as cp
import numpy as np
from numpy.linalg import det, matrix_rank, norm
from numpy.random import randn
from scipy.spatial import ConvexHull


class Zonotope:
    def __init__(self, G: Union[np.ndarray, None], c: Union[np.ndarray, None] = None):
        """class representing a zonotope Z = {c + sum_i G[:, i] * alpha_i | alpha_i in [-1,1]}"""

        if c is None:
            c = np.zeros((G.shape[0], 1))

        if len(c.shape) == 1 or c.shape[1] != 1:
            c = c.reshape((c.shape[0], 1))

        assert len(G.shape) == 2, "G must be a matrix"
        assert len(c.shape) == 2 and c.shape[1] == 1, "c must be a vector"
        assert c.shape[0] == G.shape[0], "c and G must have matching dimensions"

        self.d, self.g = np.shape(G)  # get dimension and number of generators

        self._c = c
        self._G = G

        # annotations
        # these are attributes that are expensive to compute and therefore stored in the instance
        self._volume = None
        self.boundary_points = []

    @property
    def ndim(self) -> int:
        return self.d

    @property
    def order(self) -> float:
        return self.g / self.d

    @property
    def c(self) -> np.ndarray:
        return self._c

    @c.setter
    def c(self, value: np.ndarray):
        assert value.ndim == 2 and value.shape[1] == 1, "c must be a column vector"
        self._c = value

    @property
    def G(self) -> np.ndarray:
        return self._G

    @G.setter
    def G(self, value: np.ndarray):
        assert value.ndim == 2, "G must be a matrix"
        self._G = value

    @classmethod
    def from_numpy(cls, a: np.ndarray) -> Zonotope:
        """create a zonotope from a numpy array containing the center and the generator matrix"""
        G = a[:, 1:]
        c = a[:, 0]
        return cls(G, c)

    @classmethod
    def from_unit_box(cls, dim: int, c: Union[np.ndarray, None] = None) -> Zonotope:
        """create a zonotope from a unit box"""

        G = np.eye(dim)

        return cls(G, c)

    @classmethod
    def from_random(cls, n_d: int, n_g: int, c: Union[np.ndarray, None] = None, distribution: str = None) -> Zonotope:
        """
        Create a zonotope with random generators and center vector
        The strategy is to sample the generator "directions" as points on a
        norm-sphere and then scale them uniformly from [0,1].
        Refer to CORAs zonotope.generateRandom().
        """

        if distribution == 'uniform':
            scaling_factors = np.random.uniform(size=n_g)
        elif distribution == 'normal':
            scaling_factors = np.random.normal(0, 1, size=n_g)
        elif distribution == 'exponential':
            scaling_factors = np.random.exponential(1, size=n_g)
        elif distribution == 'gamma':
            scaling_factors = np.random.gamma(2, 1, size=n_g)
        else:
            raise ValueError("Invalid distribution")

        G = np.zeros((n_d, n_g))
        for i in range(n_g):
            G[:, i] = np.random.normal(0, 1, n_d)
            G[:, i] /= np.linalg.norm(G[:, i])
            G[:, i] *= scaling_factors[i]

        return cls(G, c)

    def __add__(self, other: Zonotope) -> Zonotope:
        """minkowsi sum of two zonotopes"""

        return self.__class__(np.concatenate((self._G, other._G), axis=1), self._c + other._c)

    def __mul__(self, other: float) -> Zonotope:
        """multiplication of a zonotope with a scalar"""

        return self.__class__(self._G * other, self._c * other)

    def __rmul__(self, other: float) -> Zonotope:
        """multiplication of a zonotope with a scalar"""

        return self.__class__(self._G * other, self._c * other)

    def map(self, other: np.ndarray) -> Zonotope:
        """multiplication of a zonotope with a matrix"""

        return self.__class__(other @ self._G, other @ self._c)

    def to_vertices_2d(self, dimensions: Tuple[int, int] = (0, 1)) -> np.ndarray:
        assert dimensions[0] <= self.ndim and dimensions[1] <= self.ndim

        G = self._G[dimensions, :]
        c = self._c[dimensions, :]

        # remove zero generators
        tmp = np.sum(abs(G), axis=0)
        ind = np.where(tmp > 0)[0]
        G = G[:, ind]

        # size of enclosing interval
        xmax = np.sum(abs(G[0, :]))
        ymax = np.sum(abs(G[1, :]))

        # flip directions of generators so that all generators are pointing up
        ind = np.where(G[1, :] < 0)
        G[:, ind] = -G[:, ind]

        # sort generators according to their angles
        ang = np.arctan2(G[1, :], G[0, :])
        ind = np.where(ang < 0)[0]
        ang[ind] = ang[ind] + 2 * np.pi

        ind = np.argsort(ang)

        # sum the generators in the order of their angle
        n = G.shape[1]
        points = np.zeros((2, n + 1))

        for i in range(n):
            points[:, i + 1] = points[:, i] + 2 * G[:, ind[i]]

        points[0, :] = points[0, :] + xmax - np.max(points[0, :])
        points[1, :] = points[1, :] - ymax

        # mirror upper half of the zonotope to get the lower half
        tmp1 = np.concatenate((points[0, :], points[0, n] + points[0, 0] - points[0, 1 : n + 1]))
        tmp2 = np.concatenate((points[1, :], points[1, n] + points[1, 0] - points[1, 1 : n + 1]))

        tmp1 = np.resize(tmp1, (1, len(tmp1)))
        tmp2 = np.resize(tmp2, (1, len(tmp2)))

        points = np.concatenate((tmp1, tmp2), axis=0)

        # shift vertices by the center vector
        points = points + c

        return points

    def volume_complexity(self) -> float:
        """
        Calculates the volume complexity of the zonotope.

        The volume complexity is computed as g over d.

        Returns:
            float: number of combinations
        """
        return math.comb(self.g, self.d)

    def volume(self) -> float:
        """
        Based on:
        [1] E. Gover and N. Krikorian, “Determinants and the volumes of parallelotopes and zonotopes”,
        Linear Algebra and its Applications, vol. 433, no. 1, pp. 28–40, 2010.
        Specifically corollary 3.4, p.39.
        Note: It looks like they are doing the same in volesti.
        """
        if self._volume is None:
            vol = 0.0

            if matrix_rank(self._G) < self.d:
                # Generator matrix of insufficient rank for volume calculation
                return 0.0

            gcombs = [x for x in combinations(range(self.g), r=self.d)]

            for comb in gcombs:
                A = self._G[:, comb]
                b = np.absolute(det(A))
                vol += b

            vol = 2**self.d * vol

            self._volume = vol

        return self._volume

    def volume_triangulation(self) -> float:
        vertices = self.to_vertices()
        ch = ConvexHull(vertices.T)
        return ch.volume

    def volume_approx(self, method: str = "frob") -> float:
        """
        Approximate the volume of a zonotope.

        Args:
            method (str, optional): Method to use.
                One of:
                "frob" (Frobenius norm),
                "1" (1-norm),
                "int" (interval norm),
                "inf" (infinity norm).
                Defaults to "frob".

        Returns:
            float: appxoimate volume
        """

        if method == "frob":  # asbolute sum of all elements squared
            vol = np.linalg.norm(self._G, ord="fro")
        elif method == "1":  # max row sum
            vol = np.linalg.norm(self._G, ord=1)
        elif method == "inf":  # max column sum
            vol = np.linalg.norm(self._G, ord=np.inf)
        elif method == "int":  # absolute sum of all elements
            vol = sum([np.linalg.norm(self._G[:, i], ord=1) for i in range(self.g)])
        else:
            raise NotImplementedError("Invalid method")

        return vol

    def to_numpy(self) -> np.ndarray:
        """
        Create a numpy array containing the center and the generator matrix.
        Center: array[:, 0]
        Generator: array[:, 1:]
        """
        return np.concatenate((self._c, self._G), axis=1)

    def to_vertices(self) -> np.ndarray:
        # extract object data
        G = self._G
        c = self._c
        n = G.shape[0]

        # compute vertices of the parallelotope
        vert = np.array(list(itertools.product([-1, 1], repeat=n)))
        V = c + np.dot(G[:, :n], vert.T)

        # loop over all remaining generators
        for i in range(n, G.shape[1]):
            # extract current generator
            g = G[:, i]

            # compute potential vertices with minkowski sum
            V = np.hstack([V + g[:, np.newaxis], V - g[:, np.newaxis]])

            # compute ConvexHull
            hull = ConvexHull(V.T)

            # remove redundant vertices
            V = V[:, hull.vertices]

        return V
        # hull = ConvexHull(V.T)
        # return V[:, hull.vertices]

    def get_boundary_points(self, n: Union[int, None] = None, force_new: bool = False) -> np.ndarray:
        """
        Get boundary points of the zonotope in n random directions.
        If they have not yet been precoputed, n points are now precomputed.

        Args:
            n (Union[int, None], optional): Number of points, randomly sampled from precomputed points.
                If the number of available points is smaller than n, all points are returned.
                Defaults to None.
            force_new (bool, optional): Force precomputation of new points.

        Returns:
            np.ndarray: boundary points
        """

        if len(self.boundary_points) == 0 or force_new:
            self.boundary_points = []
            n = 10 * self.d * self.g if n is None else n
            for _ in range(n):
                direction = randn(self.d, 1)
                direction = direction / norm(direction)
                self.boundary_points.append(self.boundary_point(direction))

        if n is None or n >= len(self.boundary_points):
            return self.boundary_points
        else:
            np.random.choice(self.boundary_points, size=n, replace=False)

    def boundary_point(self, direction: np.ndarray, point: np.ndarray = None) -> np.ndarray:
        """
        Computes the boundary point of the zonotope in the given direction, starting from point.

        Args:
            direction (np.ndarray): The direction vector.
            point (np.ndarray, optional): The origin point for the direction line. Defaults to the center if it is None.

        Returns:
            np.ndarray: The boundary point in the given direction.
        """

        assert direction.ndim == 2 and direction.shape[1] == 1, "direction must be a column vector"

        point = self._c if point is None else point

        assert point.ndim == 2 and point.shape[1] == 1, "point must be a column vector"

        alpha = cp.Variable()
        gamma = cp.Variable((self.g, 1))

        # For some reason this is necessary
        direction = -direction

        constraints = [
            point + alpha * direction == self._c + self._G @ gamma,
            gamma <= 1,
            -gamma <= 1,
        ]
        objective = cp.Minimize(alpha)

        # Solve the problem
        problem = cp.Problem(objective, constraints)
        problem.solve(solver=cp.CLARABEL)

        # Check solver status
        if problem.status not in [cp.OPTIMAL]:
            raise ValueError("Solver did not converge to an optimal solution.")

        return point + alpha.value * direction

    def contains_point(self, point: np.ndarray) -> bool:
        return self.zonotope_norm(point) <= 1

    def zonotope_norm(self, direction: np.ndarray) -> float:
        """calculate zonotope norm in the given direction
        Based on: Kulmburg, A., Althoff, M., (2021): "On the co-NP-Completeness
            of the Zonotope Containment Problem", Eq. (8).
        """

        # if not len(direction.shape) == 2 and not direction.shape[1] == 1:
        #     print()
        # direction = direction[:, 0]
        assert (
            len(direction.shape) == 2 and direction.shape[1] == 1
        ), f"direction must be a vector. {direction} {direction.shape}"
        assert direction.shape[0] == self.d, "direction must have the same dimension as the zonotope"

        gamma = cp.Variable((self.g, 1))
        w = cp.Variable()

        constraints = [self._G @ gamma == direction - self._c, gamma <= w, -gamma <= w]

        # Objective
        objective = cp.Minimize(w)

        # Solve the problem
        problem = cp.Problem(objective, constraints)
        problem.solve(solver=cp.CLARABEL)

        # Check solver status
        if problem.status not in [cp.OPTIMAL]:
            raise ValueError("Solver did not converge to an optimal solution.")

        return w.value

    # def contains_points_batch(self, points: np.ndarray) -> np.ndarray:
    #     return self.zonotope_norm_batch(points) <= 1

    # def zonotope_norm_batch(self, direction: np.ndarray) -> np.ndarray:
    #     """Calculate zonotope norms for a batch of directions.
    #     directions: A numpy array of shape (n, d) representing n directions.
    #     """
    #     n, d = direction.shape
    #     assert d == self.d, "Each direction must have the same dimension as the zonotope"

    #     gamma = cp.Variable((n, self.g))
    #     # w = cp.Variable((n, 1))
    #     w = cp.Variable(n)

    #     # Adjusting constraints for batch processing
    #     constraints = [
    #         cp.vstack([(self._G @ gamma[i])[np.newaxis, :] for i in range(n)]) == direction - self._c[:, 0],
    #         # gamma <= w
    #         # -gamma <= w
    #         cp.vstack([gamma[i] for i in range(n)]) <= cp.vstack([w[i] for i in range(n)]),
    #         cp.vstack([-gamma[i] for i in range(n)]) <= cp.vstack([w[i] for i in range(n)]),
    #     ]

    #     # Adjusting objective for batch processing
    #     objective = cp.Minimize(cp.sum(w))

    #     # Solve the problem
    #     problem = cp.Problem(objective, constraints)
    #     problem.solve(solver=cp.CLARABEL)

    #     # Check solver status
    #     if problem.status not in [cp.OPTIMAL]:
    #         raise ValueError("Solver did not converge to an optimal solution.")

    #     return w.value


def zonotope_norm_batched(zonotopes: List[Zonotope], points: np.ndarray) -> np.ndarray:
    """
    Compute the norm of a batch of zonotopes in the direction from the center to a point.

    Args:
        zonotopes (List[Zonotope]): A list of zonotopes. All MUST have the same number of generators!
        direction (np.ndarray): The direction in which to compute the norm.

    Returns:
        np.ndarray: The computed norms for each zonotope in the batch.
    """

    n, d = points.shape
    assert len(zonotopes) == n, "Number of zonotopes must match the batch size"
    assert d == zonotopes[0].d, "Each direction must have the same dimension as the zonotope"

    g = zonotopes[0].g

    # gammas = [cp.Variable(z.g) for z in zonotopes]
    gammas = cp.Variable((n, g))

    w = cp.Variable(n)

    # Adjusting constraints for batch processing
    constraints = [
        cp.vstack([(zonotopes[i].G @ gammas[i]) for i in range(n)])
        == cp.vstack([points[i] - zonotopes[i].c[:, 0] for i in range(n)]),
        cp.vstack([gammas[i] for i in range(n)]) <= cp.vstack([w[i] for i in range(n)]),
        cp.vstack([-gammas[i] for i in range(n)]) <= cp.vstack([w[i] for i in range(n)]),
    ]

    # Adjusting objective for batch processing
    objective = cp.Minimize(cp.sum(w))

    # Solve the problem
    problem = cp.Problem(objective, constraints)
    problem.solve(solver=cp.CLARABEL)

    # Check solver status
    if problem.status not in [cp.OPTIMAL]:
        raise ValueError("Solver did not converge to an optimal solution.")

    return w.value


def zonotope_contains_batch(zonotopes: List[Zonotope], points: np.ndarray) -> np.ndarray:
    return zonotope_norm_batched(zonotopes, points) <= 1


if __name__ == "__main__":
    zono = Zonotope(np.array([[1, 0.5, -0.5], [0, 0.8, 0.25]]))
    zono = 1.8 * zono.map(np.array([[0.5, 0], [0, -1]]))
    print(f"Exact volume: {zono.volume():.2f}")
    print(f"Approximate volume: {zono.volume_approx():.2f}")

    d = np.array([[1], [-1]])
    print(f"Boundary point in direction {d}:\n{zono.boundary_point(d)}")

    zono.plot("red")
