import json
from typing import Any, Dict, List
import random
import os

import torch
from torch.utils.data import Dataset
import numpy as np
import _pickle as cPickle

from pytorch_pretrained_bert.tokenization import BertTokenizer
from ._image_features_reader import ImageFeaturesH5Reader
import jsonlines
import sys

import pdb
def assert_eq(real, expected):
    assert real == expected, "%s (true) vs %s (expected)" % (real, expected)

def _load_annotations(annotations_jsonpath):
    """Build an index out of FOIL annotations, mapping each image ID with its corresponding captions."""

    with jsonlines.open(annotations_jsonpath) as reader:

        # Build an index which maps image id with a list of caption annotations.
        entries = []
        imgid2entry = {}
        count = 0

        for annotation in reader:
            image_id = annotation['id']
            imgid2entry[image_id] = []
            for sentences in annotation['sentences']:
                entries.append({"caption": sentences, 'image_id':image_id})
                imgid2entry[image_id].append(count)
                count += 1

    return entries, imgid2entry


class COCORetreivalDatasetTrain(Dataset):
    def __init__(
        self,
        split: str,
        annotations_jsonpath: str,
        image_features_reader: ImageFeaturesH5Reader,
        tokenizer: BertTokenizer,
        padding_index: int = 0,
        max_caption_length: int = 20,
    ):
        # All the keys in `self._entries` would be present in `self._image_features_reader`

        self._entries, self.imgid2entry = _load_annotations(annotations_jsonpath)
        self.image_id_list = [*self.imgid2entry]

        self._image_features_reader = image_features_reader
        self._tokenizer = tokenizer

        self._padding_index = padding_index
        self._max_caption_length = max_caption_length

        # image_info = cPickle.load(open('data/cocoRetreival/hard_negative.pkl', 'rb'))
        # for key, value in image_info.items():
        #     setattr(self, key, value)
        # self.train_imgId2pool = {imageId:i for i, imageId in enumerate(self.train_image_list)}

        # cache file path data/cache/train_ques
        cap_cache_path = "data/cocoRetreival/cache/train_cap_" + split + ".pkl"
        if not os.path.exists(cap_cache_path):
            self.tokenize()
            self.tensorize()
            cPickle.dump(self._entries, open(cap_cache_path, 'wb'))
        else:
            print('loading entries from %s' %(cap_cache_path))
            self._entries = cPickle.load(open(cap_cache_path, "rb"))

    def tokenize(self):
        """Tokenizes the captions.

        This will add caption_tokens in each entry of the dataset.
        -1 represents nil, and should be treated as padding_idx in embedding.
        """
        for entry in self._entries:
            sentence_tokens = self._tokenizer.tokenize(entry["caption"])
            sentence_tokens = ["[CLS]"] + sentence_tokens + ["[SEP]"]

            tokens = [
                self._tokenizer.vocab.get(w, self._tokenizer.vocab["[UNK]"])
                for w in sentence_tokens
            ]
            tokens = tokens[:self._max_caption_length]
            segment_ids = [0] * len(tokens)
            input_mask = [1] * len(tokens)

            if len(tokens) < self._max_caption_length:
                # Note here we pad in front of the sentence
                padding = [self._padding_index] * (self._max_caption_length - len(tokens))
                tokens = tokens + padding
                input_mask += padding
                segment_ids += padding

            assert_eq(len(tokens), self._max_caption_length)
            entry["token"] = tokens
            entry["input_mask"] = input_mask
            entry["segment_ids"] = segment_ids

    def tensorize(self):

        for entry in self._entries:
            token = torch.from_numpy(np.array(entry["token"]))
            entry["token"] = token

            input_mask = torch.from_numpy(np.array(entry["input_mask"]))
            entry["input_mask"] = input_mask

            segment_ids = torch.from_numpy(np.array(entry["segment_ids"]))
            entry["segment_ids"] = segment_ids


    def __getitem__(self, index):
        entry = self._entries[index]
        image_id = entry["image_id"]

        features, num_boxes, boxes, _ = self._image_features_reader[image_id]
        image_mask = [1] * (int(num_boxes))

        while len(image_mask) < 37:
            image_mask.append(0)

        features = torch.tensor(features).float()
        image_mask = torch.tensor(image_mask).long()
        spatials = torch.tensor(boxes).float()

        caption = entry["token"]
        input_mask = entry["input_mask"]
        segment_ids = entry["segment_ids"]

        return features, spatials, image_mask, caption, input_mask, segment_ids

    def __len__(self):
        return len(self._entries)


