import logging
import math
import os
from typing import Optional, Dict, Any, Sequence

from pandas import DataFrame
from tableshift.core.utils import make_uid

from rtfm.arguments import DataArguments, make_uid_from_args
from rtfm.datasets import get_task_dataset
from rtfm.datasets.tableshift_utils import (
    fetch_preprocessor_config_from_data_args,
    build_formatted_df,
    get_dataset_info,
)
from rtfm.utils import initialize_dir


def get_cached_uid(data_args: DataArguments) -> str:
    """Holds the logic used for making cached UIDs.
    This allows a change here to percolate into both cached dataset
    creation, and cached dataset loading/training.

    Note that the actual path to cached data is typically:
    {CACHE_DIR}/{EXPERIMENT_NAME}/UID/{SPLIT} ,
    for example:
    tmp/adult/{UID}/train/
    """
    return make_uid_from_args(data_args)


def cache_task(
    task: str,
    overwrite: bool,
    data_arguments: DataArguments,
    splits: Optional[Sequence[str]] = None,
    tabular_dataset_kwargs: Optional[Dict[str, Any]] = None,
):
    cache_dir = data_arguments.cache_dir
    preprocessor_config = fetch_preprocessor_config_from_data_args(data_arguments, task)
    dset = get_task_dataset(
        task,
        cache_dir=data_arguments.cache_dir,
        preprocessor_config=preprocessor_config,
        initialize_data=False,
        tabular_dataset_kwargs=tabular_dataset_kwargs,
    )

    if dset.is_cached() and (not overwrite):
        uid = make_uid(task, dset.splitter)
        logging.info(f"dataset with uid {uid} is already cached; skipping")

    else:
        dset._initialize_data()
        if splits is None:
            splits = dset.splits.keys()
        logging.info(f"caching splits {splits}")
        uid = get_cached_uid(data_arguments)
        for split in splits:
            df = build_formatted_df(
                dset._get_split_df(split), get_dataset_info(dset), data_arguments
            )
            split_dir = os.path.join(cache_dir, dset.name, uid, split)
            write_shards(df, split_dir, split=split)

    return


def write_shards(
    to_shard: DataFrame,
    dirname: str,
    file_type="parquet",
    rows_per_shard=4096,
    split="train",
):
    initialize_dir(dirname)
    num_shards = math.ceil(len(to_shard) / rows_per_shard)
    for i in range(num_shards):
        fp = os.path.join(dirname, f"{split}_{i:05d}.{file_type}")
        logging.debug("writing file to %s" % fp)
        start, end = i * rows_per_shard, (i + 1) * rows_per_shard
        shard_df = to_shard.iloc[start:end]
        if file_type == "csv":
            shard_df.to_csv(fp, index=False)
        elif file_type == "arrow":
            shard_df.reset_index(drop=True).to_feather(fp)
        elif file_type == "parquet":
            shard_df.to_parquet(fp, index=False)
