import os
import json
import scipy.io as sio
import numpy as np
from dataset_module import MA_BBox_Dataset_488, MA_BBox_Dataset_244, MA_BBox_Dataset_3Scale
from torch.utils.data import DataLoader


""" resize to 224*n, and crop 224*224 keep the images ratio
"""
class Dataset_MA_STEP1():
    def __init__(self, opt):
        self.Image_path = 'datasets/Images/' + opt.dataset
        self.PP_split_path = 'datasets/data_PP/' + opt.dataset
        self.SP_split_path = 'datasets/standard_split/' + opt.dataset

        if opt.split == 'SP':
            self.imdb_train, self.imdb_test, \
                label_map_trainval, label_map_test_unseen = get_imdb_SP(self.PP_split_path, self.SP_split_path,
                                                                        self.Image_path, opt.dataset)
            self.num_train_sample, self.num_test_sample = len(self.imdb_train), len(self.imdb_test)
            print('Training Samples: {} Test Seen {} '.format(self.num_train_sample, self.num_test_sample))

        elif opt.split == 'PP':
            self.imdb_train, self.imdb_test_seen, self.imdb_test, \
                label_map_trainval, label_map_test_unseen = get_imdb_PP(self.PP_split_path, self.Image_path, opt.dataset)
            self.num_train_sample, self.num_test_sample, self.num_test_seen_sample = len(self.imdb_train), \
                                                                                     len(self.imdb_test), \
                                                                                     len(self.imdb_test_seen)
            print('Training Samples: {} Test Seen {} Test Unseen {}'.format(self.num_train_sample,
                                                                            self.num_test_sample,
                                                                            self.num_test_seen_sample))

        self.train_ncls = len(label_map_trainval)
        self.test_ncls = len(label_map_test_unseen)

        """ Training dataloader """
        transformed_data = MA_BBox_Dataset_244(self.imdb_train, Flip=True, RandomCrop=False)
        self.dataloader = DataLoader(transformed_data, batch_size=opt.BATCH_SIZE, shuffle=True,
                                     drop_last=True, num_workers=4)

        self.dataloader_train_for_test = DataLoader(transformed_data, batch_size=opt.BATCH_SIZE, shuffle=False,
                                                    drop_last=False, num_workers=4)

""" resize to 448*n, and crop 448*448 keep the images ratio
"""
class Dataset_MA_STEP2():
    def __init__(self, opt):
        self.Image_path = 'datasets/Images/' + opt.dataset
        self.PP_split_path = 'datasets/data_PP/' + opt.dataset
        self.SP_split_path = 'datasets/standard_split/' + opt.dataset

        if opt.split == 'SP':
            self.imdb_train, self.imdb_test, \
                label_map_trainval, label_map_test_unseen = get_imdb_SP(self.PP_split_path, self.SP_split_path,
                                                                        self.Image_path, opt.dataset)
            self.num_train_sample, self.num_test_sample = len(self.imdb_train), len(self.imdb_test)
            print('Training Samples: {} Test Seen {} '.format(self.num_train_sample, self.num_test_sample))

        elif opt.split == 'PP':
            self.imdb_train, self.imdb_test, self.imdb_test_seen, \
                label_map_trainval, label_map_test_unseen = get_imdb_PP(self.PP_split_path, self.Image_path, opt.dataset)
            self.num_train_sample, self.num_test_sample, self.num_test_seen_sample = len(self.imdb_train), \
                                                                                     len(self.imdb_test), \
                                                                                     len(self.imdb_test_seen)
            print('Training Samples: {} Test Seen {} Test Unseen {}'.format(self.num_train_sample,
                                                                            self.num_test_sample,
                                                                            self.num_test_seen_sample))
        self.train_ncls = len(label_map_trainval)
        self.test_ncls = len(label_map_test_unseen)

        """ Training dataloader """
        self.transformed_dataset = MA_BBox_Dataset_488(self.imdb_train)
        self.dataloader = DataLoader(self.transformed_dataset, batch_size=opt.BATCH_SIZE,
                                     shuffle=True, drop_last=True, num_workers=4)

        self.dataloader_train_for_test = DataLoader(self.transformed_dataset, batch_size=opt.BATCH_SIZE_TEST,
                                     shuffle=False, drop_last=False, num_workers=4)

        """ Testing dataloader, no data augmentation 
        """
        if opt.split == 'SP':
            self.transformed_dataset_test = MA_BBox_Dataset_488(self.imdb_test)
            self.dataloader_test = DataLoader(self.transformed_dataset_test, batch_size=opt.BATCH_SIZE_TEST,
                                              shuffle=False, num_workers=4)
        elif opt.split == 'PP':
            self.transformed_dataset_test = MA_BBox_Dataset_488(self.imdb_test_seen)
            self.dataloader_test_seen = DataLoader(self.transformed_dataset_test, batch_size=opt.BATCH_SIZE_TEST,
                                              shuffle=False, num_workers=4)
            self.transformed_dataset_test = MA_BBox_Dataset_488(self.imdb_test)
            self.dataloader_test_unseen = DataLoader(self.transformed_dataset_test, batch_size=opt.BATCH_SIZE_TEST,
                                              shuffle=False, num_workers=4)