def _load_annotationsVal(annotations_jsonpath):
    """Build an index out of FOIL annotations, mapping each image ID with its corresponding captions."""
    with jsonlines.open(annotations_jsonpath) as reader:

        # Build an index which maps image id with a list of caption annotations.
        image_entries = {}
        caption_entries = []
        target_entries = {}

        for annotation in reader:
            image_id = annotation['id']
            image_entries[image_id] = 1

            for sentences in annotation['sentences']:
                caption_entries.append({"caption": sentences, 'image_id':image_id})

    image_entries = [*image_entries]

    return image_entries, caption_entries


class COCORetreivalDatasetVal(Dataset):
    def __init__(
        self,
        annotations_jsonpath: str,
        image_features_reader: ImageFeaturesH5Reader,
        tokenizer: BertTokenizer,
        padding_index: int = 0,
        max_caption_length: int = 20,
    ):
        # All the keys in `self._entries` would be present in `self._image_features_reader`

        self._image_entries, self._caption_entries = _load_annotationsVal(annotations_jsonpath)
        self._image_features_reader = image_features_reader
        self._tokenizer = tokenizer

        self._padding_index = padding_index
        self._max_caption_length = max_caption_length

        # cache file path data/cache/train_ques
        # cap_cache_path = "data/cocoRetreival/cache/val_cap.pkl"
        # if not os.path.exists(cap_cache_path):
        self.tokenize()
        self.tensorize()
            # cPickle.dump(self._entries, open(cap_cache_path, 'wb'))
        # else:
            # print('loading entries from %s' %(cap_cache_path))
            # self._entries = cPickle.load(open(cap_cache_path, "rb"))
# 
        self.features_all = np.zeros((1000, 37, 2048))
        self.spatials_all = np.zeros((1000, 37, 5))
        self.image_mask_all = np.zeros((1000, 37))

        for i, image_id in enumerate(self._image_entries):
            features, num_boxes, boxes, _ = self._image_features_reader[image_id]
            image_mask = [1] * (int(num_boxes))

            while len(image_mask) < 37:
                image_mask.append(0)

            self.features_all[i] = features
            self.image_mask_all[i] = np.array(image_mask)
            self.spatials_all[i] = boxes

            sys.stdout.write('%d/%d\r' % (i, len(self._image_entries)))
            sys.stdout.flush()

        self.features_all = torch.Tensor(self.features_all).float()
        self.image_mask_all = torch.Tensor(self.image_mask_all).long()
        self.spatials_all = torch.Tensor(self.spatials_all).float()

    def tokenize(self):
        """Tokenizes the captions.

        This will add caption_tokens in each entry of the dataset.
        -1 represents nil, and should be treated as padding_idx in embedding.
        """
        for entry in self._caption_entries:
            sentence_tokens = self._tokenizer.tokenize(entry["caption"])
            sentence_tokens = ["[CLS]"] + sentence_tokens + ["[SEP]"]

            tokens = [
                self._tokenizer.vocab.get(w, self._tokenizer.vocab["[UNK]"])
                for w in sentence_tokens
            ]
            tokens = tokens[:self._max_caption_length]
            segment_ids = [0] * len(tokens)
            input_mask = [1] * len(tokens)

            if len(tokens) < self._max_caption_length:
                # Note here we pad in front of the sentence
                padding = [self._padding_index] * (self._max_caption_length - len(tokens))
                tokens = tokens + padding
                input_mask += padding
                segment_ids += padding

            assert_eq(len(tokens), self._max_caption_length)
            entry["token"] = tokens
            entry["input_mask"] = input_mask
            entry["segment_ids"] = segment_ids

    def tensorize(self):
        for entry in self._caption_entries:
            token = torch.from_numpy(np.array(entry["token"]))
            entry["token"] = token

            input_mask = torch.from_numpy(np.array(entry["input_mask"]))
            entry["input_mask"] = input_mask

            segment_ids = torch.from_numpy(np.array(entry["segment_ids"]))
            entry["segment_ids"] = segment_ids

    def __getitem__(self, index):

        # we iterate through every caption here.
        caption_idx = int(index / 2)
        image_idx = index % 2

        if image_idx == 0:
            image_entries = self._image_entries[:500]
            features_all = self.features_all[:500]
            spatials_all = self.spatials_all[:500]
            image_mask_all = self.image_mask_all[:500]

        else:
            image_entries = self._image_entries[500:]
            features_all = self.features_all[500:]
            spatials_all = self.spatials_all[500:]
            image_mask_all = self.image_mask_all[500:]

        entry = self._caption_entries[caption_idx]
        caption = entry["token"]
        input_mask = entry["input_mask"]
        segment_ids = entry["segment_ids"]

        target_all = torch.zeros(500)
        for i, image_id in enumerate(image_entries):
            if image_id == entry["image_id"]:
                target_all[i] = 1

        return features_all, spatials_all, image_mask_all, caption, input_mask, segment_ids, target_all, caption_idx, image_idx

    def __len__(self):
        return len(self._caption_entries) * 2
