import math
import os
import time
from functools import partial

import gym.spaces

from rsm_utils import (
    pretty_time,
    pretty_number,
    get_pmass_grid,
    compute_expected_l,
    v_contains,
    v_intersect,
    jv_contains,
    jv_intersect,
)

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import jax
import tensorflow as tf
import jax.numpy as jnp

from tqdm import tqdm
import numpy as np

# Tensorflow should not allocate any GPU memory
tf.config.experimental.set_visible_devices([], "GPU")


def get_n_for_bound_computation(obs_dim):
    if obs_dim == 2:
        n = 200
    elif obs_dim == 3:
        n = 60
    else:
        n = 30
    return n


class TrainBuffer:
    def __init__(self, max_size=20_000_000):
        """Counterexample training buffer"""
        self.s = []
        self.max_size = max_size
        self._cached_ds = None

    def append(self, s):
        if self.max_size is not None and len(self) > self.max_size:
            # If the buffer is full we don't add any more items
            return
        self.s.append(np.array(s))
        self._cached_ds = None

    def extend(self, lst):
        for s in lst:
            self.append(s)

    def __len__(self):
        if len(self.s) == 0:
            return 0
        return sum([s.shape[0] for s in self.s])

    @property
    def in_dim(self):
        return len(self.s[0])

    def as_tfds(self, batch_size=32):
        if self._cached_ds is not None:
            return self._cached_ds
        train_s = np.concatenate(self.s, axis=0)
        train_s = np.random.default_rng().permutation(train_s)
        train_ds = tf.data.Dataset.from_tensor_slices(train_s)
        train_ds = train_ds.shuffle(50000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
        self._cached_ds = train_ds
        return train_ds


class RSMVerifier:
    def __init__(
        self,
        rsm_learner,
        env,
        batch_size,
        reach_prob,
        fail_check_fast,
        target_grid_size,
        streaming_mode=False,
        dataset_type="all",
        lip_cheat=1.0,
        norm="l1",
    ):
        """

        :param rsm_learner: RSM Learner module that maintains the neural networks
        :param env: Environment
        :param batch_size: Batch size used by the verifier (not used for learning)
        :param reach_prob: Reach-avoid probability (0 < r <= 1)
        :param fail_check_fast: Flag  (True/False) indicating whether to abort the verifier on the
            first found counterexample or whether to compute all counterexamples
        :param target_grid_size: Desired total number of cells of the grid
        :param streaming_mode: Not used
        :param dataset_type: Either "all" (equivalent to "soft"), or "hard". Whether to add
            all (hard and soft) or just hard counterexamples to the training buffer
        :param lip_cheat: Cheating factor to pretend we have a smaller Lipschitz bound than we actually have
            ( used for debugging)
        :param norm: "l1" or "linf"
        """
        self.learner = rsm_learner
        self.env = env
        self.norm = norm
        self.reach_prob = jnp.float32(reach_prob)
        self.fail_check_fast = fail_check_fast

        self.streaming_mode = streaming_mode
        self.batch_size = batch_size
        self.block_size = 8 * batch_size
        self.refinement_enabled = True
        self.lip_cheat = lip_cheat

        target_grid_size = target_grid_size
        self.grid_size = int(math.pow(target_grid_size, 1 / env.observation_dim))

        # Precompute probability mass grid for the expectation computation
        # self.pmass_n = (
        #     10 if self.env.observation_dim == 2 else 6
        # )  # number of sums for the expectation computation
        self._cached_pmass_grid = self.learner._cached_pmass_grid
        self._cached_filtered_grid = None
        self._debug_violations = None
        self.dataset_type = dataset_type
        self.hard_constraint_violation_buffer = None
        self.train_buffer = TrainBuffer()
        self._perf_stats = {
            "apply": 0.0,
            "loop": 0.0,
        }
        self.v_get_grid_item = jax.vmap(
            self.get_grid_item, in_axes=(0, None), out_axes=0
        )
        self._grid_shuffle_rng = jax.random.PRNGKey(333)

        steps = (
            self.env.observation_space.high - self.env.observation_space.low
        ) / self.grid_size
        if self.norm == "l1":
            delta = 0.5 * np.sum(steps)
            # l1-norm of the half the grid cell (=l1 distance from center to corner)
        elif self.norm == "linf":
            delta = 0.5 * np.max(steps)
        else:
            raise ValueError("Should not happen")

    def prefill_train_buffer(self):
        """
        Fills the train buffer with a coarse grid
        """
        buffer_size = 500_000  # 500k items
        n = int(math.pow(buffer_size, 1 / self.env.observation_dim))
        state_grid, _, _ = self.get_unfiltered_grid(n=n)
        self.train_buffer.append(np.array(state_grid))
        return (
            self.env.observation_space.high[0] - self.env.observation_space.low[0]
        ) / n

    @partial(jax.jit, static_argnums=(0, 2))
    def get_grid_item(self, idx, n):
        """
        Maps an integer cell index and grid size to the bounds of the grid cell
        :param idx: Integer between 0 and n**obs_dim
        :param n: Grid size
        :return: jnp.ndarray corresponding to the center of the idx cell
        """
        dims = self.env.observation_dim
        target_points = [
            jnp.linspace(
                self.env.observation_space.low[i],
                self.env.observation_space.high[i],
                n,
                retstep=True,
                endpoint=False,
            )
            for i in range(dims)
        ]
        target_points, retsteps = zip(*target_points)
        target_points = list(target_points)
        for i in range(dims):
            target_points[i] = target_points[i] + 0.5 * retsteps[i]
        inds = []
        for i in range(dims):
            inds.append(idx % n)
            idx = idx // n
        return jnp.array([target_points[i][inds[i]] for i in range(dims)])

    def get_refined_grid_template(self, steps, n):
        """
        Refines a grid with resolution delta into n smaller grid cells.
        The returned template can be added to cells to create the smaller grid
        """
        dims = self.env.observation_dim
        grid, new_steps = [], []
        for i in range(dims):
            samples, new_step = jnp.linspace(
                -0.5 * steps[i],
                +0.5 * steps[i],
                n,
                endpoint=False,
                retstep=True,
            )
            grid.append(samples.flatten() + new_step * 0.5)
            new_steps.append(new_step)
        grid = jnp.meshgrid(*grid)
        grid = jnp.stack(grid, axis=1)
        return grid, np.array(new_steps)

    def get_unfiltered_grid(self, n=100):
        dims = self.env.observation_dim
        grid, steps = [], []
        for i in range(dims):
            samples, step = np.linspace(
                self.env.observation_space.low[i],
                self.env.observation_space.high[i],
                n,
                endpoint=False,
                retstep=True,
            )
            grid.append(samples)
            steps.append(step)
        grid = np.meshgrid(*grid)
        grid_lb = [x.flatten() for x in grid]
        grid_ub = [grid_lb[i] + steps[i] for i in range(dims)]
        grid_centers = [grid_lb[i] + steps[i] / 2 for i in range(dims)]

        grid_lb = np.stack(grid_lb, axis=1)
        grid_ub = np.stack(grid_ub, axis=1)
        grid_centers = np.stack(grid_centers, axis=1)
        return grid_centers, grid_lb, grid_ub

    def get_filtered_grid(self, n=100):
        if self._cached_filtered_grid is not None:
            if n == self._cached_filtered_grid:
                print(f"Using cached grid of n={n} ", end="", flush=True)
                return self._cached_filtered_grid[1], self._cached_filtered_grid[2]
            else:
                self._cached_filtered_grid = None
        import gc

        gc.collect()
        size_t = 4 * (n**self.env.observation_dim)
        print(
            f"Allocating grid of n={n} ({pretty_number(size_t)} bytes)",
            end="",
            flush=True,
        )
        dims = self.env.observation_space.shape[0]
        grid, steps = [], []
        for i in range(dims):
            samples, step = np.linspace(
                self.env.observation_space.low[i],
                self.env.observation_space.high[i],
                n,
                endpoint=False,
                retstep=True,
            )
            grid.append(samples)
            steps.append(step)
        print(f" meshgrid with steps={steps} ", end="", flush=True)
        grid = np.meshgrid(*grid)
        grid = [grid[i].flatten() + steps[i] / 2 for i in range(dims)]
        grid = np.stack(grid, axis=1)

        mask = np.zeros(grid.shape[0], dtype=np.bool_)
        for target_space in self.env.target_spaces:
            contains = v_contains(target_space, grid)
            mask = np.logical_or(
                mask,
                contains,
            )
        filtered_grid = grid[np.logical_not(mask)]
        steps = np.array(steps)
        self._cached_filtered_grid = (n, filtered_grid, steps)
        return filtered_grid, steps

    def compute_bound_init(self, n):
        """
        Computes the lower and upper bound of the RSM on the initial state set
        :param n: Discretization (a too high value will cause a long runtime or out-of-memory errors)
        """
        _, grid_lb, grid_ub = self.get_unfiltered_grid(n)

        mask = np.zeros(grid_lb.shape[0], dtype=np.bool_)
        # Include if the grid cell intersects with any of the init spaces
        for init_space in self.env.init_spaces:
            intersect = v_intersect(init_space, grid_lb, grid_ub)
            mask = np.logical_or(
                mask,
                intersect,
            )
        # Exclude if both lb AND ub are in the target set
        for target_space in self.env.target_spaces:
            contains_lb = v_contains(target_space, grid_lb)
            contains_ub = v_contains(target_space, grid_ub)
            mask = np.logical_and(
                mask, np.logical_not(np.logical_and(contains_lb, contains_ub))
            )

        grid_lb = grid_lb[mask]
        grid_ub = grid_ub[mask]
        assert grid_ub.shape[0] > 0

        return self.compute_bounds_on_set(grid_lb, grid_ub)

    def compute_bound_unsafe(self, n):
        """
        Computes the lower and upper bound of the RSM on the unsafe state set
        :param n: Discretization (a too high value will cause a long runtime or out-of-memory errors)
        """
        _, grid_lb, grid_ub = self.get_unfiltered_grid(n)

        # Include only if either lb OR ub are in one of the unsafe sets
        mask = np.zeros(grid_lb.shape[0], dtype=np.bool_)
        for unsafe_spaces in self.env.unsafe_spaces:
            intersect = v_intersect(unsafe_spaces, grid_lb, grid_ub)
            mask = np.logical_or(
                mask,
                intersect,
            )
        grid_lb = grid_lb[mask]
        grid_ub = grid_ub[mask]
        assert grid_ub.shape[0] > 0
        return self.compute_bounds_on_set(grid_lb, grid_ub)

    def compute_bound_domain(self, n):
        """
        Computes the lower and upper bound of the RSM on the entire state space except the target space
        :param n: Discretization (a too high value will cause a long runtime or out-of-memory errors)
        """
        _, grid_lb, grid_ub = self.get_unfiltered_grid(n)

        # Exclude if both lb AND ub are in the target set
        mask = np.zeros(grid_lb.shape[0], dtype=np.bool_)
        for target_space in self.env.target_spaces:
            contains_lb = v_contains(target_space, grid_lb)
            contains_ub = v_contains(target_space, grid_ub)
            mask = np.logical_or(
                mask,
                np.logical_and(contains_lb, contains_ub),
            )
        mask = np.logical_not(
            mask
        )  # now we have all cells that have both lb and both in a target -> invert for filtering
        grid_lb = grid_lb[mask]
        grid_ub = grid_ub[mask]
        assert grid_ub.shape[0] > 0
        return self.compute_bounds_on_set(grid_lb, grid_ub)

    def compute_bounds_on_set(self, grid_lb, grid_ub):
        """
        Computes the lower and upper bound of the RSM with respect to the given discretization
        """
        global_min = jnp.inf
        global_max = jnp.NINF
        for i in range(int(np.ceil(grid_ub.shape[0] / self.batch_size))):
            start = i * self.batch_size
            end = np.minimum((i + 1) * self.batch_size, grid_ub.shape[0])
            batch_lb = jnp.array(grid_lb[start:end])
            batch_ub = jnp.array(grid_ub[start:end])
            lb, ub = self.learner.v_ibp.apply(
                self.learner.v_state.params, [batch_lb, batch_ub]
            )
            global_min = jnp.minimum(global_min, jnp.min(lb))
            global_max = jnp.maximum(global_max, jnp.max(ub))
        return float(global_min), float(global_max)

    # @partial(jax.jit, static_argnums=(0,))
    # def compute_expected_l(self, params, s, a, pmass, batched_grid_lb, batched_grid_ub):
    #     """
    #     Compute kernel (jit compiled) that computes an upper bounds on the expected value of L(s next)
    #     """
    #     deterministic_s_next = self.env.v_next(s, a)
    #     batch_size = s.shape[0]
    #     ibp_size = batched_grid_lb.shape[0]
    #     obs_dim = self.env.observation_dim
    #
    #     # Broadcasting happens here, that's why we don't do directly vmap (although it's probably possible somehow)
    #     deterministic_s_next = deterministic_s_next.reshape((batch_size, 1, obs_dim))
    #     batched_grid_lb = batched_grid_lb.reshape((1, ibp_size, obs_dim))
    #     batched_grid_ub = batched_grid_ub.reshape((1, ibp_size, obs_dim))
    #
    #     batched_grid_lb = batched_grid_lb + deterministic_s_next
    #     batched_grid_ub = batched_grid_ub + deterministic_s_next
    #
    #     batched_grid_lb = batched_grid_lb.reshape((-1, obs_dim))
    #     batched_grid_ub = batched_grid_ub.reshape((-1, obs_dim))
    #     lb, ub = self.learner.v_ibp.apply(params, [batched_grid_lb, batched_grid_ub])
    #     ub = ub.reshape((batch_size, ibp_size))
    #
    #     pmass = pmass.reshape((1, ibp_size))  # Boradcast to batch size
    #     exp_terms = pmass * ub
    #     expected_value = jnp.sum(exp_terms, axis=1)
    #     return expected_value

    @partial(jax.jit, static_argnums=(0,))
    def _check_dec_batch(self, l_params, p_params, f_batch, l_batch, K):
        """
        Compute kernel (jit compiled) that checks if a batch of grid cells violate the decrease conditions
        """
        a_batch = self.learner.p_state.apply_fn(p_params, f_batch)
        pmass, batched_grid_lb, batched_grid_ub = self._cached_pmass_grid
        # e = self.compute_expected_l(
        e = compute_expected_l(
            self.env,
            self.learner.v_ibp.apply,
            l_params,
            f_batch,
            a_batch,
            pmass,
            batched_grid_lb,
            batched_grid_ub,
        )
        e = e.flatten()
        l_batch = l_batch.flatten()

        decrease = e + K - l_batch
        violating_indices = decrease >= 0
        v = violating_indices.astype(jnp.int32).sum()
        hard_violating_indices = e - l_batch >= 0
        hard_v = hard_violating_indices.astype(jnp.int32).sum()
        decay = (e + K) / l_batch
        return (
            v,
            violating_indices,
            hard_v,
            hard_violating_indices,
            jnp.max(decrease),
            jnp.max(decay),
        )

    @partial(jax.jit, static_argnums=(0,))
    def normalize_rsm(self, l, ub_init, domain_min):
        """
        By normalizing the RSM using the global infimum of L and the infimum of L within the init set, we
        improve the Reach-avoid bounds of L slightly.
        """
        l = l - domain_min
        ub_init = ub_init - domain_min
        # now min = 0
        l = l / jnp.maximum(ub_init, 1e-6)
        # now init max = 1
        return l

    def check_dec_cond(self, lipschitz_k):
        """
        This method checks if the decrease condition is fulfilled.
        How the grid is processed (block-wise or allocating the entire grid) is decided by the streaming_mode flag
        :param lipschitz_k: Lipschitz constant of the entire system (environment, policy, RSM)
        :return: Number of violating grid cells, and number of hard violating grid cells
        """
        if self.streaming_mode:
            return self.check_dec_cond_with_stream(lipschitz_k)
        else:
            return self.check_dec_cond_full(lipschitz_k)

    def check_dec_cond_full(self, lipschitz_k):
        """
        This method checks if the decrease condition is fulfilled by creating the allocating grid in memory first.
        This is fast but may require a lot of memory caused out-of-memory errors.
        If such error occur, consider streaming mode, which creates sub-blocks of the grid on-deman
        :param lipschitz_k: Lipschitz constant of the entire system (environment, policy, RSM)
        :return: Number of violating grid cells, and number of hard violating grid cells
        """
        dims = self.env.observation_dim
        grid_total_size = self.grid_size**dims

        verify_start_time = time.time()
        n = get_n_for_bound_computation(self.env.observation_dim)
        _, ub_init = self.compute_bound_init(n)
        domain_min, _ = self.compute_bound_domain(n)
        print(f"computed bounds done: {pretty_time(time.time()-verify_start_time)}")

        grid, steps = self.get_filtered_grid(self.grid_size)
        print(f"allocated grid done: {pretty_time(time.time()-verify_start_time)}")
        if self.norm == "l1":
            delta = 0.5 * np.sum(steps)
            # l1-norm of the half the grid cell (=l1 distance from center to corner)
        elif self.norm == "linf":
            delta = 0.5 * np.max(steps)
        else:
            raise ValueError("Should not happen")
        K = lipschitz_k * delta
        if self.lip_cheat < 1:
            print(f"Using Lipschitz cheat value {self.lip_cheat}")
            K = K * self.lip_cheat
        # This makes everything slow
        if self.fail_check_fast:
            # shuffle to have only the very violate the decreasee condition
            self._grid_shuffle_rng, rng = jax.random.split(self._grid_shuffle_rng)
            # grid = jax.random.shuffle(rng, grid)
            perm = np.random.default_rng().permutation(grid.shape[0])
            grid = grid[perm]
            # np.random.default_rng(int(rng[-1])).shuffle(grid)
            # np.random.default_rng(int(rng[-1])).shuffle(grid)
        number_of_cells = self.grid_size**self.env.observation_dim
        print(
            f"Checking GRID with {pretty_number(number_of_cells)} cells and K={K:0.3g}"
        )
        K = jnp.float32(K)

        violations = 0
        hard_violations = 0
        total_cells_processed = 0
        failed_fast = False
        max_decrease = jnp.NINF
        max_decay = jnp.NINF
        violation_buffer = []
        hard_violation_buffer = []

        print(f"loop start: {pretty_time(time.time()-verify_start_time)}")
        # block_size size should not be too large
        kernel_start_time = time.perf_counter()
        pbar = tqdm(total=grid.shape[0], unit="cells")
        for start in range(0, grid.shape[0], self.batch_size):
            end = min(start + self.batch_size, grid.shape[0])
            x_batch = jnp.array(grid[start:end])
            v_batch = self.learner.v_state.apply_fn(
                self.learner.v_state.params, x_batch
            ).flatten()
            # normalize the RSM to obtain slightly better values
            normalized_l_batch = self.normalize_rsm(v_batch, ub_init, domain_min)

            # Next, we filter the grid cells that are > 1/(1-p)
            if self.reach_prob < 1.0:
                less_than_p = normalized_l_batch - K < 1 / (1 - self.reach_prob)
                if jnp.sum(less_than_p.astype(np.int32)) == 0:
                    # If all cells are filtered -> can skip the expectation computation
                    pbar.update(end - start)
                    continue
                x_batch = x_batch[less_than_p]
                v_batch = v_batch[less_than_p]

            # Finally, we compute the expectation of the grid cell
            (
                v,
                violating_indices,
                hard_v,
                hard_violating_indices,
                decrease,
                decay,
            ) = self._check_dec_batch(
                self.learner.v_state.params,
                self.learner.p_state.params,
                x_batch,
                v_batch,
                K,
            )
            max_decrease = jnp.maximum(max_decrease, decrease)
            max_decay = jnp.maximum(max_decay, decrease)
            # Count the number of violations and hard violations
            hard_violations += hard_v
            violations += v
            if v > 0:
                violation_buffer.append(np.array(x_batch[violating_indices]))
            if hard_v > 0:
                hard_violation_buffer.append(np.array(x_batch[hard_violating_indices]))
            total_kernel_time = time.perf_counter() - kernel_start_time
            total_cells_processed = end
            kcells_per_sec = total_cells_processed / total_kernel_time / 1000
            pbar.update(end - start)
            pbar.set_description(
                f"{pretty_number(violations)}/{pretty_number(total_cells_processed)} cell violating @ {kcells_per_sec:0.1f} Kcells/s"
            )
            if self.fail_check_fast and violations > 0:
                failed_fast = True
                break
        pbar.close()
        print(f"loop ends: {pretty_time(time.time()-verify_start_time)}")
        if failed_fast:
            print(
                f"Failed fast after {pretty_number(total_cells_processed)}/{pretty_number(number_of_cells)} cells checked"
            )
        if len(violation_buffer) == 1:
            print(f"violation_buffer[0][0]:", violation_buffer[0][0])
        self.hard_constraint_violation_buffer = (
            None
            if len(hard_violation_buffer) == 0
            else np.concatenate([np.array(g) for g in hard_violation_buffer])
        )
        if self.dataset_type in ["all", "soft"]:
            self.train_buffer.extend(violation_buffer)
        elif self.dataset_type == "hard":
            self.train_buffer.extend(hard_violation_buffer)
        else:
            raise ValueError(f"Unknown dataset type {self.dataset_type}")
        print(
            f"Verified {pretty_number(total_cells_processed)} cells ({pretty_number(violations)} violations, {pretty_number(hard_violations)} hard) in {pretty_time(time.time()-verify_start_time)}"
        )

        if (
            self.refinement_enabled
            and hard_violations == 0
            and not failed_fast
            and violations > 0
            and len(violation_buffer) > 0
        ):
            print(
                f"Zero hard violations -> refinement of {pretty_number(grid_total_size)} soft violations"
            )
            refine_start = time.time()
            refinement_buffer = [np.array(g) for g in violation_buffer]
            refinement_buffer = np.concatenate(refinement_buffer)
            success, max_decrease, max_decay = self.refine_grid(
                refinement_buffer, lipschitz_k, steps, ub_init, domain_min
            )
            if success:
                print(
                    f"Refinement successful! (took {pretty_time(time.time()-refine_start)})"
                )
                return 0, 0, float(max_decrease), float(max_decay)
            else:
                print(
                    f"Refinement unsuccessful! (took {pretty_time(time.time()-refine_start)})"
                )

        return violations, hard_violations, float(max_decrease), float(max_decay)

    def check_dec_cond_with_stream(self, lipschitz_k):
        """
        This method checks if the decrease condition is fulfilled by created sub-grid block-wise and checking them
        This is slower but has a much smaller memory footprint than the full allocation method
        :param lipschitz_k: Lipschitz constant of the entire system (environment, policy, RSM)
        :return: Number of violating grid cells, and number of hard violating grid cells
        """
        dims = self.env.observation_dim
        grid_total_size = self.grid_size**dims

        verify_start_time = time.time()
        n = get_n_for_bound_computation(self.env.observation_dim)
        _, ub_init = self.compute_bound_init(n)
        domain_min, _ = self.compute_bound_domain(n)

        steps = (self.env.observation_space.high - self.env.observation_space.low) / (
            self.grid_size - 1
        )
        # l1-norm of the half the grid cell (=l1 distance from center to corner)
        if self.norm == "l1":
            delta = 0.5 * np.sum(steps)
            # l1-norm of the half the grid cell (=l1 distance from center to corner)
        elif self.norm == "linf":
            delta = 0.5 * np.max(steps)
        else:
            raise ValueError("Should not happen")
        K = lipschitz_k * delta
        max_decrease = jnp.NINF
        max_decay = jnp.NINF
        number_of_cells = self.grid_size**self.env.observation_dim
        print(
            f"Checking GRID with {pretty_number(number_of_cells)} cells and K={K:0.3g}"
        )
        K = jnp.float32(K)

        violations = 0
        hard_violations = 0
        violation_buffer = []
        hard_violation_buffer = []

        # block_size size should not be too large
        block_size = min(grid_total_size, self.block_size)
        kernel_start_time = time.perf_counter()
        total_cells = 0
        pbar = tqdm(total=grid_total_size // block_size, unit="blocks")
        failed_fast = False
        for block_id in range(0, grid_total_size, block_size):
            # Create array of indices of the grid cell in the current block
            idx = jnp.arange(block_id, min(grid_total_size, block_id + block_size))
            # Map indices -> centers of the cells
            sub_grid = self.v_get_grid_item(idx, self.grid_size)
            # Filter out grid cells that are inside in at least one target set
            contains = jnp.ones(sub_grid.shape[0], dtype=np.bool_)
            for target_space in self.env.target_spaces:
                c = jv_contains(target_space, sub_grid)
                contains = jnp.logical_or(c)
            sub_grid = sub_grid[jnp.logical_not(contains)]

            # We process the block in batches
            for start in range(0, sub_grid.shape[0], self.batch_size):
                end = min(start + self.batch_size, sub_grid.shape[0])
                x_batch = jnp.array(sub_grid[start:end])
                v_batch = self.learner.v_state.apply_fn(
                    self.learner.v_state.params, x_batch
                ).flatten()
                # normalize the RSM to obtain slightly better values
                normalized_l_batch = self.normalize_rsm(v_batch, ub_init, domain_min)

                # Next, we filter the grid cells that are > 1/(1-p)
                if self.reach_prob < 1.0:
                    less_than_p = normalized_l_batch - K < 1 / (1 - self.reach_prob)
                    if jnp.sum(less_than_p.astype(np.int32)) == 0:
                        # If all cells are filtered -> can skip the expectation computation
                        continue
                    x_batch = x_batch[less_than_p]
                    v_batch = v_batch[less_than_p]

                # Finally, we compute the expectation of the grid cell
                (
                    v,
                    violating_indices,
                    hard_v,
                    hard_violating_indices,
                    decrease,
                    decay,
                ) = self._check_dec_batch(
                    self.learner.v_state.params,
                    self.learner.p_state.params,
                    x_batch,
                    v_batch,
                    K,
                )
                max_decrease = jnp.maximum(max_decrease, decrease)
                max_decay = jnp.maximum(max_decay, decrease)
                # Count the number of violations and hard violations
                hard_violations += hard_v
                violations += v
                if v > 0:
                    violation_buffer.append(x_batch[violating_indices])
                if hard_v > 0:
                    hard_violation_buffer.append(x_batch[hard_violating_indices])
            pbar.update(1)
            total_kernel_time = time.perf_counter() - kernel_start_time
            total_cells += sub_grid.shape[0]
            kcells_per_sec = total_cells / total_kernel_time / 1000
            pbar.set_description(
                f"{pretty_number(violations)}/{pretty_number(total_cells)} cell violating @ {kcells_per_sec:0.1f} Kcells/s"
            )
            if self.fail_check_fast and violations > 0:
                failed_fast = True
                break
        pbar.close()

        self.train_buffer.extend(hard_violation_buffer)
        # self.train_buffer.extend(violation_buffer)
        print(
            f"Verified {pretty_number(grid_total_size)} cells ({pretty_number(violations)} violations, {pretty_number(hard_violations)} hard) in {pretty_time(time.time()-verify_start_time)}"
        )

        if (
            self.refinement_enabled
            and hard_violations == 0
            and violations > 0
            and len(violation_buffer) > 0
            and (not failed_fast)
        ):
            print(
                f"Zero hard violations -> refinement of {pretty_number(grid_total_size)} soft violations"
            )
            refine_start = time.time()
            refinement_buffer = [np.array(g) for g in violation_buffer]
            refinement_buffer = np.concatenate(refinement_buffer)
            success, max_decrease, max_decay = self.refine_grid(
                refinement_buffer, lipschitz_k, steps, ub_init, domain_min
            )
            if success:
                print(
                    f"Refinement successful! (took {pretty_time(time.time()-refine_start)})"
                )
                return 0, 0, float(max_decrease), float(max_decay)
            else:
                print(
                    f"Refinement unsuccessful! (took {pretty_time(time.time()-refine_start)})"
                )

        return violations, hard_violations, float(max_decrease), float(max_decay)

    def refine_grid(self, refinement_buffer, lipschitz_k, steps, ub_init, domain_min):
        n_dims = self.env.observation_dim

        n = 10
        if self.env.observation_dim > 2:
            n = 6
        template_batch, new_steps = self.get_refined_grid_template(steps, n)
        if self.norm == "l1":
            new_delta = 0.5 * np.sum(new_steps)
            # l1-norm of the half the grid cell (=l1 distance from center to corner)
        elif self.norm == "linf":
            new_delta = 0.5 * np.max(new_steps)
        else:
            raise ValueError("Should not happen")

        # Refinement template has an extra dimension we need to consider
        batch_size = self.batch_size // template_batch.shape[0]
        batch_size = max(1, batch_size // 32)  # avoid oom

        # print(f"lipschitz_k={lipschitz_k}")
        # print(f"current_delta={current_delta}")
        # print(f"new_delta={new_delta}")
        # new_delta = current_delta / (n - 1)
        K = jnp.float32(lipschitz_k * new_delta)
        template_batch = template_batch.reshape((1, -1, n_dims))
        max_decrease = jnp.NINF
        max_decay = jnp.NINF
        for i in tqdm(range(int(np.ceil(refinement_buffer.shape[0] / batch_size)))):
            start = i * batch_size
            end = np.minimum((i + 1) * batch_size, refinement_buffer.shape[0])
            s_batch = jnp.array(refinement_buffer[start:end])
            s_batch = s_batch.reshape((-1, 1, n_dims))
            r_batch = s_batch + template_batch
            r_batch = r_batch.reshape((-1, self.env.observation_dim))  # flatten

            l_batch = self.learner.v_state.apply_fn(
                self.learner.v_state.params, r_batch
            ).flatten()
            normalized_l_batch = self.normalize_rsm(l_batch, ub_init, domain_min)
            less_than_p = normalized_l_batch - K < 1 / (1 - self.reach_prob)
            if jnp.sum(less_than_p.astype(np.int32)) == 0:
                continue
            r_batch = r_batch[less_than_p]
            l_batch = l_batch[less_than_p]

            (
                v,
                violating_indices,
                hard_v,
                hard_violating_indices,
                decrease,
                decay,
            ) = self._check_dec_batch(
                self.learner.v_state.params,
                self.learner.p_state.params,
                r_batch,
                l_batch,
                K,
            )
            max_decrease = jnp.maximum(max_decrease, decrease)
            max_decay = jnp.maximum(max_decay, decay)
            if v > 0:
                return False, max_decrease, max_decay
        return True, max_decrease, max_decay