"""
SSP environments.
FiniteMDP is a superclass for SSP environments.
Currently available environments:
(1) RandomMDPEnv
(2) GridWorldEnv
"""
import numpy as np
from utils import q_value_iteration, compute_v_pi, to_ssp_kernel
from itertools import product


class FiniteMDP(object):
    def __init__(self, nb_states, nb_actions, costs, p, destination):
        assert destination == nb_states - 1
        self.nb_states = nb_states
        self.nb_actions = nb_actions
        self.states = range(self.nb_states)
        self.actions = range(self.nb_actions)
        self.costs = costs
        self.p = p
        self.destination = destination

        self.state = None
        self.opt_cost = None
        self.opt_policy = None
        self.opt_q = None
        self.opt_t = None

    def optimal_cost(self):
        if self.opt_cost is not None:
            return self.opt_cost
        self.opt_cost, self.opt_policy, self.opt_q = q_value_iteration(self.costs, self.p)
        return self.opt_cost

    def optimal_expected_hitting_time(self):
        if self.opt_t is not None:
            return self.opt_t
        self.opt_t = compute_v_pi(np.ones_like(self.costs), self.p, self.opt_policy)
        return self.opt_t

    def optimal_policy(self):
        if self.opt_policy is not None:
            return self.opt_policy
        self.opt_cost, self.opt_policy, self.opt_q = q_value_iteration(self.costs, self.p)
        return self.opt_policy

    def optimal_q(self):
        if self.opt_q is not None:
            return self.opt_q
        self.opt_cost, self.opt_policy, self.opt_q = q_value_iteration(self.costs, self.p)
        return self.opt_q

    def info(self):
        s = ''
        s += 'name = {0}\n'.format(self.__class__.__name__)
        s += 'number of states = {0}, number of actions = {1}\n'.format(self.nb_states, self.nb_actions)
        s += 'optimal cost = {0}\n'.format(self.optimal_cost())
        s += 'optimal policy = {0}\n'.format(self.optimal_policy())
        s += 'optimal q =\n{0}\n'.format(self.optimal_q())
        s += 'cost function =\n{0}\n'.format(self.costs)
        s += 'transition kernel=\n{0}\n'.format(self.p)
        return s

    def reset(self):
        self.state = 0 # initial state is always 0
        return self.state

    def step(self, action):
        if self.state == self.destination:
            print('SSP is in the absorbing state. Please reset the environment.')
        done = False
        cost = self.cost(self.state, action)
        self.state = np.random.choice(self.states, p=self.p[action, self.state])
        if self.state == self.destination:
            done = True
        return self.state, cost, done

    def cost(self, state, action):
        return self.costs[state, action]

    def write_info(self, directory=''):
        import os
        np.savetxt(os.path.join(directory, 'p'), self.p.reshape(self.nb_actions*self.nb_states, self.nb_states))
        np.savetxt(os.path.join(directory, 'c'), self.costs)


class RandomMDPEnv(FiniteMDP):
    """
    Generates an SSP with randomly chosen transition kernel and costs.
    Note that c(s, a) is uniformly sampled from [0, 1]
    and p(s'|s,a) is uniformly chosen from unit interval and then normalized.
    Destination is chosen to be the last state and is absorbing
    """
    def __init__(self, nb_states=6, nb_actions=2):
        self.destination = nb_states - 1
        costs = self._random_costs(nb_states, nb_actions)
        print(costs)
        p = self._random_p(nb_states, nb_actions)
        print(p)
        super(RandomMDPEnv, self).__init__(nb_states=nb_states, nb_actions=nb_actions, costs=costs, p=p, destination=self.destination)

    def _random_costs(self, nb_states, nb_actions):
        mat = np.random.rand(nb_states, nb_actions)
        mat[self.destination, :] = 0
        return mat

    def _random_p(self, nb_states, nb_actions):
        p = np.zeros((nb_actions, nb_states, nb_states))
        for a in range(nb_actions):
            mat = np.random.rand(nb_states, nb_states)
            p[a, :, :] = mat / np.sum(mat, axis=1)[:, None]
        return to_ssp_kernel(p, destination=self.destination)


class GridWorldEnv(FiniteMDP):
    def __init__(self, h, w):
        random_p = 0.05
        destination = h * w - 1
        costs = np.ones((h * w, 4))
        costs[destination, :] = 0
        print(costs)
        p = np.zeros((4, h * w, h * w))
        p[:, h * w - 1, h * w - 1] = 1
        for i, j in product(range(h), range(w)):
            if i == h - 1 and j == w - 1: continue 
            for a, d in enumerate([(0, 1), (0, -1), (1, 0), (-1, 0)]):
                di, dj = i + d[0], j + d[1]
                if not (0 <= di < h and 0 <= dj < w): di, dj = i, j
                p[a][i * w + j][di * w + dj] += 1 - 4 * random_p
                p[:, i * w + j, di * w + dj] += random_p
        super().__init__(w * h, 4, costs, p, destination)
