'''
Implementation of a basic RL environment.
'''

import numpy as np
from scipy.special import gammaln, softmax
#-------------------------------------------------------------------------------


class Environment:
    '''General RL environment'''

    def __init__(self, seed=0):
        '''
        Initialize a new environment

        Args:
            seed    - int - random seed

        Returns:
            Environment object
        '''

        self.state = None
        self.timestep = 0
        self.rangomgenerator = np.random.default_rng(seed)

    def reset(self):
        '''Reset the environment'''
        self.state = None
        self.timestep = 0
        return self.state

    def step(self, action):
        '''
        Moves one step in the environment.

        Args:
            action

        Returns:
            reward - double - reward
            newState - int - new state
            done - 0/1 - flag for end of the episode
        '''
        reward = self.sample_reward(self.state, action)
        newState = self.sample_next_state(self.state, action)

        self.state = newState
        self.timestep += 1

        return reward, newState, 0

    def sample_reward(self, state, action, size=None):
        '''
        Sample a reward given the current state and action

        Args:
            state
            action
            size - int - number of next state samples

        Returns:
            reward - float - a single reward if `size=None` or a list of rewards if otherwise
        '''
        raise NotImplementedError("Subclasses should implement `sample_reward`")

    def sample_next_state(self, state, action, size=None):
        '''
        Sample the next state given the current state and action

        Args:
            state
            action
            size - int - number of next state samples

        Returns:
            next_state - a single state if `size=None` or a list of states if otherwise
        '''
        raise NotImplementedError("Subclasses should implement `sample_next_state`")

#-------------------------------------------------------------------------------
# Finite MDP

class FiniteMDP(Environment):
    '''
    MDP with finite state and action spaces
    '''
    def __init__(self, nState, nAction, seed=0):
        '''
        Initialize a finite-horizon finite MDP

        Args:
            nState  - int - number of states
            nAction - int - number of actions
            seed    - int - random seed

        Returns:
            Environment object
        '''
        super().__init__(seed)

        self.nState = nState
        self.nAction = nAction
        self.state = 0

    def reset(self):
        '''Reset the environment'''
        self.timestep = 0
        self.state = 0
        return self.state

#-------------------------------------------------------------------------------
# Finite Horizon Finite MDP

class FiniteHorizonFiniteMDP(FiniteMDP):
    '''
    Finite MDP with a finite horizon
    '''
    def __init__(self, nState, nAction, epLen, seed=0):
        '''
        Initialize a finite-horizon finite MDP

        Args:
            nState  - int - number of states
            nAction - int - number of actions
            epLen   - int - episode length
            seed    - int - random seed

        Returns:
            Environment object
        '''
        super().__init__(nState, nAction, seed)
        self.epLen = epLen

    def step(self, action):
        '''
        Move one step in the environment

        Args:
        action - int - chosen action

        Returns:
        reward - double - reward
        newState - int - new state
        done - 0/1 - flag for end of the episode
        '''
        reward = self.sample_reward(self.state, action)
        newState = self.sample_next_state(self.state, action)

        # Update the environment
        self.state = newState
        self.timestep += 1

        if self.timestep == self.epLen:
            done = 1
        else:
            done = 0

        return reward, newState, done

    def to_tabular_MDP(self):
        '''
        Compute the qVals values for the environment

        Args:
            NULL - works on the FiniteHorizonTabularMDP

        Returns:
            P - S*A*S transition tensor
            R - S*A reward matrix
        '''
        raise NotImplementedError("Subclasses should implement `to_tabular_MDP`")

    def value_iteration(self):
        '''
        Compute the qVals values for the environment

        Args:
            NULL - works on the FiniteHorizonTabularMDP

        Returns:
            qVals - an S*H*A tensor of qVals values
            vVals - an H*S matrix of optimal values
        '''
        P, R = self.to_tabular_MDP()

        qVals = np.repeat(R[:, np.newaxis], self.epLen, axis=1)
        vVals = np.zeros((self.epLen, self.nState))

        vVals[self.epLen - 1] = np.max(qVals[:, self.epLen - 1, :], axis=-1)

        for h in reversed(range(self.epLen-1)):
            qVals[:, h, :] += P @ vVals[h+1]
            vVals[h] = np.max(qVals[:, h, :], axis=-1)
        #print(np.argmax(qVals,axis=-1))
        return qVals, vVals

    def softmax_policy_evaluation(self,policy):
        '''
        Evaluate the value of a softmax policy
        '''

        P, R = self.to_tabular_MDP()
        qVals = np.repeat(R[:, np.newaxis], self.epLen, axis=1)
        vVals = np.zeros((self.epLen, self.nState))
        for s in range(self.nState):
            vVals[self.epLen - 1][s] = np.dot(qVals[s, self.epLen - 1, :], policy[s, self.epLen - 1, :])

        for h in reversed(range(self.epLen - 1)):
            qVals[:, h, :] += P @ vVals[h + 1]
            for s in range(self.nState):
                vVals[h][s] = np.dot(qVals[s, h, :], policy[s, h, :])

        return qVals, vVals

    def policy_iteration(self, policy_matrix, alpha):

        qVals, vVals = self.softmax_policy_evaluation(policy_matrix)
        for h in range(self.epLen):
            for s in range(self.nState):
                for a in range(self.nAction):
                    policy_matrix[s, h, a] += alpha * (qVals[s, h, a] - vVals[h, s])
                # policy_matrix[s, h, :] /= np.sum(policy_matrix[s, h, :])
        return policy_matrix
    def policy_evaluation(self, policy):
        '''
        Evaluate the value of a deterministic policy

        Args:
            policy - policy for evaluation

        Returns:
            qVals - an S*H*A tensor of qVals values
            vVals - an H*S matrix of optimal values
        '''
        P, R = self.to_tabular_MDP()

        qVals = np.repeat(R[:, np.newaxis], self.epLen, axis=1)
        vVals = np.zeros((self.epLen, self.nState))

        for s in range(self.nState):
            vVals[self.epLen - 1, s] = qVals[s, self.epLen - 1, policy[s, self.epLen - 1]]

        for h in reversed(range(self.epLen-1)):
            qVals[:, h, :] += P @ vVals[h+1]
            for s in range(self.nState):
                vVals[h, s] = qVals[s, h, policy[s, h]]

        return qVals, vVals

