"""Samplers."""
import numpy as np
from typing import Any

Array = Any


class RandSampler(object):
  """A random sampler."""

  def __init__(self, max_size: int, batch_size: int = 1) -> None:
    self._max_size = max_size
    self._batch_size = batch_size

  def sample(self):
    """Return an array of sampled indices."""

    return np.random.randint(self._max_size, size=self._batch_size)


class BalancedSampler(object):
  """A balanced sampler."""

  def __init__(self, probs: Array, max_size: int, batch_size: int) -> None:
    self._max_size = max_size
    self._batch_size = batch_size
    self._probs = probs / np.sum(probs)
  
  def sample(self):
    return np.random.choice(self._max_size, size=self._batch_size, replace=False, p=self._probs)
