import joblib
from joblib import Parallel, delayed
import numpy as np
import time


def posify_mtx(Sigma: np.ndarray, tol: float=1e-9):
  Sigma = 0.5 * (Sigma + Sigma.T) + 1e-6 * np.eye(Sigma.shape[0])
  return Sigma


def sigmoid(x):
  x = np.minimum(np.maximum(x, -10), 10)
  y = 1 / (1 + np.exp(- x))
  return y


def irls(X, y, theta0, Lambda0, theta, irls_error=1e-3, irls_num_iter=30):
  # iterative reweighted least squares for Bayesian logistic regression
  # Sections 4.3.3 and 4.5.1 in Bishop (2006)
  # Pattern Recognition and Machine Learning

  num_iter = 0
  while num_iter < irls_num_iter:
    theta_old = np.copy(theta)

    if y.size > 0:
      Xtheta = X.dot(theta)
      R = sigmoid(Xtheta) * (1 - sigmoid(Xtheta))
      Hessian = (X * R[:, np.newaxis]).T.dot(X) + Lambda0
      grad = (sigmoid(Xtheta) - y).dot(X) + Lambda0.dot(theta - theta0)
    else:
      R = np.zeros(0)
      Hessian = Lambda0
      grad = Lambda0.dot(theta - theta0)
    theta = np.linalg.solve(Hessian, Hessian.dot(theta) - grad)

    if np.linalg.norm(theta - theta_old) < irls_error:
      break;
    num_iter += 1

  converged = (num_iter < irls_num_iter)
  return theta, Hessian, converged


# Bandit environments and simulator

class GLMBandit(object):
  """GLM bandit."""

  def __init__(self, X, K, theta, mean_function="linear", sigma=0.5):
    self.X_all = np.copy(X)  # matrix of all arm features
    self.K = K  # number of arms per round
    self.d = self.X_all.shape[1]  # number of features

    self.theta = np.copy(theta)  # model parameter
    self.mean_function = mean_function
    self.sigma = sigma  # reward noise

    self.randomize()

  def randomize(self):
    arms = np.random.choice(self.X_all.shape[0], self.K, replace=False)  # random subset of K arms
    self.X = self.X_all[arms, :]  # K x d matrix of arm features

    # mean rewards of all arms
    if self.mean_function == "linear":
      self.mu = self.X.dot(self.theta)
    elif self.mean_function == "logistic":
      self.mu = sigmoid(self.X.dot(self.theta))
    else:
      raise Exception("Unknown mean function.")

    self.best_arm = np.argmax(self.mu)  # optimal arm

    # generate random rewards
    if self.mean_function == "linear":
      self.rt = self.mu + self.sigma * np.random.randn(self.K)
    elif self.mean_function == "logistic":
      self.rt = (np.random.rand(self.K) < self.mu).astype(float)
    else:
      raise Exception("Unknown mean function.")

  def reward(self, arm):
    # instantaneous reward of the arm
    return self.rt[arm]

  def regret(self, arm):
    # instantaneous regret of the arm
    return self.rt[self.best_arm] - self.rt[arm]

  def pregret(self, arm):
    # expected regret of the arm
    return self.mu[self.best_arm] - self.mu[arm]

  def print(self):
    return "GLM bandit: %d dimensions, %d arms, %s mean function" % (self.d, self.K, self.mean_function)


class LinBandit(object):
  """Linear bandit."""

  def __init__(self, X, theta, sigma=0.5, labels=None, features=None):
    self.X = np.copy(X)  # K x d matrix of arm features
    self.K = self.X.shape[0]
    self.d = self.X.shape[1]
        
    self.theta = np.copy(theta)  # model parameter
    self.sigma = sigma  # reward noise
    self.labels = labels
    self.features = features

    self.mu = self.X.dot(self.theta)  # mean rewards of all arms
    self.randomize()

    self.best_arm = np.argmax(self.mu)  # optimal arm

  def randomize(self):
    # generate random rewards
    if self.features is None:
      self.rt = self.mu + self.sigma * np.random.randn(self.K)
    else:
      arms = np.random.choice(len(self.features), self.K, replace=False)
      if self.labels == None:
        self.X = self.features[arms]  # K x d matrix of arm features
        self.mu = self.X.dot(self.theta)  # mean rewards of all arms
        self.best_arm = np.argmax(self.mu)  # optimal arm
        self.rt = self.mu + self.sigma * np.random.randn(self.K)
        # self.rt = (np.random.rand(self.K) < self.mu).astype(float)
      else:
        self.rt = self.labels[arms]
        self.mu = self.labels[arms]
        self.X = self.features[arms]
        self.best_arm = np.argmax(self.mu)  # optimal arm

  def reward(self, arm):
    # instantaneous reward of the arm
    return self.rt[arm]

  def regret(self, arm):
    # instantaneous regret of the arm
    return self.rt[self.best_arm] - self.rt[arm]

  def pregret(self, arm):
    # expected regret of the arm
    return self.mu[self.best_arm] - self.mu[arm]

  def print(self):
    return "Linear bandit: %d dimensions, %d arms" % (self.d, self.K)


