"""Hierarchical clustering dataset."""

import logging

import numpy as np
import torch
import torch.utils.data as data

from datasets.triples import generate_triples_all_pairs


def generate_triples_all_pairs(n_nodes, num_samples):
    num_samples = int(num_samples)
    all_nodes = np.arange(n_nodes)
    mesh = np.array(np.meshgrid(all_nodes, all_nodes))
    pairs = mesh.T.reshape(-1, 2)
    pairs = pairs[pairs[:, 0] < pairs[:, 1]]
    n_pairs = pairs.shape[0]
    if num_samples < n_pairs:
        print("Generating all pairs subset")
        subset = np.random.choice(np.arange(n_pairs), num_samples, replace=False)
        pairs = pairs[subset]
    else:
        print("Generating all pairs superset")
        k_base = int(num_samples / n_pairs)
        k_rem = num_samples - (k_base * n_pairs)
        subset = np.random.choice(np.arange(n_pairs), k_rem, replace=False)
        pairs_rem = pairs[subset]
        pairs_base = np.repeat(np.expand_dims(pairs, 0), k_base, axis=0).reshape((-1, 2))
        pairs = np.concatenate([pairs_base, pairs_rem], axis=0)
    num_samples = pairs.shape[0]
    triples = np.concatenate(
        [pairs, np.random.randint(n_nodes, size=(num_samples, 1))],
        axis=1
    )
    return triples


class HCDataset(data.Dataset):
    """Hierarchical clustering dataset."""

    def __init__(self, features, labels, similarities):
        """

        @param labels: ground truth labels
        @type labels: np.array of shape (n_datapoints,)
        @param similarities: pairwise similarities between datapoints
        @type similarities: np.array of shape (n_datapoints, n_datapoints)
        """
        self.features = features
        self.labels = labels
        self.similarities = similarities
        self.n_nodes = self.similarities.shape[0]

    def __len__(self):
        return len(self.triples)

    def __getitem__(self, idx):
        triple = self.triples[idx]
        s12 = self.similarities[triple[0], triple[1]]
        s13 = self.similarities[triple[0], triple[2]]
        s23 = self.similarities[triple[1], triple[2]]
        similarities = np.array([s12, s13, s23])
        return torch.from_numpy(triple), torch.from_numpy(similarities)

    def generate_triples(self, num_samples):
        logging.info("Generating triples.")
        triples = []
        triples = generate_triples_all_pairs(self.n_nodes, num_samples=num_samples)
        logging.info(f"Total of {triples.shape[0]} triples")
        self.triples = triples.astype("int64")

