from torch.utils.data import Dataset
from torchvision import transforms
import torch

import numpy as np
from PIL import Image
import PIL


###################################################
# crop the centor part from 448*n (keep image ratio) to 448*448
###################################################
class MA_BBox_Dataset_488(Dataset):
    def __init__(self, imdb):
        self._imdb = imdb
        self.transform = transforms.Compose([
            transforms.Resize((448, 448)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self._imdb)

    def __getitem__(self, idx):
        imdb = self._imdb[idx]
        img = self._pil_loader(imdb['image'])
        m, n = img.size
        if m > n:
            m = m * 448 / n
            n = 448
        else:
            n = n * 448 / m
            m = 448
        point1 = (int((n-448)/2), int((m-448)/2))

        img = img.resize((int(m), int(n)), resample=PIL.Image.BILINEAR)
        img = img.crop([point1[1], point1[0], point1[1]+448, point1[0]+448])
        im_blob = self.transform(img)

        label_blob = imdb['label'] * np.ones((1), dtype=np.int64)
        label_blob = label_blob.squeeze()

        blobs = {'raw_imgs': np.array(img),
                 'imgs': im_blob,
                 'labels': label_blob,
                 'img_path': imdb['image']}
        return blobs

    def _pil_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')


###################################################
# crop the centor part from 224*n (keep image ratio) to 224*224
###################################################
class MA_BBox_Dataset_244(Dataset):
    def __init__(self, imdb, Flip=True, RandomCrop=False):
        self._imdb = imdb
        trans_list = list()
        if RandomCrop:
            trans_list.append(transforms.RandomResizedCrop(size=224,
                                  scale=(0.4, 1.0), ratio=(0.75, 1.333)))
        else:
            trans_list.append(transforms.Resize((224, 224)))
        if Flip:
            trans_list.append(transforms.RandomHorizontalFlip())
        # comman transformation
        trans_list += [transforms.ToTensor(),
                       transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])]
        self.transform = transforms.Compose(trans_list)

    def __len__(self):
        return len(self._imdb)

    def __getitem__(self, idx):
        imdb = self._imdb[idx]
        img = self._pil_loader(imdb['image'])
        m, n = img.size
        if m > n:
            m = m * 224 / n
            n = 224
        else:
            n = n * 224 / m
            m = 224
        point1 = (int((n-224)/2), int((m-224)/2))

        img = img.resize((int(m), int(n)), resample=PIL.Image.BILINEAR)
        img = img.crop([point1[1], point1[0], point1[1]+224, point1[0]+224])
        im_blob = self.transform(img)

        label_blob = imdb['label'] * np.ones((1), dtype=np.int64)
        label_blob = label_blob.squeeze()

        blobs = {'imgs': im_blob,
                 'labels': label_blob}
        return blobs

    def _pil_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

###################################################
# Total Raw Image with Pedicted bbox MA-CNN
###################################################
class MA_BBox_Dataset_3Scale(Dataset):
    def __init__(self, semantic_feat, roidb):
        self._roidb = roidb
        self.semantic_feat = semantic_feat
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.transform_centerCrop = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self._roidb)

    def __getitem__(self, idx):
        roidb = self._roidb[idx]
        im_org = self._pil_loader(roidb['image'])
        bbox = roidb['box']
        im_body = im_org.crop(bbox)
        im_part1 = im_org.crop(roidb['part_boxes'][0])
        im_part2 = im_org.crop(roidb['part_boxes'][1])

        im_org = self.transform_centerCrop(im_org)
        im_body = self.transform(im_body)
        im_part1 = self.transform(im_part1)
        im_part2 = self.transform(im_part2)

        label_blob = roidb['label'] * np.ones((1), dtype=np.int64)
        label_blob = label_blob.squeeze()
        semantic_data = self.semantic_feat[label_blob]
        blobs = {'im_org': im_org,
                 'im_body': im_body,
                 'im_part1': im_part1,
                 'im_part2': im_part2,
                 'labels': label_blob,
                 'semantic_data': semantic_data}
        return blobs

    def _pil_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')
