import numpy as np
import time
from copy import deepcopy


class timeit:
    stats = {}

    @staticmethod
    def reset():
        timeit.stats = {k: (0.0, 0) for k in timeit.stats}

    @staticmethod
    def time_f(name):
        def _time_f(f):
            assert name not in timeit.stats, '{} already in stats'.format(name)
            timeit.stats[name] = (0.0, 0)
            def g(*args, **kwargs):
                start_t = time.time()
                res = f(*args, **kwargs)
                timeit.stats[name] = (timeit.stats[name][0] + time.time() - start_t, timeit.stats[name][1] + 1)
                return res
            return g
        return _time_f


def q_value_iteration(c, p):
    """
    find optimal Q function for a SSP with costs c and transition kernel p using value iteration method
    :param c: list or numpy array of shape (nb_states, nb_actions) representing costs
    :param p: list or numpy array of shape (nb_actions, nb_states, nb_states) representing transition
    kernel
    :return: approximately optimal_Q which is a numpy array of size (nb_states, nb_actions).
    """
    eps = .000001
    p = np.array(p)
    n, m = c.shape
    q = np.zeros([n, m])
    while True:
        q_new = c + np.dot(p, np.min(q, axis=1)).T
        if np.max(np.abs(q_new - q)) <= eps:
            break
        q = deepcopy(q_new)
    opt_cost = np.min(q_new, axis=1)
    g = np.argmin(q_new, axis=1)
    return opt_cost, g, q_new


def compute_v_pi(c, p, pi):
    c, p, pi = c[:-1, :], p[:, :-1, :-1], pi[:-1] # remove goal state
    nb_states = len(pi)
    c_pi = c[np.arange(nb_states), pi]
    p_pi = p[pi, np.arange(nb_states), :]
    return np.dot(np.linalg.inv(np.eye(nb_states) - p_pi), c_pi)


def dirichlet(alpha):
    """
    Returns array out with the shape ([nb_actions, nb_states, nb_states]) such that
    out[action, state] is a sample from the dirichlet distribution with parameter
    alpha[action, state] and represents the transition probability p(.|s, a).
    :param alpha: numpy array with shape [nb_actions, nb_states, nb_states]
    :return: numpy array with shape [nb_actions, nb_states, nb_states]
    """
    nb_actions, nb_states, _ = alpha.shape
    out = np.zeros_like(alpha)
    for state in range(nb_states):
        for action in range(nb_actions):
            out[action, state] = np.random.dirichlet(alpha[action, state])
    return out


def to_ssp_kernel(p, destination):
    """
    Converts a transition kernel p to SSP kernel, i.e., the destination state is absorbing.
    :param p: a numpy array with shape nb_actions, nb_states, nb_states.
    :param destination: an integer denoting the destination state
    :return: a numpy array with shape nb_actions, nb_states, nb_states.
    """
    p[:, destination, :] = 0
    p[:, destination, destination] = 1
    return p


def empirical_p(n, n_prime):
    """
    returns empirical transition kernel built by number of visits.
    :param n: numpy array of size [nb_states, nb_actions] denoting number of visits.
    :param n_prime: numpy array of size [nb_states, nb_actions, nb_states] denoting number of visits to sas'.
    :return: numpy array of size [nb_actions, nb_states, nb_states] denoting empirical transition kernel for ass'.
    """
    p_hat = (n_prime / np.maximum(n[:, :, None], 1)).transpose(1, 0, 2)
    return p_hat


def bernstein_confidence(n, n_prime):
    p_hat = empirical_p(n, n_prime)
    N = np.maximum(n.transpose()[:, :, None], 1)
    ans = 1.0 / N + np.sqrt(p_hat / N)
    return ans


def extended_value_iteration(costs, c, n, n_prime, eps=1e-6):
    nb_states, nb_actions = costs.shape
    q = np.zeros([nb_states, nb_actions])
    p_hat = empirical_p(n, n_prime)
    p = np.zeros_like(p_hat)
    confidence = bernstein_confidence(n, n_prime)
    while True:
        v = np.min(q, axis=1)
        indices = np.argsort(v)
        p = np.maximum(p_hat - c * confidence, 0) # lower bound
        remaining_prob = 1 - p.sum(axis=-1)
        possible_inc = p_hat + c * confidence - p
        for i in indices: # over all (s, a)
            inc = np.minimum(possible_inc[:, :, i], remaining_prob)
            p[:, :, i] += inc
            remaining_prob -= inc
        p[:, -1, :] = 0
        p[:, -1, -1] = 1 # make goal state absorbing
        q_new = costs + np.dot(p, v).T
        if np.max(np.abs(q_new - q)) <= eps:
            break
        q = deepcopy(q_new)
    return np.argmin(q, axis=1), p


def compute_pivot_horizon(pi, p, gamma):
    nb_states = pi.shape[0]
    Q = p[pi, np.arange(nb_states), :][:-1, :-1] # remove goal
    Qs, n = Q, 1
    while np.max(Qs) > gamma: 
        Qs = np.dot(Qs, Qs)
        n += 1
    # binary search
    l, r = 1, 1<<n
    while l < r:
        mid = (l + r) >> 1
        q = np.linalg.matrix_power(Q, mid)
        if np.max(q) <= gamma: r = mid
        else: l = mid + 1
    return l
