import numpy as np

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.legend import Legend
from matplotlib.lines import Line2D
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pylab

from env import GridWorld


def plot_gridworld(
    model: GridWorld,
    value_function: np.ndarray = None,
    policy: np.ndarray = None,
    state_counts: np.ndarray = None,
    title: str = None,
    path: str = None,
) -> None:
    """Plot the gridworld with value function and policy.

    Args:
        model (Gridworld): The gridworld model.
        value_function (np.array, optional): The value function. Defaults to None.
        policy (np.array, optional): The policy. Defaults to None.
        state_counts (np.array, optional): The state counts. Defaults to None.
        title (str, optional): The title of the plot. Defaults to None.
        path (str, optional): The path to save the plot. Defaults to None.

    Raises:
        Exception: Must supple either value function or state_counts, not both!
    """

    if value_function is not None and state_counts is not None:
        raise Exception("Must supple either value function or state_counts, not both!")

    # change colorscheme
    fig, ax = plt.subplots()

    plt.rcParams["image.cmap"] = "pink_r"
    plt.rcParams["figure.dpi"] = 300
    plt.rcParams["figure.figsize"] = (8, 6)
    plt.rcParams["font.size"] = 12
    plt.rcParams["legend.fontsize"] = "large"
    plt.rcParams["figure.titlesize"] = "medium"
    plt.rcParams["axes.titlesize"] = "medium"
    plt.rcParams["axes.labelsize"] = "medium"
    plt.rcParams["xtick.labelsize"] = "medium"
    plt.rcParams["ytick.labelsize"] = "medium"

    add_patches(model, ax)
    add_policy(model, policy)

    # add features to grid world
    if value_function is not None:
        add_value_function(model, value_function, "Value function")
    elif state_counts is not None:
        add_value_function(model, state_counts, "State counts")
    elif value_function is None and state_counts is None:
        add_value_function(model, value_function, "Value function")

    # set xtick and y tick labels to empty
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    plt.tight_layout()

    if title is not None:
        plt.title(title, fontdict=None, loc="center")
    if path is not None:
        plt.savefig(path, dpi=300, bbox_inches="tight")

    legend_elements = [
        Line2D(
            [0],
            [0],
            marker="o",
            color="w",
            label="Start",
            markerfacecolor="b",
            markersize=15,
            linestyle="None",
        ),
        Line2D(
            [0],
            [0],
            marker="p",
            color="w",
            label="Goal",
            markerfacecolor="g",
            markersize=15,
            linestyle="None",
        ),
        Line2D(
            [0],
            [0],
            marker="s",
            color="maroon",
            label="Obstacle",
            markerfacecolor="maroon",
            markersize=15,
            linestyle="None",
        ),
        Line2D(
            [0],
            [0],
            marker=r"$\rightarrow$",
            color="black",
            label="Policy",
            markerfacecolor="black",
            markersize=24,
            linestyle="None",
        ),
    ]

    legend = plt.legend(
        handles=legend_elements,
        loc="upper center",
        bbox_to_anchor=(0.5, -0.12),
        fancybox=True,
        shadow=True,
        ncol=4,
    )

    if legend is not None:
        # save legend separately
        fig = legend.figure
        fig.canvas.draw()
        bbox = legend.get_window_extent()
        bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
        fig.savefig("gridworld_legend.pdf", dpi=300, bbox_inches=bbox, transparent=True)

    # plt.show()


def add_value_function(
    model: GridWorld,
    value_function: np.ndarray,
    name: str,
) -> None:
    """Add the value function to the gridworld.

    Args:
        model (Gridworld): The gridworld model.
        value_function (np.array): The value function.
        name (str): The name of the value function.
    """
    ax = plt.gca()
    if value_function is not None:
        # colobar max and min
        vmin = np.min(value_function)
        vmax = np.max(value_function)
        # reshape and set obstructed states to low value
        val = value_function[:-1, 0].reshape(model.num_rows, model.num_cols)
        if model.obs_states is not None:
            index = model.obs_states
            val[index[:, 0], index[:, 1]] = -100
        plt.imshow(val, vmin=vmin, vmax=vmax, zorder=0)
        # divider = make_axes_locatable(ax)
        # cax = divider.append_axes("right", size="5%", pad=0.2)
        cbar = plt.colorbar(label=name)
        cbar.ax.tick_params(labelsize=18)
        cbar.ax.set_ylabel(name, fontsize=24)
        plt.clim(0, 1.0)
    else:
        val = np.zeros((model.num_rows, model.num_cols))
        plt.imshow(val, zorder=0)
        plt.yticks(np.arange(-0.5, model.num_rows + 0.5, step=1))
        plt.xticks(np.arange(-0.5, model.num_cols + 0.5, step=1))
        plt.grid(linewidth=2, color="black")
        # divider = make_axes_locatable(ax)
        # cax = divider.append_axes("right", size="5%", pad=0.1)
        cbar = plt.colorbar(label=name)
        cbar.ax.tick_params(labelsize=18)
        cbar.ax.set_ylabel(name, fontsize=24)
        plt.clim(0, 1.0)


