import os
import numpy as np
import torch
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torch.utils import data
try:
    from .utils import TransformTwice
except ImportError:
    from utils import TransformTwice

class ImageNet100(ImageFolder):
    def __init__(self, root, transform=None, target_transform=None, target_list=list(range(100))):
        super(ImageNet100, self).__init__(root, transform, target_transform)

        self.transform = transform
        self.target_transform = target_transform

        self.imgs = []
        self.targets = []
        for p, t in self.samples:
            if t in target_list:
                self.imgs.append(self.loader(os.path.join(root, p)))
                self.targets.append(t)


    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        img = self.imgs[index]
        target = self.targets[index]

        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        if isinstance(target, torch.Tensor):
            target = target.clone().detach().long()
        else:
            target = torch.tensor(target).long()
        return img, target, index


def ImageNet100Data(root, split='train', aug=None, target_list=list(range(100))):
    if aug == None:
        transform = T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    elif aug == 'once':
        transform = T.Compose([
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    elif aug == 'twice':
        transform = TransformTwice(
            T.Compose([
                T.RandomResizedCrop(224),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        )
    data_root = os.path.join(root, 'train') if split == 'train' else os.path.join(root, 'val')
    return ImageNet100(data_root, transform=transform, target_transform=None, target_list=target_list)

def ImageNet100Loader(root, batch_size, split='train', num_workers=2, aug=None, shuffle=True, target_list=list(range(100))):
    dataset = ImageNet100Data(root, split=split, aug=aug, target_list=target_list)
    loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, drop_last=split=='train')
    return loader


def ImageNet100LoaderMix(root, batch_size, split='train', num_workers=2, aug=None, shuffle=True, labeled_list=list(range(100)), unlabeled_list=list(range(100))):
    dataset_labeled = ImageNet100Data(root, split=split, aug=aug, target_list=labeled_list)
    dataset_unlabeled = ImageNet100Data(root, split=split, aug=aug, target_list=unlabeled_list)

    dataset_labeled.targets = np.concatenate((dataset_labeled.targets, dataset_unlabeled.targets))
    dataset_labeled.imgs = dataset_labeled.imgs + dataset_unlabeled.imgs

    loader = data.DataLoader(dataset_labeled, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)

    return loader
