from typing import Optional

import numpy as np
import torch
from torch_geometric.transforms import BaseTransform

from temporal_graph.data import TemporalData


class TemporalSplit(BaseTransform):
    def __init__(
        self,
        val_ratio: float = 0.15,
        test_ratio: float = 0.15,
        key: Optional[str] = "t",
    ):
        self.val_ratio = val_ratio
        self.test_ratio = test_ratio
        self.key = key

    def forward(self, data: TemporalData) -> TemporalData:
        key = self.key
        t = data[key].sort().values
        val_ratio = self.val_ratio
        test_ratio = self.test_ratio
        val_time, test_time = np.quantile(
            t.cpu().numpy(), [1. - val_ratio - test_ratio, 1. - test_ratio])
        data.train_mask = data[key] < val_time
        data.val_mask = torch.logical_and(data[key] >= val_time, data[key]
                                          < test_time)
        data.test_mask = data[key] >= test_time
        return data

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.val_ratio}, '
                f'{self.test_ratio})')
