import os
import pickle
import re
import warnings
from collections import Counter
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Dict

import matplotlib

matplotlib.use("agg")
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib import ticker
from tqdm import tqdm
from trajdata import MapAPI, UnifiedDataset, VectorMap
from trajdata.data_structures import AgentType, Scene

from trajdata_analysis.analysis.failure_analyses import (
    parallel_analyze_collisions_offroad,
)
from trajdata_analysis.analysis.point_wise import (
    agent_ego_distance,
    pointwise_acceleration,
    pointwise_heading,
    pointwise_jerk,
    pointwise_speed,
)
from trajdata_analysis.analysis.radian_formatter import Multiple
from trajdata_analysis.analysis.trajectory_wise import (
    max_heading_delta,
    rel_heading_delta,
    traj_length_time,
)
from trajdata_analysis.data_preprocessing.data_preprocessing import (
    parallel_generate_dataframes,
)


def is_cached(curr_results, bins, key: str):
    hist_re = re.compile(f"{key}_.*_hist")
    bin_re = re.compile(f"{key}_.*_bins")

    hist_matches = list(filter(hist_re.match, curr_results.keys()))
    bin_matches = list(filter(bin_re.match, curr_results.keys()))

    try:
        return len(hist_matches) > 0 and len(bin_matches) > 0 and np.allclose(curr_results[bin_matches[0]], bins[key])
    except:
        return False