#-------------------------------------------------------------------------------
# Finite Horizon Tabular MDP

class FiniteHorizonTabularMDP(FiniteHorizonFiniteMDP):
    '''
    Deterministic Reward Finite Horizon Tabular MDP

    P - dict by (s,a) - each P[s,a] = transition vector size S
    R - dict by (s,a) - each R[s,a] = meanReward
    '''

    def __init__(self, nState, nAction, epLen, P, R,seed=0):
        '''
        Initialize a tabular episodic MDP

        Args:
            nState  - int - number of states
            nAction - int - number of actions
            epLen   - int - episode length
            P - dict by (s,a) - each P[s,a] = transition vector size S
            R - dict by (s,a) - each R[s,a] = meanReward

        Returns:
            Environment object
        '''
        super().__init__(nState, nAction, epLen,seed)

        # Now initialize R and P
        self.P = P
        self.R = R

    def to_tabular_MDP(self):
        P = np.zeros((self.nState, self.nAction, self.nState))
        R = np.zeros((self.nState, self.nAction))
        for s in range(self.nState):
            for a in range(self.nAction):
                P[s, a] = self.P[s, a]
                R[s, a] = self.mean_reward(s, a)

        return P, R

    def sample_reward(self, state, action, size=None):
        return np.full(size, self.R[state, action])

    def sample_next_state(self, state, action, size=None):
        #return np.random.choice(self.nState, size=size, p=self.P[state, action])
        return self.rangomgenerator.choice(self.nState,size=size,p=self.P[state, action])

    def mean_reward(self, state, action):
        return self.R[state, action]


#-------------------------------------------------------------------------------
# Finite Horizon Tabular MDP with Gaussian Reward

class GaussianRewardFiniteHorizonTabularMDP(FiniteHorizonTabularMDP):
    '''
    Gaussian Reward Finite Horizon Tabular MDP

    P - dict by (s,a) - each P[s,a] = transition vector size S
    R - dict by (s,a) - each R[s,a] = (meanReward, stdReward)
    '''

    def sample_reward(self, state, action, size=None):
        if self.R[state, action][1] < 1e-9:
            # Hack for no noise
            reward = np.full(size, self.R[state, action][0])
        else:
            # reward = np.random.normal(loc=self.R[state, action][0],
            #                           scale=self.R[state, action][1],
            #                           size=size)
            reward = self.rangomgenerator.normal(loc=self.R[state, action][0],
                                      scale=self.R[state, action][1],
                                      size=size)
            # reward = np.full(size, self.R[state, action][0])
        return reward

    def mean_reward(self, state, action):
        return self.R[state, action][0]

#-------------------------------------------------------------------------------
# Finite Horizon Tabular MDP with Bernoulli Reward

class BernoulliRewardFiniteHorizonTabularMDP(FiniteHorizonTabularMDP):
    '''
    Bernoulli Reward Finite Horizon Tabular MDP

    P - dict by (s,a) - each P[s,a] = transition vector size S
    R - dict by (s,a) - each R[s,a] = meanReward in [0, 1]
    '''

    def sample_reward(self, state, action, size=None):
        #return np.random.binomial(1, self.R[state, action], size=size)
        return self.rangomgenerator.binomial(1, self.R[state, action], size=size)

