import numpy as np


def create_subsamples():
    """ Subsample to avoid bias by image selection. """
    data = np.load('5m.npz')
    images, labels = data['image'], data['label']
    num_classes = 10  # for CIFAR-10
    for num_samples, file_name in [(5000, '50k'), (10000, '100k'), (20000, '200k'), (50000, '500k'), (100000, '1m')]:
        indices = np.full(labels.shape[0], False)
        step = labels.shape[0] // num_classes
        for i in range(num_classes):
            indices[i * step:i * step + num_samples] = True
        sel_images, sel_labels = images[indices], labels[indices]
        print(images.shape, labels.shape)
        np.savez('{}.npz'.format(file_name), image=sel_images, label=sel_labels)


def check_subsamples():
    """ Verify that certain assumptions remain valid. """
    for num_samples, file_name in [(5000, '50k'), (10000, '100k'), (20000, '200k'), (50000, '500k'), (100000, '1m')]:
        data = np.load('{}.npz'.format(file_name))
        images, labels = data['image'], data['label']
        num_classes = 10  # for CIFAR-10
        assert images.shape == (num_classes * num_samples, 32, 32, 3)
        assert labels.shape == (num_classes * num_samples,)
        reference = []
        for i in range(num_classes):
            reference.append(np.full((num_samples,), i))
        reference = np.concatenate(reference)
        assert np.all(labels == reference)


def main():
    create_subsamples()
    check_subsamples()


if __name__ == '__main__':
    main()

