import torch.utils.data as data
import os
import json
from PIL import Image

__all__ = ['VWWDataset']


class VWWDataset(data.Dataset):
    def __init__(self, dataset_dir='/dataset/coco/', split='minival', transform=None):
        assert split == 'minival'  # we used minival for evaluation
        self.dataset_dir = dataset_dir
        self.split = split

        anno_path = './dataset/instances_visualwakewords_{}2014.json'.format(split)
        with open(anno_path) as f:
            labels = json.load(f)  # keys: 'images', 'annotations', 'categories'

        file_names = [l['file_name'] for l in labels['images']]

        self.data_pairs = []
        self.bboxes = {}
        for f in file_names:
            image_id = f.split('.')[0].split('_')[-1]
            image_id = str(int(image_id))
            annotation = labels['annotations'][image_id]
            assert len(annotation) == 1
            annotation = annotation[0]
            this_l = annotation['label']
            self.data_pairs.append((f, this_l))
            if this_l:
                self.bboxes[f] = annotation['object']  # list of dict, {'bbox': [a, b, c, d], 'area', 'category_id': 1}

        self.transform = transform
        self.split = self.split.replace('mini', '')  # correct coco filename

    def __getitem__(self, index):
        f_name, label = self.data_pairs[index]

        img = Image.open(os.path.join(self.dataset_dir, '{}2014'.format(self.split), f_name)).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.data_pairs)

