import chunk
import os
import time
from typing import Dict, List, Tuple, Union

import h5py
import numpy as np
import pandas as pd
import torch
import utils
from data.generate_speedy import PARAM_SPACES
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset
from tqdm import tqdm


############################## wind simulation ##############################
class SpeedyWeatherDataset(Dataset):
    def __init__(
        self,
        data_path: str,
        model_name: str = "ShallowWaterModel",
        num_simulations: int = 100,
        mode: str = "train",
        num_views=5,
        include_keys=["u", "v", "vor"],
        shared_ids: List[Union[int, List]] = None,
        factor_sharing: Dict[int, List[int]] = None,
        collate_style: str = "default",
        chunk_size=60,
    ) -> None:
        super().__init__()
        self.data_dir = os.path.join(data_path, model_name)
        os.makedirs(self.data_dir, exist_ok=True)
        self.mode = mode
        assert mode in ["train", "val", "test"]
        # assert data_generation_scheme in ["fixed_graph", "discrete", "continuous"]

        assert num_views > 1, "number of views must be greater than 1"
        self.num_views = num_views

        self.simulation_ids = range(num_simulations)
        self.include_keys = include_keys  # , "vor"]

        self.num_simulations = num_simulations
        self.initialized = False
        # self.data_generation_scheme = data_generation_scheme
        self.param_spaces = {kk: vv for k, v in PARAM_SPACES.items() for kk, vv in v.items()}
        # initialize the shared indices (the underlying causal graph)
        # sharing_ids are defined w.r.t. the first view
        self.collate_stype = collate_style
        assert collate_style in ["default", "random"]

        if collate_style == "default":
            self.shared_ids = shared_ids
            self.factor_sharing = factor_sharing
            assert shared_ids is not None, "shared_ids for the dataset must be provided"
            assert factor_sharing is not None, "factor_sharing for the dataset must be provided"
            self.collate_fn = self.default_collate_fn

        if collate_style == "random":
            assert self.num_views == 2, "random_collate_fn is only implemented for 2 views"
            self.collate_fn = self.random_collate_fn

        if "ShallowWaterModel" in model_name:
            self.param_spaces = {kk: vv for kk, vv in self.param_spaces.items() if kk == "layer_thickness"}
        if "PrimitiveWetModel" in model_name:
            self.param_spaces = {kk: vv for kk, vv in self.param_spaces.items() if kk != "layer_thickness"}

        self.chunk_size = chunk_size

        self.__preprocess_data__()

    def __preprocess_data__(self):
        for i in tqdm(self.simulation_ids):
            file_path = os.path.join(self.data_dir, f"run_{i+1:04d}/output.nc")
            if not os.path.exists(file_path):
                raise FileNotFoundError(f"File {file_path} does not exist.")
            ds = h5py.File(file_path, "r")

            if not self.initialized:
                self.lat_dim = np.asarray(ds["lat"][:]).shape[0]
                self.lon_dim = np.asarray(ds["lon"][:]).shape[0]
                self.level_dim = np.asarray(ds["lev"][:]).shape[0]
                self.time_steps = np.asarray(ds["time"][:]).shape[0]
                self.data = {k: np.asarray(ds[k][:]) for k in ds.keys()}

                for k in self.include_keys:
                    self.data[k] = [self.data[k]]
                # voriticity: shape [time, level, lat, lon]
            else:
                for k in self.include_keys:
                    # stack the states
                    self.data[k] += [np.asarray(ds[k][:])]  # shape [time, level, lat, lon]
            self.initialized = True

        for k in self.include_keys:
            self.data[k] = np.stack(self.data[k], axis=0)
            self.data[k] = (
                StandardScaler()
                .fit_transform(self.data[k].reshape(-1, 1))
                .reshape(self.data[k].shape)
                .transpose(1, 0, 2, 3, 4)
            )  # [time, num_sim, level, lat, lon]

        self.time_steps = self.data[self.include_keys[0]].shape[0]
        self.time_indices = np.arange(self.time_steps - self.chunk_size, self.time_steps)  # last chunk unseen

    def __len__(self):
        return self.num_simulations * self.level_dim * self.lat_dim * self.lon_dim  # num_sim*lag*lon*level

    def __sample_location__(self):
        sampled_lev = np.random.randint(self.level_dim)
        sampled_lat = np.random.randint(self.lat_dim)
        sampled_lon = np.random.randint(self.lon_dim)
        return (sampled_lev, sampled_lat, sampled_lon)

    def __retrieve_item__(self, simulation_index: int, location: Tuple[int]):
        traj = np.stack(
            [self.data[k][:, simulation_index, *location] for k in self.include_keys],
            axis=-1,
        )

        param_file = os.path.join(self.data_dir, f"run_{self.simulation_ids[0]+simulation_index+1:04d}/parameters.txt")
        param_dict = utils.parse_text_to_nested_dict(param_file)
        param = {k: v.scale(float(param_dict[k])) for k, v in self.param_spaces.items()}
        return param, traj

    def __getitem__(self, index: int):
        multi_dim_ind = np.unravel_index(index, (self.num_simulations, self.level_dim, self.lat_dim, self.lon_dim))
        simulation_index = multi_dim_ind[0]
        params, trajectory = self.__retrieve_item__(simulation_index, multi_dim_ind[1:])
        if self.mode != "val":
            if self.chunk_size < self.time_steps and self.mode == "train":
                time_index = np.random.randint(self.time_steps - 2 * self.chunk_size)
            else:
                time_index = 0
            return {
                "index": simulation_index,
                "time_index": time_index,
                "location": multi_dim_ind[1:],
                "gt_params": params,  # n_views, batch_size, 4
                "states": trajectory[time_index : time_index + self.chunk_size],
            }
        else:
            return {
                "index": simulation_index,
                "time_index": self.time_indices[0],
                "location": multi_dim_ind[1:],
                "gt_params": params,  # n_views, batch_size, 4
                "states": trajectory[self.time_indices],
            }  # index of the location (i,j) and (i',j'

    def __get_augmented_view__(self, args, **kwargs):
        raise NotImplementedError

    def default_collate_fn(self, batch: List[Dict]):
        simulation_ids = [[] for _ in range(self.num_views)]
        locations = [[] for _ in range(self.num_views)]
        states = [[] for _ in range(self.num_views)]
        params = {k: [[] for _ in range(self.num_views)] for k in self.param_spaces.keys()}

        for b in batch:
            simulation_index, location, _, _ = b["index"], b["location"], b["gt_params"], b["states"]
            simulation_ids[0] += [simulation_index]
            locations[0] += [location]
            states[0] += [b["states"]]
            for k in params:
                params[k][0] += [b["gt_params"][k]]

            for i, shared_indices in enumerate(self.shared_ids):
                sampled_simulation_index = self.__get_augmented_view__(simulation_index, shared_indices, b["gt_params"])
                sampled_location = self.__sample_location__()  #
                simulation_ids[i + 1] += [sampled_simulation_index]
                locations[i + 1] += [sampled_location]
                sampled_params, sampled_traj = self.__retrieve_item__(sampled_simulation_index, sampled_location)
                states[i + 1] += [sampled_traj[b["time_index"] : b["time_index"] + self.chunk_size]]

                for k in params:
                    params[k][i + 1] += [sampled_params[k]]

        for k in params:
            params[k] = np.stack(params[k])

        batch_dict = {
            "shared_index": self.shared_ids,
            "index": simulation_ids,  # n_views, batch_size # simulation index
            "location": locations,  # n_views, batch_size, 3
            "gt_params": params,  # n_views, batch_size, 3
            "states": torch.from_numpy(np.stack(states)),  # n_views, batch_size, time, 3
        }
        return batch_dict

    def random_collate_fn(self, batch: List[Dict]):
        assert self.num_views == 2, "random_collate_fn is only implemented for 2 views"
        simulation_ids = [[] for _ in range(self.num_views)]
        locations = [[] for _ in range(self.num_views)]
        states = [[] for _ in range(self.num_views)]
        params = {k: [[] for _ in range(self.num_views)] for k in self.param_spaces.keys()}

        shared_ids = [np.random.choice(len(self.grid_size), size=len(self.grid_size) - 1, replace=False)]

        for b in batch:
            simulation_index, location, _, _ = b["index"], b["location"], b["gt_params"], b["states"]
            simulation_ids[0] += [simulation_index]
            locations[0] += [location]
            states[0] += [b["states"]]
            for k in params:
                params[k][0] += [b["gt_params"][k]]

            for i, shared_indices in enumerate(shared_ids):
                sampled_simulation_index = self.__get_augmented_view__(simulation_index, shared_indices, b["gt_params"])
                sampled_location = location  # self.__sample_location__()
                simulation_ids[i + 1] += [sampled_simulation_index]
                locations[i + 1] += [sampled_location]
                sampled_params, sampled_traj = self.__retrieve_item__(sampled_simulation_index, sampled_location)
                states[i + 1] += [sampled_traj]

                for k in params:
                    params[k][i + 1] += [sampled_params[k]]

        for k in params:
            params[k] = np.stack(params[k])

        batch_dict = {
            "shared_index": shared_ids[0],
            "index": simulation_ids,  # n_views, batch_size # simulation index
            "location": locations,  # n_views, batch_size, 3
            "gt_params": params,  # n_views, batch_size, 3
            "states": torch.from_numpy(np.stack(states)),  # n_views, batch_size, time, 3
        }
        return batch_dict

    def sample_augmented_trajectory(self, simulation_index, location, aug_location=True, aug_simulation=True):
        assert aug_location or aug_simulation, "at least choose one augmentation method."
        if aug_simulation:
            if simulation_index % 2 == 0:
                sampled_simulation_index = simulation_index + 1
            else:
                sampled_simulation_index = simulation_index - 1
        else:
            sampled_simulation_index = simulation_index

        if aug_location:
            sampled_location = self.__sample_location__()
        else:
            sampled_location = location
        aug_traj = np.stack(
            (
                self.data["u"][:, sampled_simulation_index, *sampled_location],
                self.data["v"][:, sampled_simulation_index, *sampled_location],
                self.data["vor"][:, sampled_simulation_index, *sampled_location],
            ),
            axis=-1,
        )

        aug_param_file = os.path.join(
            self.data_dir, f"run_{self.simulation_ids[0]+simulation_index+1:04d}/parameters.txt"
        )
        aug_param_dict = utils.parse_text_to_nested_dict(aug_param_file)
        aug_param = {k: float(aug_param_dict[k]) for k, v in self.param_spaces.items()}
        return sampled_simulation_index, sampled_location, aug_param, aug_traj

    def collate_fn(self, batch):
        simulation_ids = [[], [], [], []]
        locations = [[], [], [], []]
        states = [[], [], [], []]
        params = [[], [], [], []]

        args = [(True, False), (False, True), (True, True)]

        # batch: list of dictionary
        for b in batch:
            simulation_index, location, _, _ = b["simulation_index"], b["location"], b["params"], b["states"]

            simulation_ids[0] += [simulation_index]
            locations[0] += [location]
            states[0] += [b["states"]]
            params[0] += [b["params"]]
            for i, settings in enumerate(args):
                sampled_simulation_index, sampled_location, aug_param, aug_traj = self.sample_augmented_trajectory(
                    simulation_index, location, *settings
                )
                simulation_ids[i + 1] += [sampled_simulation_index]
                locations[i + 1] += [sampled_location]
                states[i + 1] += [aug_traj]
                params[i + 1] += [aug_param]

        batch_dict = {
            "simulation_index": simulation_ids,  # n_views, batch_size
            "location": locations,  # n_views, batch_size, 3
            "params": params,  # n_views, batch_size, 3
            "states": torch.from_numpy(np.stack(states)),  # n_views, batch_size, time, 3
        }

        return batch_dict