def evaluate_one(Alg, params, env, n, period_size=1, return_logs=True):
  """One run of a bandit algorithm."""
  alg = Alg(env, n, params)

  regret = np.zeros(n // period_size)
  for t in range(n):
    # generate state
    env.randomize()

    # take action and update agent
    arm = alg.get_arm(t)
    alg.update(t, arm, env.reward(arm))

    # track performance
    regret_at_t = env.regret(arm)
    regret[t // period_size] += regret_at_t

  if return_logs:
    return regret, alg
  else:
    return regret


def evaluate(Alg, params, env, n=1000, period_size=1, printout=True, return_logs=True, n_jobs=-1):
  """Multiple runs of a bandit algorithm."""
  if printout:
    print("Evaluating %s" % Alg.print(), end="")
  start = time.time()

  num_exps = len(env)
  regret = np.zeros((n // period_size, num_exps))
  alg = num_exps * [None]

  output = Parallel(n_jobs=n_jobs)(delayed(evaluate_one)(Alg, params, env[ex], n, period_size, return_logs)
    for ex in range(num_exps))

  if return_logs:
      for ex in range(num_exps):
        regret[:, ex] = output[ex][0]
        alg[ex] = output[ex][1]
  else:
      for ex in range(num_exps):
        regret[:, ex] = output[ex]

  if printout:
    print(" %.1f seconds" % (time.time() - start))

  if printout:
    total_regret = regret.sum(axis=0)
    print("Regret: %.2f +/- %.2f (median: %.2f, max: %.2f, min: %.2f)" %
      (total_regret.mean(), total_regret.std() / np.sqrt(num_exps),
      np.median(total_regret), total_regret.max(), total_regret.min()))

  if return_logs:
      return regret, alg
  else:
      return regret


# Bandit algorithms

class LinBanditAlg:
  def __init__(self, env, n, params):
    self.env = env  # bandit environment that the agent interacts with
    self.K = self.env.K  # number of arms
    self.d = self.env.d  # number of features
    self.n = n  # horizon
    self.theta0 = np.zeros(self.d)  # prior mean of the model parameter
    self.Sigma0 = np.eye(self.d)  # prior covariance of the model parameter
    self.sigma = 0.5  # reward noise

    # override default values
    for attr, val in params.items():
      if isinstance(val, np.ndarray):
        setattr(self, attr, np.copy(val))
      else:
        setattr(self, attr, val)

    # sufficient statistics
    self.Lambda = np.linalg.inv(self.Sigma0)
    self.B = self.Lambda.dot(self.theta0)

  def update(self, t, arm, r):
    # update sufficient statistics
    x = self.env.X[arm, :]
    self.Lambda += np.outer(x, x) / np.square(self.sigma)
    self.B += x * r / np.square(self.sigma)


class LinTS(LinBanditAlg):
  def get_arm(self, t):
    # linear model posterior
    Sigma_hat = np.linalg.inv(self.Lambda)
    theta_hat = Sigma_hat.dot(self.B)

    # posterior sampling
    self.theta_tilde = np.random.multivariate_normal(theta_hat, posify_mtx(Sigma_hat))
    self.mu = self.env.X.dot(self.theta_tilde)

    arm = np.argmax(self.mu)
    return arm

  @staticmethod
  def print():
    return "LinTS"


class LinUCB(LinBanditAlg):
  def __init__(self, env, n, params):
    LinBanditAlg.__init__(self, env, n, params)

    self.cew = self.confidence_ellipsoid_width(n)

  def confidence_ellipsoid_width(self, t):
    # Theorem 2 in Abassi-Yadkori (2011)
    # Improved Algorithms for Linear Stochastic Bandits
    delta = 1 / self.n
    L = np.amax(np.linalg.norm(self.env.X, axis=1))
    Lambda = np.square(self.sigma) * np.linalg.eigvalsh(np.linalg.inv(self.Sigma0)).max()  # V = \sigma^2 (posterior covariance)^{-1}
    R = self.sigma
    S = np.sqrt(self.d)
    width = np.sqrt(Lambda) * S + \
      R * np.sqrt(self.d * np.log((1 + t * np.square(L) / Lambda) / delta))
    return width

  def get_arm(self, t):
    # linear model posterior
    Sigmahat = np.linalg.inv(self.Lambda)
    thetahat = Sigmahat.dot(self.B)

    # UCBs
    invV = Sigmahat / np.square(self.sigma)  # V^{-1} = posterior covariance / \sigma^2
    self.mu = self.env.X.dot(thetahat) + self.cew * \
      np.sqrt((self.env.X.dot(invV) * self.env.X).sum(axis=1))

    arm = np.argmax(self.mu)
    return arm

  @staticmethod
  def print():
    return "LinUCB"


class LinGreedy(LinBanditAlg):
  def get_arm(self, t):
    self.mu = np.zeros(self.K)

    # roughly 5% exploration rate
    if np.random.rand() < 0.05 * np.sqrt(self.n / (t + 1)) / 2:
      self.mu[np.random.randint(self.K)] = np.Inf
    else:
      theta = np.linalg.solve(self.Lambda, self.B)
      self.mu = self.env.X.dot(theta)

    arm = np.argmax(self.mu)
    return arm

  @staticmethod
  def print():
    return "Linear e-greedy"


class MixTS:
  def __init__(self, env, n, params):
    self.env = env  # bandit environment that the agent interacts with
    self.K = self.env.K  # number of arms
    self.d = self.env.d  # number of features
    self.n = n  # horizon
    self.num_mix = 2  # number of mixture components
    self.p0 = np.ones(self.num_mix) / self.num_mix  # mixture weights
    self.theta0 = np.zeros((self.num_mix, self.d))  # prior means of the mixture components
    self.Sigma0 = np.zeros((self.num_mix, self.d, self.d))  # prior covariances of the mixture components
    for i in range(self.num_mix):
      self.Sigma0[i, :, :] = np.eye(self.d)
    self.sigma = 0.5  # reward noise

    # override default values
    for attr, val in params.items():
      if isinstance(val, np.ndarray):
        setattr(self, attr, np.copy(val))
      else:
        setattr(self, attr, val)

    if self.p0.ndim == 2:
      ndx = np.random.randint(self.p0.shape[0])
      self.p0 = self.p0[ndx, :]
      self.theta0 = self.theta0[ndx, :, :]
      self.Sigma0 = self.Sigma0[ndx, :, :, :]

    self.num_mix = self.p0.size

    # initialize mixture-component algorithms
    self.algs = []
    for i in range(self.num_mix):
      alg = LinTS(self.env, self.n,
        {"theta0": self.theta0[i, :], "Sigma0": self.Sigma0[i, :, :], "sigma": self.sigma})
      self.algs.append(alg)

  def update(self, t, arm, r):
    # update mixture-component algorithms
    for i in range(self.num_mix):
      self.algs[i].update(t, arm, r)

  def get_arm(self, t):
    # latent state posterior
    _, prior_logdet = np.linalg.slogdet(self.Sigma0)
    Lambda0 = np.linalg.inv(self.Sigma0)

    logp = np.zeros(self.num_mix)
    for i in range(self.num_mix):
      post_cov = np.linalg.inv(self.algs[i].Lambda)
      _, post_logdet = np.linalg.slogdet(post_cov)
      logp[i] = 0.5 * ((self.d * np.log(2 * np.pi) + post_logdet) -
        ((self.d + t) * np.log(2 * np.pi) + 2 * t * np.log(self.sigma) + prior_logdet[i])) + \
        0.5 * (self.algs[i].B.T.dot(post_cov).dot(self.algs[i].B) -
        self.theta0[i, :].T.dot(Lambda0[i, :, :]).dot(self.theta0[i, :])) + \
        np.log(self.p0[i])

    self.p = np.exp(logp - logp.max())
    self.p /= self.p.sum()

    # posterior sampling
    self.component_tilde = np.random.choice(self.num_mix, p=self.p)
    arm = self.algs[self.component_tilde].get_arm(t)

    return arm

  @staticmethod
  def print():
    return "MixTS"


class LinDiffTS(LinBanditAlg):
  def __init__(self, env, n, params):
    LinBanditAlg.__init__(self, env, n, params)

  def map_estimator(self, theta0, Sigma0, t):
    Lambda0 = np.linalg.inv(Sigma0)
    Lambda = Lambda0 + self.Lambda
    Sigma = np.linalg.inv(Lambda)
    theta = Sigma.dot(Lambda0.dot(theta0) + self.B)
    return theta, Sigma

  def get_arm(self, t):
    # posterior sampling through likelihood
    map_lambda = lambda theta0, Sigma0: self.map_estimator(theta0, Sigma0, t)
    self.theta_tilde = self.prior.posterior_sample_map(map_lambda)[0, :]
    self.mu = self.env.X.dot(self.theta_tilde)

    # # evidence
    # Sigma_bar = np.linalg.inv(self.Lambda)
    # theta_bar = Sigma_bar.dot(self.B)

    # # posterior sampling
    # theta_tilde = self.prior.posterior_sample(theta_bar, posify_mtx(Sigma_bar))[0, :]
    # self.mu = self.env.X.dot(theta_tilde)

    arm = np.argmax(self.mu)
    return arm

  @staticmethod
  def print():
    return "LinDiffTS"


class LinDiffTSChung:
  def __init__(self, env, n, params):
    self.env = env  # bandit environment that the agent interacts with
    self.K = self.env.K  # number of arms
    self.d = self.env.d  # number of features
    self.n = n  # horizon
    self.sigma = 0.5  # reward noise

    # override default values
    for attr, val in params.items():
      if isinstance(val, np.ndarray):
        setattr(self, attr, np.copy(val))
      else:
        setattr(self, attr, val)

    # sufficient statistics
    self.Xh = np.zeros((self.n, self.d))  # history of feature vectors
    self.yh = np.zeros(self.n)  # history of rewards

  def update(self, t, arm, r):
    # update sufficient statistics
    self.Xh[t, :] = self.env.X[arm, :]
    self.yh[t] = r

  def loglik_grad(self, theta0, t, linear_growth=False):
    if t == 0:
      grad = np.zeros(self.d)
    else:
      X = self.Xh[: t, :]
      y = self.yh[: t]
      v = X.T.dot(y - X.dot(theta0)) / np.linalg.norm(y - X.dot(theta0))
      grad = v / np.square(self.sigma)
      if linear_growth:
        grad *= np.sqrt(t)
    return grad

  def get_arm(self, t):
    # posterior sampling through loglik gradient
    grad_lambda = lambda theta0: self.loglik_grad(theta0, t)
    self.theta_tilde = self.prior.posterior_sample_grad(grad_lambda)[0, :]
    self.mu = self.env.X.dot(self.theta_tilde)

    arm = np.argmax(self.mu)
    return arm

  @staticmethod
  def print():
    return "LinDiffTSChung"


class LogBanditAlg:
  def __init__(self, env, n, params):
    self.env = env  # bandit environment that the agent interacts with
    self.K = self.env.K  # number of arms
    self.d = self.env.d  # number of features
    self.n = n  # horizon
    self.theta0 = np.zeros(self.d)  # prior mean of the model parameter
    self.Sigma0 = np.eye(self.d)  # prior covariance of the model parameter

    # override default values
    for attr, val in params.items():
      if isinstance(val, np.ndarray):
        setattr(self, attr, np.copy(val))
      else:
        setattr(self, attr, val)

    # sufficient statistics
    self.Lambda0 = np.linalg.inv(self.Sigma0)
    self.Xh = np.zeros((self.n, self.d))  # history of feature vectors
    self.yh = np.zeros(self.n)  # history of rewards

    self.irls_theta = np.zeros(self.d)

  def update(self, t, arm, r):
    # update sufficient statistics
    self.Xh[t, :] = self.env.X[arm, :]
    self.yh[t] = r

  def solve(self, t):
    theta, Hessian, converged = irls(
      self.Xh[: t, :], self.yh[: t], self.theta0, self.Lambda0, np.copy(self.irls_theta))
    if converged:
      self.irls_theta = np.copy(theta)
    else:
      self.irls_theta = np.zeros(self.d)

    return theta, Hessian, converged


class LogTS(LogBanditAlg):
  def get_arm(self, t):
    # Laplace posterior approximation
    theta_hat, Hessian, _ = self.solve(t)
    Sigma_hat = np.linalg.inv(Hessian)

    # posterior sampling
    theta_tilde = np.random.multivariate_normal(theta_hat, Sigma_hat)
    self.mu = self.env.X.dot(theta_tilde)

    arm = np.argmax(self.mu)
    return arm

  @staticmethod
  def print():
    return "LogTS"


class LogDiffTS(LogBanditAlg):
  def __init__(self, env, n, params):
    LogBanditAlg.__init__(self, env, n, params)

  def map_estimator(self, theta0, Sigma0, t):
    theta, Lambda, _ = irls(
      self.Xh[: t, :], self.yh[: t], theta0, np.linalg.inv(Sigma0), np.zeros(self.d))
    Sigma = np.linalg.inv(Lambda)
    return theta, Sigma

  def get_arm(self, t):
    # posterior sampling through likelihood
    map_lambda = lambda theta0, Sigma0: self.map_estimator(theta0, Sigma0, t)
    theta_tilde = self.prior.posterior_sample_map(map_lambda)[0, :]
    self.mu = self.env.X.dot(theta_tilde)

    # # evidence
    # theta_bar, Hessian, _ = self.solve(t)
    # Sigma_bar = np.linalg.inv(Hessian)

    # # posterior sampling
    # theta_tilde = self.prior.posterior_sample(theta_bar, posify_mtx(Sigma_bar))[0, :]
    # self.mu = self.env.X.dot(theta_tilde)

    arm = np.argmax(self.mu)
    return arm

  @staticmethod
  def print():
    return "LogDiffTS"
