################################################################################
# experilog/stats.py
#
# 
# 
# 
# 2024
#
# Helper functions for statistical calculations.

import numpy       as np
import scipy.stats as sps

from typing import Tuple

def stdevs_percentile(
    # Arguments:
    p: float
  ) -> float:
  """
  Calculates the number of standard deviations away from the mean the given
  percentile ``p`` is (on a unit normal distribution).

  Args:
    p (float):
      The percentile (in the range [0.0, 1.0]).

  Returns:
    float:
      The number of standard deviations from the mean.
  """
  upper = sps.norm.ppf(0.5 + (p / 2.0))
  lower = sps.norm.ppf(0.5 - (p / 2.0))
  return float((upper - lower) / 2.0)

def estimate_confidence_interval(
    # Arguments:
    mu:      float,
    sigma_m: float,
    sigma_p: float
  ) -> Tuple[float, float]:
  """
  Calculates an estimate of the confidence interval from the mean and standard
  error. This should only be used for large sample sizes.

  Args:
    mu (float):
      The sample mean.
    sigma_m (float):
      The standard error of the mean (ideally, the unbiased estimate).
    sigma_p (float):
      The number of standard deviations from the mean for a given percentile.

  Returns:
    Tuple[float, float]:
      The (lower, upper) confidence bounds.
  """
  offset = sigma_p * sigma_m
  return (float(mu - offset), float(mu + offset))

def standard_error(
    # Arguments:
    sigma: float,
    N:     int
  ) -> float:
  """
  Calculates the standard error of the mean.

  Args:
    sigma (float):
      The standard devation (sample or population).
    N (int):
      The sample / population size.

  Returns:
    float:
      The standard error of the mean (sigma_m).
  """
  return float(sigma / np.sqrt(float(N)))