class SpeedyWeatherDiscreteDataset(SpeedyWeatherDataset):
    def __init__(
        self,
        data_path: str,
        model_name: str = "ShallowWaterModel/discrete_small",
        num_simulations: int = 16,
        mode: str = "train",
        num_views=5,
        include_keys=["u", "v", "vor"],
        shared_ids: List[Union[int, List]] = None,
        factor_sharing: Dict[int, List[int]] = None,
        collate_style: str = "default",
        grid_size: List[int] = [2] * 4,
        chunk_size=30,
    ) -> None:
        super().__init__(
            data_path=data_path,
            model_name=model_name,
            num_simulations=num_simulations,
            mode=mode,
            num_views=num_views,
            include_keys=include_keys,
            shared_ids=shared_ids,
            factor_sharing=factor_sharing,
            collate_style=collate_style,
            chunk_size=chunk_size,
        )
        self.grid_size = grid_size  # defined while generating code
        assert len(grid_size) == len(self.param_spaces)
        param_samples = {
            k: np.linspace(v.min_, v.max_, self.grid_size[i]) for i, (k, v) in enumerate(self.param_spaces.items())
        }
        PARAM_GRID = np.stack(np.meshgrid(*[list(v) for v in param_samples.values()], indexing="ij"), axis=-1).reshape(
            -1, len(param_samples)
        )
        param_sample_values = list(param_samples.values())
        self.params = np.stack(
            [np.searchsorted(param_sample_values[i], PARAM_GRID[:, i]) for i in range(PARAM_GRID.shape[-1])], -1
        )

    def __retrieve_item__(self, simulation_index: int, location: Tuple[int]):
        traj = np.stack(
            [self.data[k][:, simulation_index, *location] for k in self.include_keys],
            axis=-1,
        )
        param = {k: self.params[simulation_index][i] for i, k in enumerate(self.param_spaces.keys())}
        return param, traj

    def __get_augmented_view__(self, simulation_index, shared_indices, *args, **kwargs):
        multi_dim_index = np.unravel_index(simulation_index, self.grid_size)
        aug_multi_dim_index = [
            np.random.choice(np.delete(np.arange(self.grid_size[i]), multi_dim_index[i]))
            if i not in shared_indices
            else multi_dim_index[i]
            for i in range(len(self.grid_size))
        ]
        return np.ravel_multi_index(aug_multi_dim_index, self.grid_size)





