import copy
import os
from collections import Counter, defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, Generator, List, Optional, Tuple

import geopandas as gpd
import numpy as np
import pandas as pd
import shapely
from tqdm import tqdm, trange
from trajdata import MapAPI, UnifiedDataset, VectorMap
from trajdata.caching import EnvCache, SceneCache
from trajdata.caching.df_cache import DataFrameCache
from trajdata.data_structures import Scene
from trajdata.maps.vec_map_elements import MapElementType


def process_scene(scene_path: Path, cache_path: Path):
    scene: Scene = EnvCache.load(scene_path)

    agent_meta_dict = {}
    if "lyft" in scene.env_name or "waymo" in scene.env_name:
        for agent in scene.agents:
            agent_meta_dict[agent.name] = {
                "type": agent.type.name,
            }
    else:
        for agent in scene.agents:
            agent_meta_dict[agent.name] = {
                "length": agent.extent.length,
                "width": agent.extent.width,
                "type": agent.type.name,
            }

    agent_type_map = {k: v["type"] for k, v in agent_meta_dict.items()}

    if "waymo" not in scene.env_name and "sdd" not in scene.env_name and "eupeds" not in scene.env_name:
        map_api = MapAPI(cache_path)
        vec_map: VectorMap = map_api.get_map(
            f"{scene.env_name}:{scene.location}",
            incl_road_lanes=True,
            incl_road_areas=True,
            incl_ped_crosswalks=False,
            incl_ped_walkways=False,
        )

        if MapElementType.ROAD_AREA in vec_map.elements:
            map_polygons = np.empty(len(vec_map.elements[MapElementType.ROAD_AREA]), dtype="object")
            map_polygons[:] = [
                shapely.Polygon(
                    shell=shapely.LinearRing(road_area.exterior_polygon.xy),
                    holes=[shapely.LinearRing(poly.xy) for poly in road_area.interior_holes],
                )
                for road_area in vec_map.elements[MapElementType.ROAD_AREA].values()
            ]
        else:
            map_polygons = shapely.polygons(
                geometries=[
                    shapely.LinearRing(np.concatenate((road_lane.left_edge.xy, road_lane.right_edge.xy[::-1]), axis=0))
                    for road_lane in vec_map.elements[MapElementType.ROAD_LANE].values()
                    if road_lane.left_edge is not None and road_lane.right_edge is not None
                ]
            )

    # if MapElementType.ROAD_AREA in vec_map.elements:
    #     map_polygons = shapely.polygons(
    #             [
    #                 shapely.LinearRing(road_area.exterior_polygon.xy)
    #                 for road_area in vec_map.elements[MapElementType.ROAD_AREA].values()
    #             ]
    #         ,
    #         holes=[
    #             [shapely.LinearRing()]
    #             if len(road_area.interior_holes) == 0
    #             else [shapely.LinearRing(poly.xy) for poly in road_area.interior_holes]
    #             for road_area in vec_map.elements[MapElementType.ROAD_AREA].values()
    #         ]
    #     )
    # else:
    #     map_polygons = shapely.polygons(
    #         geometries=[shapely.LinearRing(
    #             np.concatenate((road_lane.left_edge.xy, road_lane.right_edge.xy[::-1]), axis=0)
    #         ) for road_lane in vec_map.elements[MapElementType.ROAD_LANE].values()]
    #     )

    columns_to_read: List[str]
    if "lyft" in scene.env_name or "waymo" in scene.env_name:
        columns_to_read = ["agent_id", "scene_ts", "x", "y", "ax", "ay", "heading", "length", "width"]
    else:
        columns_to_read = ["agent_id", "scene_ts", "x", "y", "ax", "ay", "heading"]

    scene_dir = SceneCache.scene_cache_dir(cache_path, scene.env_name, scene.name)
    scene_df: pd.DataFrame = pd.read_feather(
        scene_dir / DataFrameCache._agent_data_file(scene.dt),
        use_threads=False,
        columns=columns_to_read,
    ).set_index(["agent_id", "scene_ts"])

    # Path Efficiency
    df_groups = scene_df[["x", "y"]].groupby("agent_id")
    traj_endpoints: pd.DataFrame = df_groups.last() - df_groups.first()
    traj_path_length: pd.DataFrame = df_groups.diff().abs().groupby("agent_id").sum()

    traj_endpoints["norm"] = np.linalg.norm(traj_endpoints, axis=1)
    traj_path_length["norm"] = np.linalg.norm(traj_path_length, axis=1)

    not_stationary = np.logical_and(traj_endpoints["norm"] > 0.0, traj_path_length["norm"] > 0.0)
    path_efficiency = traj_endpoints.loc[not_stationary, "norm"] / traj_path_length.loc[not_stationary, "norm"]
    path_efficiency = path_efficiency.set_axis(
        pd.MultiIndex.from_arrays(
            (path_efficiency.index, path_efficiency.index.map(agent_type_map)), names=["agent_id", "agent_type"]
        ),
        copy=False,
    )

    num_not_stationary = np.sum(not_stationary).item()
    num_total = len(not_stationary)
    num_stationary = num_total - num_not_stationary

    path_efficiency_bins = np.linspace(0, 1.01, 102)
    path_efficiency_hists = dict()
    for agent_type in path_efficiency.index.get_level_values("agent_type").unique():
        path_efficiency_hists[agent_type] = np.histogram(path_efficiency.loc[:, agent_type], bins=path_efficiency_bins)[
            0
        ]

    # Agent Density
    scene_df["agent_type"] = scene_df.index.get_level_values("agent_id").map(agent_type_map)
    time_type_group = scene_df.groupby(["scene_ts", "agent_type"])

    agent_count = time_type_group["x"].count()
    x_range = time_type_group["x"].apply(np.ptp)
    y_range = time_type_group["y"].apply(np.ptp)
    areas = x_range * y_range
    agent_dens = agent_count / areas
    agent_density_hists = dict()
    agent_density_bins = np.logspace(-4, 0, 21)
    for agent_type in agent_dens.index.get_level_values("agent_type").unique():
        agent_density_hists[agent_type] = np.histogram(agent_dens.loc[:, agent_type], bins=agent_density_bins)[0]

    # Harsh Vehicle Accelerations
    scene_df["acc"] = np.sqrt(scene_df.ax**2 + scene_df.ay**2)
    max_veh_acc = scene_df[scene_df["agent_type"] == "VEHICLE"].groupby("agent_id")["acc"].max()
    one_g = 9.81
    max_acc_bins = np.linspace(0, one_g, 21)
    max_acc_hist = np.histogram(max_veh_acc, bins=max_acc_bins)[0]
    num_veh_acc_triggers = np.sum(max_veh_acc >= 0.4 * one_g).item()  # 0.4 g

    # Offroad + Collisions
    agent_ids = scene_df.index.get_level_values("agent_id")
    xys = scene_df[["x", "y"]].to_numpy()
    hs = scene_df["heading"].to_numpy()
    if "lyft" in scene.env_name or "waymo" in scene.env_name:
        lengths = scene_df["length"].to_numpy()
        widths = scene_df["width"].to_numpy()
    else:
        lengths = np.array([agent_meta_dict[a]["length"] for a in agent_ids])
        widths = np.array([agent_meta_dict[a]["width"] for a in agent_ids])

    rotation_matrices = np.array([[np.cos(hs), -np.sin(hs)], [np.sin(hs), np.cos(hs)]]).transpose(2, 0, 1)[:, None]

    half_lengths = lengths / 2
    half_widths = widths / 2
    coords = np.array(
        [
            [-half_lengths, -half_widths],
            [half_lengths, -half_widths],
            [half_lengths, half_widths],
            [-half_lengths, half_widths],
        ]
    ).transpose(2, 0, 1)[..., None]

    polygon_pts: np.ndarray = xys[:, None, :, None] + np.matmul(rotation_matrices, coords)
    polygons = shapely.polygons(polygon_pts.squeeze(-1))

    gdf = gpd.GeoDataFrame({"geometry": polygons}).set_index(scene_df.index)

    if "waymo" not in scene.env_name and "sdd" not in scene.env_name and "eupeds" not in scene.env_name:
        offroad_points: gpd.GeoDataFrame = gpd.sjoin(
            gpd.GeoDataFrame(geometry=gpd.points_from_xy(xys[:, 0], xys[:, 1])).set_index(scene_df.index),
            gpd.GeoDataFrame(geometry=map_polygons),
            how="left",
        )

        # Eliminating any multiple hits (which yield duplicate indices).
        offroad_mask = pd.isna(offroad_points.loc[~offroad_points.index.duplicated(), "index_right"])

        offroad_agent_ids = agent_ids[offroad_mask].unique()
        offroad_agent_types = Counter(offroad_agent_ids.map(agent_type_map))

        # TODO: Count the amount that an agent is offroad insted of just all of it?
        num_offroad_positions = sum(offroad_mask)
    else:
        offroad_agent_types = None
        num_offroad_positions = None

    num_positions = len(scene_df)

    scene_collided_agent_ids = set()
    for scene_ts in scene_df.index.get_level_values("scene_ts").unique():
        timestep_df = gdf.loc[pd.IndexSlice[:, scene_ts], :].reset_index(level=1, drop=True)

        overlapping_agents = gpd.sjoin(timestep_df, timestep_df, predicate="intersects", how="inner")
        overlapping_agents = overlapping_agents[overlapping_agents.index != overlapping_agents.index_right]
        scene_collided_agent_ids = scene_collided_agent_ids.union(set(overlapping_agents.index.values))

    num_collisions = len(scene_collided_agent_ids)
    num_total_agents = len(scene.agents)

    assert num_total_agents == num_total

    agent_classes = {agent_meta["type"]: 0 for agent_meta in agent_meta_dict.values()}
    collided_agent_classes = copy.deepcopy(agent_classes)
    for agent_id, agent_meta in agent_meta_dict.items():
        agent_type = agent_meta["type"]
        agent_classes[agent_type] += 1
        if agent_id in scene_collided_agent_ids:
            collided_agent_classes[agent_type] += 1

    return (
        num_collisions,
        num_total_agents,
        num_stationary,
        agent_classes,
        collided_agent_classes,
        offroad_agent_types,
        num_offroad_positions,
        num_positions,
        path_efficiency_hists,
        path_efficiency_bins,
        agent_density_hists,
        agent_density_bins,
        max_acc_hist,
        max_acc_bins,
        num_veh_acc_triggers,
        # inter_agent_dists, Closest distances between agent types?
    )


