#!/usr/bin/env python3
# Adapted from https://github.com/facebookresearch/SlowFast/blob/master/slowfast/datasets/kinetics.py

import abc
import math
import os
import random
from copy import deepcopy
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional
import torch.utils.data
import torchvision.io
from ffprobe import FFProbe

from utils import utils


class BaseDataset(torch.utils.data.Dataset, abc.ABC):
    """Base dataset from which all the datasets inherit. Do not instantiate"""

    def __init__(self,
                 split,  # train, validate, test, all
                 num_frames=None,  # number of frames per video
                 video_sampling_interval=None,  # sampling interval of frames in a video
                 jitter_info=None,
                 resolution=None,  # spatial resolution
                 video_augmentations=None,  # information about video augmentations
                 num_classes=0,  # in case of classification
                 path=None,  # dataset path
                 max_videos=None,  # maximum size of dataset
                 max_clips_split=None,  # maximum size of the split. Different than max_videos, because for a
                 # given dataset split (related to max_videos) we can change the
                 # max_clips_split
                 max_subclips=None,  # maximum number of subclips per video
                 use_precomputed_splits=True,  # in case we have a .txt file for every split
                 list_splits=None,
                 dataset_name='',
                 options=None,
                 params=None,
                 info_all_splits=None,
                 restart_samples=False,
                 pca_augmentation=None,  # In case we need to do some augmentation of the data
                 extrapolate_future=False,  # Just for inference time
                 predict_interpolate=False,  # Just for inference time
                 uniform=False,  # Sample uniformly
                 uniform_interpolate=False,  # Sample uniformly
                 seed=42
                 ):

        self.dataset_path = Path(path)
        # Only support train, val, test, and "all" split.
        assert split in ["train", "validate", "test", "all"], f"Split '{split}' not supported."
        self.split = split
        self.vid_aug = video_augmentations
        self.num_frames = num_frames
        # frame sampling rate (interval, in original video frames, between two sampled frames).
        self.video_sampling_interval = video_sampling_interval
        self.jitter_info = jitter_info
        self.resolution = resolution
        self.num_classes = num_classes
        self.max_videos = max_videos
        self.max_subclips = max_subclips  # Sampled from ALL videos
        self.use_precomputed_splits = use_precomputed_splits
        self.min_scale = self.jitter_info.train_jitter_scales[0] if self.jitter_info is not None \
            else self.resolution
        self.dataset_name = dataset_name
        self.options = options
        self.params = params
        self.max_clips_split = max_clips_split
        self.pca_augmentation = pca_augmentation
        self.extrapolate_future = extrapolate_future
        self.predict_interpolate = predict_interpolate
        self.uniform = uniform
        self.uniform_interpolate = uniform_interpolate
        self.seed = seed

        if self.split in ['train', 'validate']:
            # -1 indicates random sampling.
            self.temporal_sample_index = -1
            self.spatial_sample_index = -1
            self.max_scale = self.jitter_info.train_jitter_scales[1] if self.jitter_info is not None \
                else self.resolution

        else:  # test or all
            self.temporal_sample_index = 1  # middle
            self.spatial_sample_index = 1  # middle
            self.max_scale = self.jitter_info.train_jitter_scales[0] if self.jitter_info is not None \
                else self.resolution

        self.rand_erase = self.vid_aug.re_prob > 0
        self.clip_infos = None
        self.clip_ids = None
        self.list_splits = list_splits  # List of all the splits there are

        self.proportions = {'train': 0.8, 'validate': 0.1, 'all': 1.}
        self.proportions['test'] = 1 - self.proportions['train'] - self.proportions['validate']

        if info_all_splits is None:
            path_clip_infos = self.dataset_path / self.name_info
            if restart_samples or not os.path.isfile(path_clip_infos):
                clip_infos, *other_info = self.prepare_samples()
                print('Saving dataset')
                torch.save([clip_infos, *other_info], path_clip_infos)
            else:
                print('Loading dataset')
                clip_infos, *other_info = torch.load(path_clip_infos)
                print('Dataset loaded')
            clip_ids = list(clip_infos.keys())
            last_i = 0
            clip_ids_splits = {}
            for split_ in self.proportions.keys():
                start = 0 if split_ == 'all' else last_i
                end = start + int(np.round(self.proportions[split_] * len(clip_ids)))
                end = np.minimum(end, len(clip_ids))
                if split_ == 'all':
                    start = 0
                else:
                    last_i = end
                clip_ids_splits[split_] = [clip_ids[i] for i in range(start, end)]
                if self.max_clips_split is not None and len(clip_ids_splits[split_]) > self.max_clips_split:
                    # Mostly for visualization or debugging purposes
                    # Random so that not all clips belong to the same video
                    clip_ids_splits[split_] = [clip_ids_splits[split_][i] for i in
                                               random.sample(range(len(clip_ids_splits[split_])), self.max_clips_split)]
            info_all_splits = [clip_infos, clip_ids_splits, other_info]

        self.info_all_splits = info_all_splits

        self.clip_infos = info_all_splits[0]
        if self.split == 'all':
            self.clip_ids = info_all_splits[1]['train'] + info_all_splits[1]['validate'] + info_all_splits[1]['test']
        else:
            self.clip_ids = info_all_splits[1][split]

        if not hasattr(self, 'max_steps'):
            self.max_steps = None

    @classmethod
    def _extract_segments_video(cls, clip, num_frames, video_sr):
        clip_id, path, clip_info = clip
        new_clip_ids = []
        new_clip_infos = {}

        if path.endswith('.jpg'):  # Static image
            new_clip_ids.append(clip_id)
            new_clip_infos[clip_id] = clip_info

        else:
            metadata = FFProbe(path)
            frame_rate_clip = duration_clip = 0
            for stream in metadata.streams:
                if stream.is_video():
                    duration_clip = stream.duration_seconds()
                    frame_rate_clip = stream.framerate
                    break

            if frame_rate_clip != 24:
                print(path)
                if os.path.isfile(path):
                    os.remove(path)
                return None, None
                # raise Warning(f'Clip {clip_id} with path {path} has frame rate {frame_rate_clip}, not 24')
            duration_subclips = num_frames * video_sr / frame_rate_clip
            for i in range(math.ceil(duration_clip / duration_subclips)):
                subclip_info = deepcopy(clip_info)
                subclip_info['start'] = i * duration_subclips
                subclip_info['end'] = (i + 1) * duration_subclips
                subclip_id = clip_id + f'_{i:03d}'

                new_clip_ids.append(subclip_id)
                new_clip_infos[subclip_id] = subclip_info

        return new_clip_ids, new_clip_infos

    @abc.abstractmethod
    def prepare_samples(self, **kwargs):
        raise NotImplementedError("Please Implement this method")

    @abc.abstractmethod
    def visualize_trajectories(self, saved_tensors, save_names, reconstruct_intersection, model_id, vrnn_model):
        raise NotImplementedError("Please Implement this method")

    @property
    def name_info(self):
        return 'clip_infos.pth'

    @property
    def save_info_path(self):
        """
        Path to save info about the dataset
        """
        info_path = self.dataset_path / 'data_info'
        os.makedirs(info_path, exist_ok=True)
        return info_path

    @classmethod
    def save_info_path_class(cls, path) -> str:
        """
        Used when not called from self (for example as pre-processing)
        """
        info_path = cls.get_dataset_path(path) / 'data_info'
        os.makedirs(info_path, exist_ok=True)
        return info_path

    @classmethod
    def get_dataset_path(cls, path) -> str:
        return path

    def prepare_sequence(self, positions, time_indices, temporal_noise=False, noise_seed=None, max_steps=None):
        """
        Move to tensor and zero-pad if necessary
        We pad with -100, to have an easy-to-debug value. But note that -100 can still be an acceptable value, we will
        never filter by -100; we need to use the mask all the time.
        An alternative would be to use NaN, but they are very bad at backpropagating properly.

        If temporal_noise, instead of having integer values for time indices,
        """
        assert self.max_steps is not None, 'Only for datasets where we have max_steps'

        if type(positions) != torch.Tensor and positions is not None:
            positions = torch.tensor(positions).float()
        if type(time_indices) != torch.Tensor:
            time_indices = torch.tensor(time_indices).float()
        if positions is not None:
            assert time_indices.shape[0] == positions.shape[0]

        duration = time_indices.shape[0]

        # None of the past/future alone will be as long as self.max_steps, but we give room for inputting whole
        # sequences
        if max_steps is None:
            max_steps = self.max_steps  # if not self.extrapolate_future else 2*self.max_steps
        if duration < max_steps:
            if positions is not None:
                padding_pos = torch.zeros((max_steps - duration, *positions.shape[1:])).to(positions.device)
                positions = torch.cat([positions, padding_pos - 100])
            padding_time = torch.zeros(max_steps - duration).to(time_indices.device)
            time_indices = torch.cat([time_indices, padding_time - 100])
        elif duration > max_steps:
            positions = positions[:max_steps] if positions is not None else None
            time_indices = time_indices[:max_steps]

        if temporal_noise:
            # Make the sequence start a bit later (so relative increments are the same)
            time_indices[time_indices != -100] += utils.str_to_probability(noise_seed)

        return positions, time_indices, duration

    def get_video(self, path_video, name, clip_start=0, clip_end=None):
        """
        Given the video path, return the list of frames
        Reused code from
        https://github.com/facebookresearch/pytorchvideo/blob/master/pytorchvideo/data/labeled_video_dataset.py
        :param path_video
        :param name: name of the video
        :param clip_start. Start of the video in seconds
        :param clip_end. End of video in seconds. If None, load all of it
        """
        if clip_start is None:
            clip_start = 0
        if clip_end == -1:
            clip_end = None

        video_orig, audio_orig, info = torchvision.io.read_video(path_video, clip_start, clip_end, pts_unit='sec')

        # ------------ Process video ------------- #

        # Downsample according to self.video_sampling_interval
        video_orig = video_orig[::self.video_sampling_interval]

        # Sample self.num_frames from video
        start_idx, end_idx = self.get_start_end_idx(video_orig.shape[0])
        video = video_orig[start_idx:end_idx]

        # Padding with zeros if necessary
        video_len = len(video)
        video = torch.nn.functional.pad(video, (0,) * 7 + (self.num_frames - video.shape[0],))
        video = self.augment(video)

        return video, video_len

    def get_start_end_idx(self, size_video):
        delta = max(size_video - self.num_frames, 0)
        if self.split == "train":
            start_idx = int(random.uniform(0, delta))
        else:  # Sample middle of the clip
            start_idx = math.floor(delta / 2)
        end_idx = start_idx + self.num_frames
        return start_idx, end_idx

    def __getitem__(self, index):
        return NotImplemented

    def __len__(self):
        return len(self.clip_ids)

    def augment_trajectories(self, original, j):
        raise NotImplementedError("Please Implement this method")

    def augment_time_inputs(self, original, j, total_j):
        raise NotImplementedError("Please Implement this method")

    def collate_fn(self, batch):
        if self.pca_augmentation is not None:
            assert self.pca_augmentation in ['speed', 'temporal_offset', 'spatial_flip']
            """Augment batch with similar segments for MDS experiment"""
            segment_inputs = ['past', 'future', 'all']
            time_inputs = ['time_indices_past', 'time_indices_future', 'time_indices_all']
            num_repetitions = 2 if self.pca_augmentation == 'spatial_flip' else 10
            batch_new = []
            for i, sample in enumerate(batch):
                for j in range(num_repetitions):
                    new_sample = sample.copy()
                    for inp in segment_inputs:
                        if inp in sample:
                            new_sample[inp] = self.augment_trajectories(sample[inp], j)
                    for inp in time_inputs:
                        if inp in sample:
                            new_sample[inp] = self.augment_time_inputs(sample[inp], j, num_repetitions)
                    new_sample['index'] = torch.stack([torch.tensor(sample['index']), torch.tensor(j)])
                    batch_new.append(new_sample)
            batch = batch_new
        return torch.utils.data._utils.collate.default_collate(batch)
