import numpy as np
import torch
import torch.utils.data as Data
import torchvision.transforms as transforms


class simple_dataset(Data.Dataset):
    def __init__(self, X: torch.Tensor, Y: torch.Tensor, transform=None):
        self.X = X
        self.Y = Y
        self.transform = transform

    def __getitem__(self, index: int):
        X = self.X[index]
        if self.transform is not None:
            X = self.transform(X)
        Y = self.Y[index]
        return X, Y

    def __len__(self):
        return self.X.shape[0]


def tinyimagenet_dataset(data_root='./data'):
    data = np.load('%s/tiny200.npz' % data_root)
    trainX = torch.from_numpy(data['trainX']).permute(0, 3, 1, 2)
    trainY = torch.from_numpy(data['trainY'])
    valX = torch.from_numpy(data['valX']).permute(0, 3, 1, 2)
    valY = torch.from_numpy(data['valY'])

    trainX = trainX.float().div_(255.)
    valX = valX.float().div_(255.)

    transform_train = transforms.Compose([
        transforms.RandomCrop(64, padding=8),
        transforms.RandomHorizontalFlip()
    ])

    trainset = simple_dataset(trainX, trainY, transform_train)
    testset = simple_dataset(valX, valY)
    return trainset, testset
