import numpy as np
import gym
from scipy.linalg import solve_discrete_are
from lqr_env import LQREnv
from IPython import embed
import torch
import scipy
import itertools
from bandit_env import Controller
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


UPPER_BOUND = np.inf
LOWER_BOUND = -np.inf


def sample(dim, H):
    goal = np.random.randint(0, dim, 2)
    env = DarkroomEnv(dim, goal, H)
    return env



class DarkroomEnv(LQREnv):
    def __init__(self, dim, goal, H, random_init=False):
        self.dim = dim
        self.goal = np.array(goal)
        self.H = H
        self.dx = 2
        self.du = 5
        self.observation_space = gym.spaces.Box(low=0, high=dim - 1, shape=(self.dx,))
        self.action_space = gym.spaces.Discrete(self.du)
        self.random_init = random_init
        
    def sample_x(self):
        return np.random.randint(0, self.dim, 2)
    
    def sample_u(self):
        i = np.random.randint(0, 5)
        a = np.zeros(self.action_space.n)
        a[i] = 1
        return a

    def reset(self):
        self.current_step = 0
        if self.random_init:
            self.state = self.sample_x()
        else:
            self.state = np.array([0, 0])
        return self.state


    def transit(self, s, a):
        a = np.argmax(a)
        assert a in np.arange(self.action_space.n)
        s = np.array(s)
        if a == 0:
            s[0] += 1
        elif a == 1:
            s[0] -= 1
        elif a == 2:
            s[1] += 1
        elif a == 3:
            s[1] -= 1
        s = np.clip(s, 0, self.dim - 1)

        if np.all(s == self.goal):
            r = 1
        else:
            r = 0
        return s, r

    def step(self, action):
        if self.current_step >= self.H:
            raise ValueError("Episode has already ended")

        self.state, r = self.transit(self.state, action)
        self.current_step += 1
        done = (self.current_step >= self.H)
        return self.state.copy(), r, done, {}

    def get_obs(self):
        return self.state.copy()

    def opt_a(self, x):
        if x[0] < self.goal[0]:
            a = 0
        elif x[0] > self.goal[0]:
            a = 1
        elif x[1] < self.goal[1]:
            a = 2
        elif x[1] > self.goal[1]:
            a = 3
        else:
            a = 4
        zeros = np.zeros(self.action_space.n)
        zeros[a] = 1
        return zeros