class Dataset_MA_STEP3_3Scale():
    def __init__(self, opt):
        self.Image_path = 'datasets/Images/' + opt.dataset
        self.BBox_path    = 'datasets/BBoxes/'  + opt.dataset
        self.PP_split_path = 'datasets/data_PP/' + opt.dataset
        self.SP_split_path = 'datasets/standard_split/' + opt.dataset

        if opt.split == 'SP':
            self.imdb_train, self.imdb_test, \
                label_map_trainval, label_map_test_unseen = get_imdb_SP(self.PP_split_path, self.SP_split_path,
                                                                        self.Image_path, opt.dataset)
            self.bboxes = json.load(open(self._devkit_path + '/Pred_Boxes_half_SP.json', 'r'))

            self.imdb_train = self.merge_bbox2imdb(self.imdb_train, self.bboxes)
            self.imdb_test  = self.merge_bbox2imdb(self.imdb_test, self.bboxes)
            self.num_train_sample, self.num_test_sample = len(self.imdb_train), len(self.imdb_test)
            print('Training Samples: {} Test Seen {} '.format(self.num_train_sample, self.num_test_sample))

        elif opt.split == 'PP':
            self.imdb_train, self.imdb_test_seen, self.imdb_test_unseen, \
                label_map_trainval, label_map_test_unseen = get_imdb_PP(self.PP_split_path, self.Image_path, opt.dataset)
            self.bboxes = json.load(open(self.BBox_path + '/Pred_Boxes_half_PP.json', 'r'))
            self.imdb_train     = self.merge_bbox2imdb(self.imdb_train, self.bboxes)
            self.imdb_test = self.merge_bbox2imdb(self.imdb_test_unseen, self.bboxes)
            self.imdb_test_seen = self.merge_bbox2imdb(self.imdb_test_seen, self.bboxes)

            self.num_train_sample, self.num_test_sample, self.num_test_seen_sample = len(self.imdb_train), \
                                                                                     len(self.imdb_test), \
                                                                                     len(self.imdb_test_seen)
            print('Training Samples: {} Test Seen {} Test Unseen {}'.format(self.num_train_sample,
                                                                            self.num_test_sample,
                                                                            self.num_test_seen_sample))

        self.train_semantic_feat, self.test_semantic_feat = self.load_semantic_data(label_map_trainval, label_map_test_unseen)

        self.train_ncls, self.att_ndim = self.train_semantic_feat.shape
        print('ATT dimension: {}'.format(self.att_ndim))

        """ Train dataloader """
        self.transformed_dataset = MA_BBox_Dataset_3Scale(self.train_semantic_feat, self.imdb_train)
        self.dataloader_train = DataLoader(self.transformed_dataset, batch_size=opt.BATCH_SIZE,
                                     shuffle=True, drop_last=True, num_workers=4)

        self.dataloader_train_for_test = DataLoader(self.transformed_dataset, batch_size=opt.BATCH_SIZE,
                                           shuffle=False, drop_last=False, num_workers=4)
        """ Train dataloader: unseen and seen """
        self.transformed_dataset = MA_BBox_Dataset_3Scale(self.test_semantic_feat, self.imdb_test)
        self.dataloader_test = DataLoader(self.transformed_dataset, batch_size=opt.BATCH_SIZE,
                                    shuffle=False, drop_last=False, num_workers=4)
        self.transformed_dataset = MA_BBox_Dataset_3Scale(self.train_semantic_feat, self.imdb_test_seen)
        self.dataloader_test_seen = DataLoader(self.transformed_dataset, batch_size=opt.BATCH_SIZE,
                                          shuffle=False, drop_last=False, num_workers=4)



    def merge_bbox2imdb(self, roidb, bboxes):

        images = [_i['image'].split('/')[-2] + '/' + _i['image'].split('/')[-1] for _i in roidb]
        new_roidb = list()
        for i, _roidb in enumerate(roidb):
            _bbox = bboxes[images[i]]
            new_roidb.append({'label': _roidb['label'], 'image': _roidb['image'],
                               'box': _bbox['box'], 'part_boxes': _bbox['part_boxes']})
        return new_roidb



    def load_semantic_data(self, label_map_trainval, label_map_test_unseen):
        raw_data = sio.loadmat(os.path.join(self.PP_split_path, 'att_splits.mat'))
        att = raw_data['att'].transpose()
        train_text_feature = np.zeros((len(label_map_trainval), att.shape[1]), np.float32)
        test_text_feature  = np.zeros((len(label_map_test_unseen), att.shape[1]), np.float32)
        for _key in label_map_trainval.keys():
            train_text_feature[label_map_trainval[_key]] = att[_key-1]
        for _key in label_map_test_unseen.keys():
            test_text_feature[label_map_test_unseen[_key]] = att[_key-1]
        return train_text_feature, test_text_feature


