import numpy as np
import torch

class random_inputs_sampler():
    def __init__(self, dataset):
        self.dataset = dataset
        num_class = max(dataset.targets)+1
        self.class_idx = {}
        for c in range(num_class):
            self.class_idx[c] = np.where(np.array(dataset.targets)==c)[0]

    def sample(self, targets):
        rand_data_list = []
        for t in targets:
            rand_idx=np.random.choice(self.class_idx[t.item()], 1, replace=False).item()
            rand_data_list.append(self.dataset[rand_idx][0].unsqueeze(dim=0))
        rand_data = torch.cat(rand_data_list, dim=0)
        return rand_data