from functools import partial

from heapq import heapify, heappop, heappush

from jax import Array
from jax import jit
import jax.numpy as jnp
import jax.random as random
from jax.lax import cond
from jax.scipy.stats import norm as normal_dist
from jax.experimental.ode import odeint
from .gaussian_density_ratios import IsotropicGaussianDensityRatio
from .gaussian_density_ratios import OneDimensionalTruncatableGaussianDensityRatio
from greedy_rejection_process.util import trunc_gumbel


# =============================================================================
# =============================================================================
#
# Slow GPRS
#
# =============================================================================
# =============================================================================

@partial(jit, static_argnames=['gauss_dr'])
def slow_gauss_loop_body_with_inv_stretch(k, key, log_time, gauss_dr):
  key = random.fold_in(key, k)
  log_time_key, x_key = random.split(key, num=2)

  log_time = -trunc_gumbel(log_time_key, shape=(), loc=0., bound=-log_time)
  x = gauss_dr.sample_p(x_key, shape=())

  h = odeint(gauss_dr.inv_stretch_ode, jnp.array(0.), jnp.array([0., jnp.exp(log_time)]))[-1]

  return x, h, log_time


def slow_gauss_encoder_v2(seed, gauss_dr, max_iter=1_000):

  key = random.PRNGKey(seed)

  log_time = -jnp.inf

  for k in range(max_iter):
    x, h, log_time = slow_gauss_loop_body_with_inv_stretch(k, key, log_time, gauss_dr)

    if h < gauss_dr.ratio(x):
      return x, k

  else:
    raise ValueError("did not terminate!")


# =============================================================================
# =============================================================================
#
# Slow Parallel GPRS
#
# =============================================================================
# =============================================================================


@partial(jit, static_argnames=['gauss_dr', 'n_branches'])
def parallel_slow_gauss_loop_body(k, key, log_times, gauss_dr, n_branches, gumbel_loc):
  key = random.fold_in(key, k)

  log_times_key, xs_key = random.split(key, num=2)

  log_times = -trunc_gumbel(log_times_key, shape=(n_branches,), loc=gumbel_loc, bound=-log_times)
  xs = gauss_dr.sample_p(xs_key, shape=(n_branches,))

  sorted_time_inds = jnp.argsort(log_times)
  inv_sorted_time_inds = jnp.argsort(sorted_time_inds)

  sorted_log_times = log_times[sorted_time_inds]

  sorted_hs = odeint(gauss_dr.inv_stretch_ode, jnp.array(0.), jnp.concatenate([jnp.zeros(1), jnp.exp(sorted_log_times)]))[1:]
  hs = sorted_hs[inv_sorted_time_inds]

  # If there are any points below the graph we keep their arrival times to improve the upper bound
  # otherwise set them to infinity, so that they don't improve the bound
  candidate_log_times = jnp.where(hs < gauss_dr.ratio(xs), log_times, jnp.inf)
  sorted_candidate_time_inds = jnp.argsort(candidate_log_times)

  earliest_current_log_arrival = candidate_log_times[sorted_candidate_time_inds[0]]

  branch_index = sorted_candidate_time_inds[0]
  x_candidate = xs[branch_index]

  return x_candidate, branch_index, earliest_current_log_arrival, log_times


def parallel_slow_gauss_encoder(seed, gauss_dr, log2_num_parallel, max_iter=1_000):
  """
  log2_num_parrallel: int - if this number is K, then we instantiate 2^K parallel searches,
  so that the coding cost for the choic of the branch is K bits.
  """
  
  n_branches = 2**log2_num_parallel
  gumbel_loc = -log2_num_parallel * jnp.log(2.)

  key = random.PRNGKey(seed)

  log_times = jnp.ones(n_branches) * -jnp.inf

  # Arrival time of the candidate
  log_arrival_time_upper_bound = jnp.inf

  x_candidate = None
  branch_index = None

  for k in range(max_iter):

    new_x_candidate, new_branch_index, earliest_current_log_arrival, log_times = parallel_slow_gauss_loop_body(k, key, log_times, gauss_dr, n_branches, gumbel_loc) 

    if earliest_current_log_arrival > log_arrival_time_upper_bound:
      return x_candidate, k - 1, branch_index

    log_arrival_time_upper_bound = earliest_current_log_arrival
    x_candidate = new_x_candidate
    branch_index = new_branch_index

  else:
    raise ValueError("did not terminate!")