#-------------------------------------------------------------------------------
# Benchmark environments

def make_random_gaussian_reward_tabular_MDP(dist_param, epLen,seed=0):
    '''
    Create a random Gaussian reward Finite Horizon Tabular MDP

    Args:
        dist_param - np.ndarray - specify the distribution used to generate the MDP
        epLen - int - the length of the episode
        seed - the seed used to generate the random generator

    Returns:
        GaussianRewardFiniteHorizonTabularMDP object
    '''
    randomgenerator=np.random.default_rng(seed)
    nState, nAction, _ = dist_param.shape
    R = np.ones((nState, nAction, 2))
    R[:, :, 0] = randomgenerator.normal(dist_param[:, :, -2], dist_param[:, :, -1])
    P = np.zeros((nState, nAction, nState))
    for s in range(nState):
        for a in range(nAction):
            P[s, a] = randomgenerator.dirichlet(dist_param[s, a, :-2])

    return GaussianRewardFiniteHorizonTabularMDP(nState, nAction, epLen, P, R,seed)


def make_riverSwim(nState=6, epLen=20,seed=0):
    '''
    Makes the benchmark RiverSwim MDP.

    Args:
        nState - int - the number of states.
        epLen - int - the length of the episode
        seed - the seed used to generate the random generator

    Returns:
        riverSwim - Tabular MDP environment
    '''
    nAction = 2
    R_true = {}
    P_true = {}

    for s in range(nState):
        for a in range(nAction):
            R_true[s, a] = 0.0
            P_true[s, a] = np.zeros(nState)

    # Rewards
    R_true[0, 0] = 5. / 1000
    R_true[nState - 1, 1] = 1.0

    # Transitions
    for s in range(nState):
        P_true[s, 0][max(0, s-1)] = 1.

    for s in range(1, nState - 1):
        P_true[s, 1][min(nState - 1, s + 1)] = 0.35
        P_true[s, 1][s] = 0.6
        P_true[s, 1][max(0, s-1)] = 0.05

    P_true[0, 1][0] = 0.4
    P_true[0, 1][1] = 0.6
    P_true[nState - 1, 1][nState - 1] = 0.6
    P_true[nState - 1, 1][nState - 2] = 0.4
    riverSwim = FiniteHorizonTabularMDP(nState, nAction, epLen, P_true, R_true,seed)
    riverSwim.reset()

    return riverSwim

def make_deterministicChain(nState, epLen,seed=0):
    '''
    Creates a deterministic chain MDP with two actions.

    Args:
        nState - int - number of states
        epLen - int - episode length
        seed - int - the seed used to generate the random generator

    Returns:
        chainMDP - Tabular MDP environment
    '''
    nAction = 2

    R_true = {}
    P_true = {}

    for s in range(nState):
        for a in range(nAction):
            R_true[s, a] = (0, 0)
            P_true[s, a] = np.zeros(nState)

    # Rewards
    R_true[0, 0] = (0, 1)
    R_true[nState - 1, 1] = (1, 1)

    # Transitions
    for s in range(nState):
        P_true[s, 0][max(0, s-1)] = 1.
        P_true[s, 1][min(nState - 1, s + 1)] = 1.

    chainMDP = GaussianRewardFiniteHorizonTabularMDP(nState, nAction, epLen, P_true, R_true,seed)
    chainMDP.reset()

    return chainMDP

def make_stochasticChain(nState,seed=0):
    '''
    Creates a difficult stochastic chain MDP with two actions.

    Args:
        nState - int - total number of states
        seed - int - the seed used to generate the random generator

    Returns:
        chainMDP - Tabular MDP environment
    '''
    nState = nState
    epLen = nState
    nAction = 2
    pNoise = 1. / nState

    R_true = {}
    P_true = {}

    for s in range(nState):
        for a in range(nAction):
            R_true[s, a] = (0, 0)
            P_true[s, a] = np.zeros(nState)

    # Rewards
    R_true[0, 0] = (0, 1)
    R_true[nState - 1, 1] = (1, 1)

    # Transitions
    for s in range(nState):
        P_true[s, 0][max(0, s-1)] = 1.

        P_true[s, 1][min(nState - 1, s + 1)] = 1. - pNoise
        P_true[s, 1][max(0, s-1)] += pNoise

    stochasticChain = GaussianRewardFiniteHorizonTabularMDP(nState, nAction, epLen, P_true, R_true,seed)
    stochasticChain.reset()

    return stochasticChain