def add_patches(model: GridWorld, ax: plt.axes) -> None:
    start = patches.Circle(
        tuple(np.flip(model.start_state[0])),
        0.2,
        linewidth=1,
        edgecolor="b",
        facecolor="b",
        zorder=1,
        label="Start",
    )
    """Add the patches to the gridworld.
    
    Args:
        model (Gridworld): The gridworld model.
        ax (matplotlib.axes): The axes to plot on.
    """
    ax.add_patch(start)

    for i in range(model.goal_states.shape[0]):
        end = patches.RegularPolygon(
            tuple(np.flip(model.goal_states[i, :])),
            numVertices=5,
            radius=0.25,
            orientation=np.pi,
            edgecolor="g",
            zorder=1,
            facecolor="g",
            label="Goal" if i == 0 else None,
        )
        ax.add_patch(end)

    # obstructed states patches
    if model.obs_states is not None:
        for i in range(model.obs_states.shape[0]):
            obstructed = patches.Rectangle(
                tuple(np.flip(model.obs_states[i, :]) - 0.35),
                0.7,
                0.7,
                linewidth=1,
                edgecolor="maroon",
                facecolor="maroon",
                zorder=1,
                label="Obstructed" if i == 0 else None,
            )
            ax.add_patch(obstructed)

    if model.bad_states is not None:
        for i in range(model.bad_states.shape[0]):
            bad = patches.Wedge(
                tuple(np.flip(model.bad_states[i, :])),
                0.2,
                40,
                -40,
                linewidth=1,
                edgecolor="r",
                facecolor="r",
                zorder=1,
                label="Bad state" if i == 0 else None,
            )
            ax.add_patch(bad)

    if model.restart_states is not None:
        for i in range(model.restart_states.shape[0]):
            restart = patches.Wedge(
                tuple(np.flip(model.restart_states[i, :])),
                0.2,
                40,
                -40,
                linewidth=1,
                edgecolor="y",
                facecolor="y",
                zorder=1,
                label="Restart state" if i == 0 else None,
            )
            ax.add_patch(restart)


def add_policy(model: GridWorld, policy: np.ndarray) -> None:
    """Add the policy to the gridworld.

    Args:
        model (Gridworld): The gridworld model.
        policy (np.array): The policy.
    """

    if policy is not None:
        # define the gridworld
        X = np.arange(0, model.num_cols, 1)
        Y = np.arange(0, model.num_rows, 1)

        # define the policy direction arrows
        U, V = create_policy_direction_arrays(model, policy)
        # remove the obstructions and final state arrows
        ra = model.goal_states
        U[ra[:, 0], ra[:, 1]] = np.nan
        V[ra[:, 0], ra[:, 1]] = np.nan
        if model.obs_states is not None:
            ra = model.obs_states
            U[ra[:, 0], ra[:, 1]] = np.nan
            V[ra[:, 0], ra[:, 1]] = np.nan
        if model.restart_states is not None:
            ra = model.restart_states
            U[ra[:, 0], ra[:, 1]] = np.nan
            V[ra[:, 0], ra[:, 1]] = np.nan

        plt.quiver(X, Y, U, V, zorder=10, label="Policy")


def create_policy_direction_arrays(model: GridWorld, policy: np.ndarray) -> np.ndarray:
    """Create the policy direction arrays.

    Args:
        model (Gridworld): The gridworld model.
        policy (np.array): The policy.

    Returns:
        np.array: The arrow array.
    """
    # action options
    UP = 0
    DOWN = 1
    LEFT = 2
    RIGHT = 3

    # intitialize direction arrays
    U = np.zeros((model.num_rows, model.num_cols))
    V = np.zeros((model.num_rows, model.num_cols))

    for state in range(model.num_states - 1):
        # get index of the state
        i = tuple(model._seq_to_col_row(state)[0])
        # define the arrow direction
        if policy[state] == UP:
            U[i] = 0
            V[i] = 0.5
        elif policy[state] == DOWN:
            U[i] = 0
            V[i] = -0.5
        elif policy[state] == LEFT:
            U[i] = -0.5
            V[i] = 0
        elif policy[state] == RIGHT:
            U[i] = 0.5
            V[i] = 0

    return U, V