# =============================================================================
# =============================================================================
#
# Split-on-sample GPRS without efficient heap index
#
# =============================================================================
# =============================================================================


@partial(jit, static_argnames=['gauss_dr'])
def sac_gauss_encoder_loop_body_with_inv_stretch(base_key: random.PRNGKey,
                               k: int,
                               bounds: Array,
                               log_time0: Array,
                               log_h0: Array,
                               gauss_dr: IsotropicGaussianDensityRatio):
    print(f"tracing")
    key = random.fold_in(base_key, k)
    log_time_key, u_key = random.split(key, num=2)

    bound_size = bounds[1] - bounds[0]

    log_time = -trunc_gumbel(log_time_key, shape=(), loc=jnp.log(bound_size), bound=-log_time0)
    u = random.uniform(u_key, shape=())
    u = bounds[0] + bound_size * u

    x = normal_dist.ppf(u, 0., 1.)

    log_r = gauss_dr.log_ratio(x)

    log_h = cond(jnp.all(log_time0 == -jnp.inf),
                 lambda: jnp.log(odeint(gauss_dr.inv_stretch_ode, 0., jnp.array([0., jnp.exp(log_time)]))[-1]),
                 lambda: odeint(gauss_dr.log_inv_stretch_ode_log_time, log_h0, jnp.array([log_time0, log_time]))[-1])

    return log_time, x, u, log_r, log_h, bound_size


def sac_gauss_encoder_v2(seed: int,
                      gauss_dr: IsotropicGaussianDensityRatio,
                      max_iter: int = 100) -> tuple[Array, int]:

  log_time = -jnp.inf
  log_h = -jnp.inf
  base_key = random.PRNGKey(seed)

  bounds = jnp.array([0., 1.])

  for k in range(1, max_iter + 1):

    log_time, x, u, log_r, log_h, bound_size = sac_gauss_encoder_loop_body_with_inv_stretch(base_key, k, bounds, log_time, log_h, gauss_dr)
    
    if log_h < log_r:
      return x, k, -jnp.log2(bound_size)

    if x < gauss_dr.r_loc:
      bounds = jnp.array([u, bounds[1]])
    else:
      bounds = jnp.array([bounds[0], u])
  
  else:
    raise ValueError('did not terminate!')


# =============================================================================
# =============================================================================
#
# Split-on-sample GPRS *with* efficient heap index
#
# =============================================================================
# =============================================================================

@jit
def is_interval_intersection(a1, b1, a2, b2):
  """returns true when two intervals (a1, b1) and (a2, b2) intersect"""
  return jnp.maximum(a1, a2) - jnp.minimum(b1, b2) <= 0.


@jit
def sac_gauss_encoder_efficient_sampler_loop_body(base_key: random.PRNGKey,
                               heap_index: int,
                               bounds: Array,
                               log_time0: Array):
    print(f"tracing sampler")
    key = random.fold_in(base_key, heap_index)
    log_time_key, u_key = random.split(key, num=2)

    bound_size = bounds[1] - bounds[0]

    log_time = -trunc_gumbel(log_time_key, shape=(), loc=jnp.log(bound_size), bound=-log_time0)
    u = random.uniform(u_key, shape=())
    u = bounds[0] + bound_size * u

    x = normal_dist.ppf(u, 0., 1.)

    return log_time, x, u


@partial(jit, static_argnames=['gauss_dr'])
def sac_gauss_encoder_efficient_integrator_loop_body(
                               x: Array,
                               log_time: Array,
                               gauss_dr: IsotropicGaussianDensityRatio):
    print(f"tracing integrator")
    log_r = gauss_dr.log_ratio(x)
    log_h = jnp.log(odeint(gauss_dr.inv_stretch_ode, 0., jnp.array([0., jnp.exp(log_time)]))[-1])

    return log_r, log_h


