import os
from typing import Tuple

import cv2
import torch
import torch.utils.data as Data
import torchvision.transforms as transforms


def parse_imagenet(data_root: str,
                   num_classes: int = 1000,
                   seed: int = 2023) -> Tuple[Tuple, Tuple]:
    torch.manual_seed(seed)
    selected_classes = torch.randperm(1000)[:num_classes]
    selected_classes = set(selected_classes.tolist())
    lookup = {}
    for idx, label in enumerate(selected_classes):
        lookup[label] = idx

    data_lists = []
    for mode in 'train', 'val':
        data_file = f'{data_root}/imagenet/meta/{mode}.txt'
        data_list = []
        with open(data_file) as f:
            lines = f.readlines()
            for line in lines:
                path, label = line.split()
                label = int(label)
                if label not in lookup:
                    continue
                path = f'{data_root}/imagenet/{mode}/{path}'
                path = path.replace('//', '/')
                record = (path, lookup[label])
                data_list.append(record)
        data_lists.append(tuple(data_list))
    return tuple(data_lists)


class imagenet_loader(Data.Dataset):
    def __init__(self, datalist: Tuple, mode: bool = True):
        assert mode in ['train', 'val']

        self.datalist = datalist

        if mode == 'train':
            self.transform = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
            ])

    def __getitem__(self, index):
        path, label = self.datalist[index]
        image = self.read_image(path)
        image = self.transform(image)
        image = image.float().div_(255)
        return image, label

    def __len__(self):
        return len(self.datalist)

    def read_image(self, path):
        image = cv2.imread(path)
        image = torch.from_numpy(image)
        if image.ndim == 2:  # gray image
            image = image.expand(3, -1, -1)
        else:  # rgb image
            image = image.permute(2, 0, 1)
        return image


def imagenet_dataset(num_classes: int, data_root: str, seed: int = 2023):
    train_list, val_list = parse_imagenet(data_root, num_classes, seed)
    train_dataset = imagenet_loader(train_list, mode='train')
    val_dataset = imagenet_loader(val_list, mode='val')
    return train_dataset, val_dataset
