import pdb
from collections import defaultdict
import json
import os
import pickle
import zipfile

import numpy as np
from PIL import Image, ImageFile

import torch
from torchvision import datasets as t_datasets
from tqdm import tqdm

import random
import glob
from dataset.dataset_util import get_paths

def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


class FSCDataset(torch.utils.data.Dataset):
    def __init__(self, root, split, transform=None, ret_path=False):
        self.transform = transform

        anno_file = os.path.join(root, 'annotation_FSC147_384.json')
        data_split_file = os.path.join(root, 'Train_Test_Val_FSC_147.json')
        im_dir = os.path.join(root, 'images_384_VarV2')

        with open(anno_file) as f:
            self.annotations = json.load(f)

        with open(data_split_file) as f:
            self.data_split = json.load(f)

        self.im_ids = self.data_split[split]

        self.paths = []
        self.labels = []
        for idx in self.im_ids:
            path = os.path.join(im_dir, idx)
            self.paths.append(path)
            anno = np.array(self.annotations[idx]['points']).shape[0]
            self.labels.append(anno)

        self.ret_path = ret_path

    def __len__(self):
        return len(self.im_ids)

    def __getitem__(self, idx):
        im_idx = self.im_ids[idx]
        im_path, label = self.paths[idx], self.labels[idx]
        img = pil_loader(im_path)

        if self.transform is not None:
            img = self.transform(img)
        if not self.ret_path:
            return img, label
        else:
            return img, label, im_path