class ShallowWaterDiscreteDataset(SpeedyWeatherDiscreteDataset):
    def __init__(
        self,
        data_path: str ,
        model_name: str = "ShallowWaterModel/discrete_small",
        num_simulations: int = 16,
        mode: str = "train",
        num_views=5,
        include_keys=["u", "v", "vor"],
        shared_ids: List[Union[int, List]] = None,
        factor_sharing: Dict[int, List[int]] = None,
        collate_style: str = "default",
        grid_size: List[int] = [2] * 4,
        chunk_size=30,
    ) -> None:
        super().__init__(
            data_path=data_path,
            model_name=model_name,
            num_simulations=num_simulations,
            mode=mode,
            num_views=num_views,
            include_keys=include_keys,
            shared_ids=shared_ids,
            factor_sharing=factor_sharing,
            collate_style=collate_style,
            grid_size=grid_size,
            chunk_size=chunk_size,
        )
    
    
    
    def default_collate_fn(self, batch: List[Dict]):
        assert self.num_views == 3, "default_collate_fn here is only implemented for 3 views"
        simulation_ids = [[] for _ in range(self.num_views)]
        locations = [[] for _ in range(self.num_views)]
        states = [[] for _ in range(self.num_views)]
        params = {k: [[] for _ in range(self.num_views)] for k in self.param_spaces.keys()}
        time_indices = []

        for b in batch:
            time_indices += [b["time_index"]]
            simulation_index, location, _, _ = b["index"], b["location"], b["gt_params"], b["states"]
            simulation_ids[0] += [simulation_index]
            locations[0] += [location]
            states[0] += [b["states"]]
            for k in params:
                params[k][0] += [b["gt_params"][k]]

            # share layer thickness but from another location
            for i in range(1, self.num_views):
                if i == 1:
                    sampled_simulation_index = simulation_index
                    sampled_location = self.__sample_location__()  #
                if i == 2:
                    sampled_simulation_index = np.random.choice(np.delete(np.arange(self.grid_size[0]), simulation_index))
                    sampled_location = location
                simulation_ids[i] += [sampled_simulation_index]
                locations[i] += [sampled_location]
                sampled_params, sampled_traj = self.__retrieve_item__(sampled_simulation_index, sampled_location)
                states[i] += [sampled_traj[b["time_index"] : b["time_index"] + self.chunk_size]]

                for k in params:
                    params[k][i] += [sampled_params[k]]

        for k in params:
            params[k] = np.stack(params[k])

        batch_dict = {
            "time_indices": time_indices,
            "shared_index": self.shared_ids,
            "index": simulation_ids,  # n_views, batch_size # simulation index
            "location": locations,  # n_views, batch_size, 3
            "gt_params": params,  # n_views, batch_size, 3
            "states": torch.from_numpy(np.stack(states)),  # n_views, batch_size, time, 3
        }
        return batch_dict
    
    


