from functools import partial

from jax import Array
from jax import jit
import jax.numpy as jnp
import jax.random as random

from .util import trunc_gumbel
from .util import logsubexp
from .uniform_helpers import ModuUniformDensityRatio


@partial(jit, static_argnames=['modu_dr'])
def slow_modu_encoder_loop_body(k, key, log_time, modu_dr):

  key = random.fold_in(key, k)
  log_time_key, u_key = random.split(key, num=2)

  log_time = -trunc_gumbel(log_time_key, shape=(), loc=0., bound=-log_time)
  u = random.uniform(u_key, shape=())
  r = jnp.exp(modu_dr.log_ratio(u))

  f = modu_dr.stretch(r)

  return u, f, log_time


def slow_modu_encoder(seed: int, 
                      modu_dr: ModuUniformDensityRatio,
                      max_iter: int = 1_000):
  key = random.PRNGKey(seed)

  log_time = -jnp.inf

  for k in range(1, max_iter + 1):
    u, f, log_time = slow_modu_encoder_loop_body(k, key, log_time, modu_dr)

    if log_time < jnp.log(f):
      return u, k

  else:
    raise ValueError('Did not terminate!')


def slow_modu_decoder(seed: int,
                      k: int):
  key = random.PRNGKey(seed)
  key = random.fold_in(key, k)

  _, u_key = random.split(key)

  return random.uniform(u_key, shape=())


@partial(jit, static_argnames=['modu_dr'])
def binary_modu_encoder_loop_body(k, key, bounds, log_time, modu_dr):

  key = random.fold_in(key, k)
  log_time_key, u_key, b_key = random.split(key, num=3)

  bound_size = bounds[1] - bounds[0]

  log_time = -trunc_gumbel(log_time_key, shape=(), loc=jnp.log(bound_size), bound=-log_time)
  u = random.uniform(u_key, shape=())
  u = bounds[0] + bound_size * u

  r = jnp.exp(modu_dr.log_ratio(u))

  f = modu_dr.stretch(r)

  return u, r, f, log_time, bound_size, b_key


def binary_modu_encoder(seed: int,
                        modu_dr: ModuUniformDensityRatio,
                        max_iter: int = 100):
  log_time = -jnp.inf
  base_key = random.PRNGKey(seed)

  bounds = jnp.array([0., 1.])

  for k in range(1, max_iter + 1):
    u, r, f, log_time, bound_size, b_key = binary_modu_encoder_loop_body(k, base_key, bounds, log_time, modu_dr)

    if log_time <= jnp.log(f):
      return u, k, -jnp.log2(bound_size)

    bound_center = (bounds[0] + bounds[1]) / 2
    bound_left = jnp.array([bounds[0], bound_center])
    bound_right = jnp.array([bound_center, bounds[1]])

    log_p_measure = jnp.log(modu_dr.width_p(r, bounds[0], bounds[1]))
    log_q_measure = jnp.log(modu_dr.width_q(r, bounds[0], bounds[1]))

    assert log_p_measure > -jnp.inf

    log_right_p_measure = jnp.log(modu_dr.width_p(r, bound_right[0], bound_right[1]))
    log_right_q_measure = jnp.log(modu_dr.width_q(r, bound_right[0], bound_right[1]))

    log_g_inv = jnp.log(modu_dr.inv_stretch(jnp.exp(log_time)))

    log_norm_const = logsubexp(log_q_measure, log_g_inv + log_p_measure)
    log_right_prob = logsubexp(log_right_q_measure, log_g_inv + log_right_p_measure)

    right_prob = jnp.exp(log_right_prob - log_norm_const)

    assert 0. <= right_prob <= 1.

    b = random.bernoulli(b_key, right_prob).astype(jnp.int32)

    bounds = [bound_left, bound_right][b]