def parallel_analyze_collisions_offroad(
    dataset: UnifiedDataset,
    max_workers: int = os.cpu_count(),
) -> Generator[Dict, None, None]:
    max_scenes = dataset.num_scenes()

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = []
        for scene_path in tqdm(dataset._scene_index, desc="Queueing Analysis Tasks"):
            futures.append(executor.submit(process_scene, scene_path, dataset.cache_path))

        results = []
        for future in tqdm(
            as_completed(futures),
            desc=f"Running Failure Analyses ({max_workers} CPUs)",
            total=max_scenes,
        ):
            results.append(future.result())

    result_dict = defaultdict(list)
    for (
        num_collisions,
        num_total_agents,
        num_stationary,
        agent_classes,
        collided_agent_classes,
        offroad_agent_types,
        num_offroad_positions,
        num_positions,
        path_efficiency_hists,
        path_efficiency_bins,
        agent_density_hists,
        agent_density_bins,
        max_acc_hist,
        max_acc_bins,
        num_veh_acc_triggers,
    ) in results:
        result_dict["num_collisions"].append(num_collisions)
        result_dict["num_total_agents"].append(num_total_agents)
        result_dict["num_stationary"].append(num_stationary)
        result_dict["agent_classes"].append(agent_classes)
        result_dict["collided_agent_classes"].append(collided_agent_classes)

        result_dict["offroad_agent_types"].append(offroad_agent_types)
        result_dict["num_offroad_positions"].append(num_offroad_positions)
        result_dict["num_positions"].append(num_positions)

        result_dict["path_efficiency_hist"].append(path_efficiency_hists)
        result_dict["path_efficiency_bins"] = path_efficiency_bins
        result_dict["agent_density_hist"].append(agent_density_hists)
        result_dict["agent_density_bins"] = agent_density_bins
        result_dict["max_acc_hist"].append(max_acc_hist)
        result_dict["max_acc_bins"] = max_acc_bins
        result_dict["num_veh_acc_triggers"].append(num_veh_acc_triggers)

    return result_dict