def efficient_sac_gauss_encoder(seed: int,
                                gauss_dr: IsotropicGaussianDensityRatio,
                                max_iter: int = 100) -> tuple[Array, int]:

  heap = []
  heapify(heap)

  log_time = -jnp.inf
  log_h = -jnp.inf
  base_key = random.PRNGKey(seed)

  bounds = jnp.array([0., 1.])
  dyadic_bounds = jnp.array([0., 1.])

  log_time, x, u = sac_gauss_encoder_efficient_sampler_loop_body(base_key, 1, dyadic_bounds, log_time)

  heappush(heap, (log_time, x, u, dyadic_bounds, 1))

  for k in range(1, max_iter + 1):

    # find next arrival
    for _ in range(max_iter):
      
      log_time, x, u, dyadic_bounds, heap_index = heappop(heap)

      midpoint = (dyadic_bounds[0] + dyadic_bounds[1]) / 2.

      if is_interval_intersection(dyadic_bounds[0], midpoint, bounds[0], bounds[1]):
        left_heap_index = 2 * heap_index
        left_dyadic_bounds = jnp.array([dyadic_bounds[0], midpoint])
        left_log_time, left_x, left_u = sac_gauss_encoder_efficient_sampler_loop_body(
          base_key, left_heap_index, left_dyadic_bounds, log_time)

        heappush(heap, (left_log_time, left_x, left_u, left_dyadic_bounds, left_heap_index))
      
      if is_interval_intersection(midpoint, dyadic_bounds[1], bounds[0], bounds[1]):
        right_heap_index = 2 * heap_index + 1
        right_dyadic_bounds = jnp.array([midpoint, dyadic_bounds[1]])
        right_log_time, right_x, right_u = sac_gauss_encoder_efficient_sampler_loop_body(
          base_key, right_heap_index, right_dyadic_bounds, log_time)

        heappush(heap, (right_log_time, right_x, right_u, right_dyadic_bounds, right_heap_index))

      if bounds[0] < u < bounds[1]:
        break

    else:
      raise ValueError("Inner loop didn't terminate")

    log_r, log_h = sac_gauss_encoder_efficient_integrator_loop_body(x, log_time, gauss_dr)
    
    if log_h < log_r:
      return x, heap_index

    if x < gauss_dr.r_loc:
      bounds = jnp.array([u, bounds[1]])
    else:
      bounds = jnp.array([bounds[0], u])
  
  else:
    raise ValueError('did not terminate!')


# =============================================================================
# =============================================================================
#
# Dyadic GPRS
#
# =============================================================================
# =============================================================================


@partial(jit, static_argnames=['gauss_dr'])
def ggrs_gauss_encoder_loop_body_with_inv_stretch(base_key: random.PRNGKey,
                               k: int,
                               bounds: Array,
                               log_time0: Array,
                               log_h0: Array,
                               gauss_dr: OneDimensionalTruncatableGaussianDensityRatio):
  print(f"tracing")
  key = random.fold_in(base_key, k)
  log_time_key, u_key, b_key = random.split(key, num=3)

  bound_size = bounds[1] - bounds[0]
  bound_center = (bounds[0] + bounds[1]) / 2.
  bound_left = jnp.array([bounds[0], bound_center])
  bound_right = jnp.array([bound_center, bounds[1]])

  log_time = -trunc_gumbel(log_time_key, shape=(), loc=jnp.log(bound_size), bound=-log_time0)
  u = random.uniform(u_key, shape=())
  u = bounds[0] + bound_size * u

  x = normal_dist.ppf(u, 0., 1.)

  log_r = gauss_dr.log_ratio(x)

  lower = normal_dist.ppf(bounds[0], 0., 1.)
  upper = normal_dist.ppf(bounds[1], 0., 1.)

  inv_stretch_ode = partial(gauss_dr.inv_stretch_ode, lower=lower, upper=upper)

  log_h = jnp.log(odeint(inv_stretch_ode, 0., jnp.array([0., jnp.exp(log_time)]))[-1])

  log_norm_const = gauss_dr.log_1m_lower_p_mass(log_h, lower=lower, upper=upper, log=True)
  log_right_prob = gauss_dr.log_1m_lower_p_mass(log_h, 
                                                lower=normal_dist.ppf(bound_center, 0., 1.), 
                                                upper=upper, 
                                                log=True)
  # Normalize the probability
  log_cond_right_prob = log_right_prob - log_norm_const

  b = random.bernoulli(b_key, jnp.exp(log_cond_right_prob)).astype(jnp.int32)

  bounds = jnp.array([bound_left, bound_right])[b]

  return log_time, x, log_r, log_h, bound_size, bounds, log_cond_right_prob