class DarkroomEnvStitch(DarkroomEnv):
    """
    Darkroom environment with two goals, one on the right and one on the bottom.
    If the goal is on the right, the agent is initialized on the left during train and top during eval.
    If the goal is on the bottom, the agent is initialized on the top during train and left during eval.
    """
    def __init__(self, dim, goal, H, eval=False):
        self.goals = [np.array([dim // 2, dim - 1]), np.array([dim - 1, dim // 2])]
        assert any([np.all(goal == g) for g in self.goals])
        super().__init__(dim, goal, H)

        self.eval = eval
        if eval:
            self.initial_states = [np.array([0, dim // 2]), np.array([dim // 2, 0])]
            self.demo_states = [
                [np.array([i, dim // 2]) for i in range(dim // 2)] + [np.array([dim // 2, i]) for i in range(dim // 2, dim)],
                [np.array([dim // 2, i]) for i in range(dim // 2)] + [np.array([i, dim // 2]) for i in range(dim // 2, dim)],
            ]
        else:
            self.initial_states = [np.array([dim // 2, 0]), np.array([0, dim // 2])]
            self.demo_states = [
                [np.array([dim // 2, i]) for i in range(dim)],
                [np.array([i, dim // 2]) for i in range(dim)],
            ]

    def reset(self):
        self.current_step = 0
        if np.all(self.goal == self.goals[0]):
            self.state = self.initial_states[0]
        else:
            self.state = self.initial_states[1]
        return self.state

    def sample_stitch_x(self):
        assert self.eval
        all_demo_states = self.demo_states[0] + self.demo_states[1]
        return all_demo_states[np.random.randint(0, len(all_demo_states))]

    def sample_stitch_opt_a(self, x):
        assert self.eval
        if x[0] == self.dim // 2 and x[1] != self.dim // 2:
            if x[1] < self.dim - 1:
                a = 2
            else:
                a = 4
        elif x[1] == self.dim // 2 and x[0] != self.dim // 2:
            if x[0] < self.dim - 1:
                a = 0
            else:
                a = 4
        else:
            if np.random.rand() < 0.5:
                a = 2
            else:
                a = 0
        zeros = np.zeros(self.action_space.n)
        zeros[a] = 1
        return zeros

    def sample_opt_x(self):
        if np.all(self.goal == self.goals[0]):
            return self.demo_states[0][np.random.randint(0, len(self.demo_states[0]))]
        else:
            return self.demo_states[1][np.random.randint(0, len(self.demo_states[1]))]

    def opt_a(self, x):
        if self.eval:
            if np.all(self.goal == self.goals[0]):
                # down then right
                return super().opt_a(x)
            else:
                # right then down
                if x[1] < self.goal[1]:
                    a = 2
                elif x[1] > self.goal[1]:
                    a = 3
                elif x[0] < self.goal[0]:
                    a = 0
                elif x[0] > self.goal[0]:
                    a = 1
                else:
                    a = 4
                zeros = np.zeros(self.action_space.n)
                zeros[a] = 1
                return zeros
        else:
            return super().opt_a(x)


class DarkroomEnvPermuted(DarkroomEnv):
    """
    Darkroom environment with permuted actions. The goal is always the bottom right corner.
    """
    def __init__(self, dim, perm_index, H):
        goal = np.array([dim - 1, dim - 1])
        super().__init__(dim, goal, H)

        self.perm_index = perm_index
        assert perm_index < 120     # 5! permutations in darkroom
        actions = np.arange(self.action_space.n)
        permutations = list(itertools.permutations(actions))
        self.perm = permutations[perm_index]

    def transit(self, s, a):
        perm_a = np.zeros(self.action_space.n)
        perm_a[self.perm[np.argmax(a)]] = 1
        return super().transit(s, perm_a)

    def opt_a(self, x):
        action = super().opt_a(x)
        action = np.argmax(action)
        perm_action = np.where(self.perm == action)[0][0]
        zeros = np.zeros(self.action_space.n)
        zeros[perm_action] = 1
        return zeros


class DarkroomEnvVec(LQREnv):
    """
    Vectorized Darkroom environment.
    """
    def __init__(self, envs):
        self._envs = envs
        self._num_envs = len(envs)

    def reset(self):
        return [env.reset() for env in self._envs]

    def step(self, actions):
        next_obs, rews, dones = [], [], []
        for action, env in zip(actions, self._envs):
            next_ob, rew, done, _ = env.step(action)
            next_obs.append(next_ob)
            rews.append(rew)
            dones.append(done)
        return next_obs, rews, dones, {}

    @property
    def num_envs(self):
        return self._num_envs

    @property
    def envs(self):
        return self._envs

    def deploy(self, ctrl, include_partial_hist=False, grow_context=False):
        x = self.reset()
        xs = []
        xps = []
        us = []
        rs = []
        done = False

        while not done:
            u = ctrl.act(x)

            xs.append(x)
            us.append(u)

            x, r, done, _ = self.step(u)
            done = all(done)

            rs.append(r)
            xps.append(x)

            if include_partial_hist:
                new_x = torch.tensor(np.array(xs[-1])[:, None, :]).float().to(device)
                new_u = torch.tensor(np.array(us[-1])[:, None, :]).float().to(device)
                new_xp = torch.tensor(np.array(xps[-1])[:, None, :]).float().to(device)
                new_r = torch.tensor(np.array(r)[:, None, None]).float().to(device)

                if grow_context:
                    new_rollin_xs = torch.cat((ctrl.batch['rollin_xs'], new_x), axis=1)
                    new_rollin_us = torch.cat((ctrl.batch['rollin_us'], new_u), axis=1)
                    new_rollin_xps = torch.cat((ctrl.batch['rollin_xps'], new_xp), axis=1)
                    new_rollin_rs = torch.cat((ctrl.batch['rollin_rs'], new_r), axis=1)
                else:
                    new_rollin_xs = torch.cat((ctrl.batch['rollin_xs'][:, 1:], new_x), axis=1)
                    new_rollin_us = torch.cat((ctrl.batch['rollin_us'][:, 1:], new_u), axis=1)
                    new_rollin_xps = torch.cat((ctrl.batch['rollin_xps'][:, 1:], new_xp), axis=1)
                    new_rollin_rs = torch.cat((ctrl.batch['rollin_rs'][:, 1:], new_r), axis=1)

                batch = {
                    'rollin_xs': new_rollin_xs,
                    'rollin_us': new_rollin_us,
                    'rollin_xps': new_rollin_xps,
                    'rollin_rs': new_rollin_rs,
                }
                ctrl.set_batch(batch)

        return np.stack(xs, axis=1), np.stack(us, axis=1), np.stack(xps, axis=1), np.stack(rs, axis=1)


class DarkroomOptPolicy(Controller):
    def __init__(self, env):
        super().__init__()
        self.env = env
        self.goal = env.goal

    def reset(self):
        return

    def act(self, x):
        return self.env.opt_a(x)
        
        
class RandCommit(Controller):
    def __init__(self, env):
        super().__init__()
        self.goal = None
        self.env = env

    def reset(self):
        self.goal = None


    def set_batch(self, batch):
        self.batch = batch
        rs = batch['rollin_rs'].flatten().cpu().detach().numpy()
        if len(rs) > 0 and np.max(rs) > 0:
            i = np.argmax(rs)
            self.goal = batch['rollin_xps'][0,i,:].cpu().detach().numpy()
        else:
            self.goal = None

    def act(self, x):
        if self.goal is None:
            a = np.random.choice(np.arange(self.env.action_space.n))
        else:
            return self.env.opt_a(x)
        zeros = np.zeros(self.env.action_space.n)
        zeros[a] = 1
        return zeros


class RandPolicy(DarkroomOptPolicy):

    def __init__(self, env):
        super().__init__(env)
        self.env = env

    def act(self, x):
        a = np.random.choice([0, 1, 2, 3, 4])
        zeros = np.zeros(self.env.action_space.n)
        zeros[a] = 1
        return zeros


class DarkroomTransformerController(Controller):
    def __init__(self, model, batch_size=1, sample=False):
        self.model = model
        self.du = model.config['du']
        self.dx = model.config['dx']
        self.H = model.H
        self.zeros = torch.zeros(batch_size, self.dx**2 + self.du + 1).float().to(device)
        self.zerosQ = torch.zeros(batch_size, self.H, self.dx**2).float().to(device)
        self.sample = sample
        self.temp = 1.0
        self.batch_size = batch_size


    def act(self, x):
        self.batch['zeros'] = self.zeros
        self.batch['zerosQ'] = self.zerosQ

        states = torch.tensor(np.array(x)).float().to(device)
        if self.batch_size == 1:
            states = states[None, :]
        self.batch['states'] = states

        actions = self.model(self.batch).cpu().detach().numpy()
        if self.batch_size == 1:
            actions = actions[0]

        if self.sample:
            if self.batch_size > 1:
                action_indices = []
                for idx in range(self.batch_size):
                    probs = scipy.special.softmax(actions[idx] / self.temp)
                    action_indices.append(np.random.choice(np.arange(self.du), p=probs))
            else:
                probs = scipy.special.softmax(actions / self.temp)
                action_indices = [np.random.choice(np.arange(self.du), p=probs)]
            # print(f"max: {probs.round(2)}")
        else:
            action_indices = np.argmax(actions, axis=-1)

        actions = np.zeros((self.batch_size, self.du))
        actions[np.arange(self.batch_size), action_indices] = 1.0
        if self.batch_size == 1:
            actions = actions[0]
        return actions



if __name__ == '__main__':
    env = sample(3)
    ctrl = RandPolicy(env)
    embed()