def main():
    make_plots = True

    # Setting plot defaults.
    sns.set_theme(style="ticks")

    # cache_path = Path("~/.unified_data_cache").expanduser()
    # map_api = MapAPI(cache_path)

    for env_spec in [
        # "nusc_trainval-train",
        # "waymo_val-val",  # TODO: NEED TO MAKE THIS TRAIN
        # "interaction_single-train",
        # "interaction_multi-train",
        # "lyft_sample-mini_train",
        # "nuplan_mini-mini_train",
        # "nusc_mini-mini_train",
        # "sdd-train",
        "eupeds_ETH",
        "eupeds_UCY",
    ]:
        if "eupeds" in env_spec:
            env_name = env_spec

            desired_data_dict = {
                "eupeds_ETH": ["eupeds_eth-test_loo", "eupeds_hotel-test_loo"],
                "eupeds_UCY": ["eupeds_univ-test_loo", "eupeds_zara1-test_loo", "eupeds_zara2-test_loo"],
            }
        else:
            env_name = env_spec.split("-")[0]

        if env_spec == "eupeds_ETH":
            data_dirs_dict = {
                # NOTE: This assumes the data is already cached!
                "eupeds_eth": "",
                "eupeds_hotel": "",
            }
        elif env_spec == "eupeds_UCY":
            data_dirs_dict = {
                # NOTE: This assumes the data is already cached!
                "eupeds_univ": "",
                "eupeds_zara1": "",
                "eupeds_zara2": "",
            }
        else:
            # NOTE: This assumes the data is already cached!
            data_dirs_dict = {env_name: ""}

        print(flush=True)
        print("#" * 40, flush=True)
        print(f"Analyzing {env_name}...", flush=True)
        print("#" * 40, flush=True)

        plot_dir = Path(f"plots/{env_name}")
        plot_dir.mkdir(parents=True, exist_ok=True)
        use_cache = True

        if "eupeds" not in env_name:
            desired_data_list = [env_name]
        else:
            desired_data_list = desired_data_dict[env_name]

        dataset = UnifiedDataset(
            desired_data=desired_data_list,
            num_workers=os.cpu_count(),
            verbose=True,
            data_dirs=data_dirs_dict,
        )

        DT = dataset.envs[0].metadata.dt
        print(f"# Data Samples: {len(dataset):,}")

        all_agent_types = [x.name for x in AgentType]

        # Overloading these variables, they will later hold bin edges and
        # keep them constant throughout computation.
        bins = dict()
        bins["speed"] = np.linspace(0, 60, 101)
        bins["acc"] = np.linspace(0, 100, 101)
        bins["jerk"] = np.linspace(-100, 100, 101)
        n_headings = 57
        bins["heading"] = np.linspace(-np.pi - np.pi / (n_headings - 1), np.pi - np.pi / (n_headings - 1), n_headings)
        bins["ae_dist"] = np.linspace(0, 200, 101)
        bins["length"] = np.arange(0, 20 + DT, DT)
        bins["max_dh"] = np.linspace(0, 4 * np.pi, 101)
        bins["rel_dh"] = np.linspace(-4 * np.pi, 4 * np.pi, 100)

        hists = dict()
        for name, bin_arr in bins.items():
            for agent_type in all_agent_types:
                hists[f"{name}_{agent_type}"] = np.zeros((bin_arr.shape[0] - 1,), dtype=int)

        print("Point-wise Analyses", flush=True)
        for (data_df, curr_cache_file) in parallel_generate_dataframes(
            dataset, hist_cache_dir=plot_dir, skip_if_cached=True
        ):
            n_compute_workers = 8
            with ProcessPoolExecutor(max_workers=n_compute_workers) as executor:
                curr_results: Dict[str, np.ndarray] = dict()
                futures = dict()

                if curr_cache_file.exists() and use_cache:
                    loaded_data = np.load(curr_cache_file)
                    curr_results = {k: loaded_data[k] for k in loaded_data.files}

                if not is_cached(curr_results, bins, "speed"):
                    futures[executor.submit(pointwise_speed, data_df, bins=bins["speed"])] = "speed"

                if not is_cached(curr_results, bins, "acc"):
                    futures[executor.submit(pointwise_acceleration, data_df, bins=bins["acc"])] = "acc"

                if not is_cached(curr_results, bins, "jerk"):
                    futures[executor.submit(pointwise_jerk, data_df, DT=DT, bins=bins["jerk"])] = "jerk"

                if not is_cached(curr_results, bins, "heading"):
                    futures[executor.submit(pointwise_heading, data_df, bins=bins["heading"])] = "heading"

                if "interaction" not in env_name and "eupeds" not in env_name and "sdd" not in env_name:
                    if not is_cached(curr_results, bins, "ae_dist"):
                        futures[executor.submit(agent_ego_distance, data_df, bins=bins["ae_dist"])] = "ae_dist"

                if not is_cached(curr_results, bins, "length"):
                    futures[executor.submit(traj_length_time, data_df, DT=DT, bins=bins["length"])] = "length"

                if not is_cached(curr_results, bins, "max_dh"):
                    futures[
                        executor.submit(
                            max_heading_delta,
                            data_df,
                            bins=bins["max_dh"],
                        )
                    ] = "max_dh"

                if not is_cached(curr_results, bins, "rel_dh"):
                    futures[
                        executor.submit(
                            rel_heading_delta,
                            data_df,
                            bins=bins["rel_dh"],
                        )
                    ] = "rel_dh"

                for future in tqdm(
                    as_completed(futures),
                    desc=f"Computing Histograms ({n_compute_workers} CPUs)",
                    total=len(futures),
                    position=3,
                    leave=False,
                ):
                    name = futures[future]
                    results_dict = future.result()
                    for agent_type in results_dict:
                        value_hist, _ = results_dict[agent_type]
                        curr_results[f"{name}_{agent_type}_hist"] = value_hist
                        curr_results[f"{name}_{agent_type}_bins"] = bins[name]

                for name, value in curr_results.items():
                    if name.endswith("_hist"):
                        hists[name[: -len("_hist")]] += value

                if len(futures) > 0:
                    np.savez(curr_cache_file, **curr_results)

        if make_plots:
            normal_keys_labels_dict = {
                "speed": (r"Speed $(m/s)$", True),
                "acc": (r"Acceleration $(m/s^2)$", True),
                "jerk": (r"Jerk $(m/s^3)$", True),
                "ae_dist": ("Agent-Ego Distance (m)", False),
                "length": ("Agent Observation Length (s)", True),
            }

            heading_keys_labels_dict = {
                "heading": ("Heading (radians)", False),
                "max_dh": ("Max. Heading Change (radians)", True),
                "rel_dh": ("Rel. Heading Change (radians)", True),
            }

            n_plot_workers = os.cpu_count()
            with ProcessPoolExecutor(max_workers=n_plot_workers) as executor:
                futures = list()
                for agent_type in [x.name for x in AgentType]:
                    with warnings.catch_warnings():
                        warnings.simplefilter(action="ignore", category=FutureWarning)

                        for key, (label, logscale) in normal_keys_labels_dict.items():
                            if key == "ae_dist" and (
                                "interaction" in env_name or "eupeds" in env_name or "sdd" in env_name
                            ):
                                continue

                            if not np.any(hists[f"{key}_{agent_type}"]):
                                continue

                            futures.append(
                                executor.submit(normal_plot, plot_dir, bins, hists, agent_type, key, label, logscale)
                            )

                        if np.any(hists[f"heading_{agent_type}"]):
                            futures.append(executor.submit(polar_heading_plot, plot_dir, bins, hists, agent_type))

                        for key, (label, logscale) in heading_keys_labels_dict.items():
                            if not np.any(hists[f"{key}_{agent_type}"]):
                                continue

                            futures.append(
                                executor.submit(radianx_plot, plot_dir, bins, hists, agent_type, key, label, logscale)
                            )

                for future in tqdm(
                    as_completed(futures),
                    desc=f"Making Plots ({n_plot_workers} CPUs)",
                    total=len(futures),
                ):
                    pass

        print("Scene-wise Analyses", flush=True)
        scenewise_cache_file = plot_dir / "scenewise_info.pkl"
        if scenewise_cache_file.exists() and use_cache:
            pass
        else:
            max_sim_agents = []
            sim_agents = []
            agent_types = []
            scene: Scene
            for scene in tqdm(dataset.scenes(), total=dataset.num_scenes(), desc="Scenes"):
                agent_types.extend([agent.type.name for agent in scene.agents])

                simul_agents_per_timestep = [len(agents_at_time) for agents_at_time in scene.agent_presence]
                sim_agents.extend(simul_agents_per_timestep)
                max_sim_agents.append(max(simul_agents_per_timestep))

            agent_counts = Counter(agent_types)
            sim_agents_hist, sim_agents_bins = np.histogram(sim_agents, bins=np.linspace(0, 250, 51))
            max_sim_agents_hist, max_sim_agents_bins = np.histogram(max_sim_agents, bins=np.linspace(0, 250, 51))

            failure_results_dict = parallel_analyze_collisions_offroad(dataset, max_workers=os.cpu_count())

            with open(plot_dir / "scenewise_info.pkl", "wb") as f:
                pickle.dump(
                    {
                        "agent_counts": agent_counts,
                        "sim_agents_hist": sim_agents_hist,
                        "sim_agents_bins": sim_agents_bins,
                        "max_sim_agents_hist": max_sim_agents_hist,
                        "max_sim_agents_bins": max_sim_agents_bins,
                        **failure_results_dict,
                    },
                    f,
                )