def binary_gauss_encoder_v2(seed: int,
                         gauss_dr: OneDimensionalTruncatableGaussianDensityRatio,
                         max_iter: int = 100) -> tuple[Array, int]:
  log_time = -jnp.inf
  log_h = -jnp.inf
  base_key = random.PRNGKey(seed)

  bounds = jnp.array([0., 1.])

  lcrps = []

  for k in range(1, max_iter + 1):
    log_time, x, log_r, log_h, bound_size, bounds, lcrp = ggrs_gauss_encoder_loop_body_with_inv_stretch(base_key, k, bounds, log_time, log_h, gauss_dr)

    lcrps.append(lcrp)

    if log_h < log_r:
      return x, k, -jnp.log2(bound_size), jnp.array(lcrps)


# =============================================================================
# =============================================================================
#
# Heap index conversion
#
# =============================================================================
# =============================================================================


def uniform_heap_index_conversion(seed: int,
                                  target_heap_index: int):
  heap = []
  heapify(heap)

  left_target_bound = 0.
  right_target_bound = 1.

  left_sim_bound = 0.
  right_sim_bound = 1.

  sim_heap_index = 1

  directions = []

  while target_heap_index != 1:
    target_heap_index, direction = jnp.divmod(target_heap_index, 2)
    directions.append(int(direction))

  directions = directions[::-1] + [jnp.nan]

  key = random.PRNGKey(seed)
  key, t_key, u_key = random.split(key, num=3)

  u = random.uniform(u_key)
  t = random.exponential(t_key)

  heappush(heap, (t, u, left_sim_bound, right_sim_bound, sim_heap_index))

  us = []
  ts = []
  his = []

  for direction in directions:
    for j in range(100):
      t, u, left_sim_bound, right_sim_bound, sim_heap_index = heappop(heap)

      us.append(u)
      ts.append(t)
      his.append(sim_heap_index)
      midpoint = (left_sim_bound + right_sim_bound) / 2.

      if is_interval_intersection(left_sim_bound, midpoint, left_target_bound, right_target_bound):
        left_heap_index = 2 * sim_heap_index
        left_bound_size = midpoint - left_sim_bound

        left_key = random.fold_in(key, left_heap_index)
        left_t_key, left_u_key = random.split(left_key)

        left_delta = random.exponential(left_t_key) / left_bound_size
        left_t = t + left_delta

        left_u = left_sim_bound + left_bound_size * random.uniform(left_u_key)
        
        heappush(heap, (left_t, left_u, left_sim_bound, midpoint, left_heap_index))

      if is_interval_intersection(midpoint, right_sim_bound, left_target_bound, right_target_bound):
        right_heap_index = 2 * sim_heap_index + 1
        right_bound_size = right_sim_bound - midpoint

        right_key = random.fold_in(key, right_heap_index)
        right_t_key, right_u_key = random.split(right_key)

        right_delta = random.exponential(right_t_key) / right_bound_size
        right_t = t + right_delta

        right_u = midpoint + right_bound_size * random.uniform(right_u_key)
        
        heappush(heap, (right_t, right_u, midpoint, right_sim_bound, right_heap_index))

      if left_target_bound <= u < right_target_bound:
        break

    midpoint = u #(right_target_bound + left_target_bound) / 2.

    if direction == 0:
      right_target_bound = midpoint
    elif direction == 1:
      left_target_bound = midpoint

  return t, u, sim_heap_index, jnp.array(ts), jnp.array(us), jnp.array(his)