# this fucntion produce a map from original label to 0-149 label.
def get_label_map(labels, loc):
    selected_label = labels[loc]
    unique_label = np.unique(selected_label)
    label_map = dict()
    for i, label in enumerate(unique_label):
        label_map[label] = i
    return label_map


def get_imdb_PP(PP_split_path, Image_path, dataset):
    raw_data = sio.loadmat(os.path.join(PP_split_path, 'res101.mat'))
    im = raw_data['image_files'].ravel()
    labels = raw_data['labels'].squeeze()
    if dataset == 'CUB':
        im_path = ['/'.join([Image_path]  + im[x][0].split('/')[5:]) for x in range(len(im))]
    elif dataset == 'AWA1':
        im_path = ['/'.join([Image_path] + ['images'] + im[x][0].split('/')[-2:]) for x in range(len(im))]
    else:
        raise ValueError('Not Implement Yet...')
    raw_data = sio.loadmat(os.path.join(PP_split_path, 'att_splits.mat'))
    trainval_loc = raw_data['trainval_loc'].squeeze() - 1  # 0-based
    label_map_trainval = get_label_map(labels, trainval_loc)
    trainval_roidb = [{'label': label_map_trainval[labels[index]], 'image': im_path[index]} for index in trainval_loc]

    # seen class test data, unseen class test data, here I keep the original label
    test_seen_loc = raw_data['test_seen_loc'].squeeze() - 1  # 0-based
    test_unseen_loc = raw_data['test_unseen_loc'].squeeze() - 1  # 0-based
    label_map_test_unseen = get_label_map(labels, test_unseen_loc)
    test_seen_roidb = [{'label': label_map_trainval[labels[index]], 'image': im_path[index]} for index in test_seen_loc]
    test_unseen_roidb = [{'label': label_map_test_unseen[labels[index]], 'image': im_path[index]} for index in test_unseen_loc]
    return trainval_roidb, test_seen_roidb, test_unseen_roidb, label_map_trainval, label_map_test_unseen


def get_imdb_SP(PP_split_path, SP_split_path, Image_path, dataset):
    raw_data = sio.loadmat(os.path.join(PP_split_path, 'res101.mat'))
    im = raw_data['image_files'].ravel()
    labels = raw_data['labels'].squeeze()
    if dataset == 'CUB':
        im_path = ['/'.join([Image_path] + im[x][0].split('/')[5:]) for x in range(len(im))]
    else:
        raise ValueError('Not Implement Yet...')
    raw_data = sio.loadmat(os.path.join(SP_split_path, 'att_splits.mat'))
    trainval_loc = raw_data['trainval_loc'].squeeze() - 1  # 0-based
    label_map_trainval = get_label_map(labels, trainval_loc)
    trainval_roidb = [{'label': label_map_trainval[labels[index]], 'image': im_path[index]} for index in
                      trainval_loc]

    # seen class test data, unseen class test data, here I keep the original label
    test_unseen_loc = raw_data['test_unseen_loc'].squeeze() - 1  # 0-based
    label_map_test_unseen = get_label_map(labels, test_unseen_loc)
    test_unseen_roidb = [{'label': label_map_test_unseen[labels[index]], 'image': im_path[index]} for index in
                         test_unseen_loc]
    return trainval_roidb, test_unseen_roidb, label_map_trainval, label_map_test_unseen