from typing import Callable

import numpy as np
import scipy
from pysimplicialcubature.simplicialcubature import integrateOnSimplex

from action_masking.rlsampling.sets.zonotope import Zonotope

Function = Callable[[np.ndarray], np.ndarray]


def geometric_integration(zonotope: Zonotope, function: Function, abs_error: float = 0.001) -> float:
    # TODO: Some estimates are over 1 or less than 0. For those the estAbsError can be very high.
    # Maybe it's a good idea to compute the taylor approximation of the gauss dist.

    vertices = zonotope.to_vertices()
    triangles = scipy.spatial.Delaunay(vertices.T)

    # A result of manual trial and error (if the integrateOnSimplex function fails often, increase the exponent!)
    max_evals = _get_max_evals(zonotope.g)

    # Just increase maxEvals!
    result = integrateOnSimplex(
        function, triangles.points[triangles.simplices], absError=abs_error, maxEvals=max_evals, tol=1e-4
    )

    while result["integral"] is None or result["integral"] == np.array([None]):
        max_evals += max_evals
        result = integrateOnSimplex(
            function, triangles.points[triangles.simplices], absError=abs_error, maxEvals=max_evals
        )

    return result["integral"]


def geometric_integration_gaussian(
    zonotope: Zonotope, dist: scipy.stats.multivariate_normal, abs_error: float = 0.001
) -> float:
    vertices = zonotope.to_vertices()
    triangles = scipy.spatial.Delaunay(vertices.T)

    max_evals = _get_max_evals(zonotope.g)

    # Just increase maxEvals!
    result = integrateOnSimplex(
        dist.pdf, triangles.points[triangles.simplices], absError=abs_error, maxEvals=max_evals, tol=1e-4
    )

    while result["integral"] is None or result["integral"] == np.array([None]):
        print("Doubling")
        max_evals += max_evals
        result = integrateOnSimplex(
            dist.pdf, triangles.points[triangles.simplices], absError=abs_error, maxEvals=max_evals
        )

    return np.clip(result["integral"], 0.001, 1)


def _get_max_evals(g: int) -> int:
    # A result of manual trial and error
    return 1.7**g * 1_000