def polar_heading_plot(plot_dir, bins, hists, agent_type):
    fig, ax = plt.subplots(subplot_kw=dict(projection="polar"))
    sns.histplot(
        x=bins["heading"][:-1],
        weights=hists[f"heading_{agent_type}"],
        ax=ax,
        bins=bins["heading"],
        stat="proportion",
        linewidth=0.1,
        edgecolor="k",
    )
    ax.set_xlabel("Heading (radians)")
    ax.set_ylabel(None)
    ax.spines["polar"].set_visible(False)
    ax.set_yticks(np.linspace(0.02, 0.1, 5))

    major = Multiple(denominator=4)
    # ax.xaxis.set_major_locator(major.locator())
    ax.xaxis.set_major_formatter(major.formatter())
    ax.yaxis.set_major_formatter(ticker.PercentFormatter(1.0, decimals=0))
    fig.savefig(plot_dir / f"{agent_type}_heading_polar.pdf", bbox_inches="tight")
    plt.close(fig)


def radianx_plot(plot_dir, bins, hists, agent_type, key, label, logscale):
    fig, ax = plt.subplots()
    sns.histplot(
        x=bins[key][:-1],
        weights=hists[f"{key}_{agent_type}"],
        ax=ax,
        bins=bins[key],
        stat="proportion",
    )
    ax.set_xlabel(label)
    if logscale:
        ax.set_yscale("log")
    else:
        ax.yaxis.set_major_formatter(ticker.PercentFormatter(1.0, decimals=0))
    major = Multiple(denominator=4)
    ax.xaxis.set_major_formatter(major.formatter())
    fig.savefig(plot_dir / f"{agent_type}_{key}.pdf", bbox_inches="tight")
    plt.close(fig)


def normal_plot(plot_dir: Path, bins, hists, agent_type: str, key: str, label: str, logscale: bool) -> None:
    fig, ax = plt.subplots()
    sns.histplot(
        x=bins[key][:-1],
        weights=hists[f"{key}_{agent_type}"],
        ax=ax,
        bins=bins[key],
        stat="proportion",
    )
    ax.set_xlabel(label)
    if logscale:
        ax.set_yscale("log")
    else:
        ax.yaxis.set_major_formatter(ticker.PercentFormatter(1.0, decimals=0))
    fig.savefig(plot_dir / f"{agent_type}_{key}.pdf", bbox_inches="tight")
    plt.close(fig)


if __name__ == "__main__":
    main()
