from collections import Counter
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import ticker

from .radian_formatter import Multiple


def traj_length_time(df: pd.DataFrame, DT: float, ax: Optional[plt.Axes] = None, bins: int = 10, **kwargs):
    scene_ts = df.reset_index(level="scene_ts")["scene_ts"]
    grouped = scene_ts.groupby(["scene_id", "agent_id", "agent_type"])
    traj_length_df = (grouped.last() - grouped.first()) * DT
    if ax is not None:
        ax = sns.histplot(traj_length_df, bins=bins, stat="proportion", ax=ax, binwidth=DT, **kwargs)
        ax.set_xlabel("Trajectory Length (s)")
        ax.yaxis.set_major_formatter(ticker.PercentFormatter(1.0, decimals=0))
    else:
        value_hist = dict()
        for agent_type, group_df in traj_length_df.groupby("agent_type"):
            value_hist[agent_type] = np.histogram(group_df, bins=bins, **kwargs)

        return value_hist


def max_heading_delta(df: pd.DataFrame, ax: Optional[plt.Axes] = None, bins: int = 100, **kwargs):
    max_heading_delta = df.heading.groupby(["scene_id", "agent_id", "agent_type"]).apply(lambda x: np.ptp(np.unwrap(x)))

    if ax is not None:
        ax = sns.histplot(max_heading_delta, bins=bins, stat="proportion", ax=ax, **kwargs)

        major = Multiple(denominator=4)
        # ax.xaxis.set_major_locator(major.locator())
        ax.xaxis.set_major_formatter(major.formatter())
        ax.set_xlabel("Max. Heading Change (radians)")
        ax.set_yscale("log")
    else:
        value_hist = dict()
        for agent_type, group_df in max_heading_delta.groupby("agent_type"):
            value_hist[agent_type] = np.histogram(group_df, bins=bins, **kwargs)

        return value_hist


def rel_heading_delta(df: pd.DataFrame, ax: Optional[plt.Axes] = None, bins: int = 100, **kwargs):
    # grouped = df.heading.groupby(["scene_id", "agent_id"])
    # headings = grouped.apply(np.unwrap).explode()
    # rel_heading_delta = headings - grouped.transform("first")
    rel_heading_delta = (
        df.heading.groupby(["scene_id", "agent_id", "agent_type"]).apply(lambda x: np.unwrap(x)[1:] - x[0]).explode()
    )

    if ax is not None:
        ax = sns.histplot(rel_heading_delta, bins=bins, stat="proportion", ax=ax, **kwargs)

        major = Multiple(denominator=2)
        # ax.xaxis.set_major_locator(major.locator())
        ax.xaxis.set_major_formatter(major.formatter())
        ax.set_xlabel("Rel. Heading Change (radians)")
        ax.set_yscale("log")
        # ax.yaxis.set_major_formatter(ticker.PercentFormatter(1.0, decimals=0))
    else:
        value_hist = dict()
        for agent_type, group_df in rel_heading_delta.groupby("agent_type"):
            value_hist[agent_type] = np.histogram(group_df, bins=bins, **kwargs)

        return value_hist


def agent_type(scenes, ax: Optional[plt.Axes] = None, bins: int = 10, **kwargs):
    agent_types = []
    for scene in scenes:
        agents = scene.agents
        agent_types.extend([agent.type.name for agent in agents])

    if ax is not None:
        ax = sns.histplot(agent_types, bins=bins, stat="proportion", ax=ax, **kwargs)
        ax.set_xlabel("Agent Type")
        ax.yaxis.set_major_formatter(ticker.PercentFormatter(1.0, decimals=0))
    else:
        return Counter(agent_types)
