"""
This file contains imdb datasets things: 
class MovieReviewsDataset and class Gpt2ClassificationCollator
"""

import io
import os
import torch
import time # to do add time t0 and t1, then dt = t1-t0 in train loop
from torch.utils.data import Dataset

'''
pytorch dataset needs 3 components: 
- init() where we read in data and transform text into numbers
- len() where we return the number of examples
- getitem() given an int, returns the example indexed at that position
'''

'''3 parts of GPT2 we need to use:
- config (GPTConfig)
- tokenizer (GPT2Tokenizer)
- model (GPT2ForSequenceClassification)
'''

class MovieReviewsDataset(Dataset):
  """PyTorch Dataset class for loading data.
  This is where the data parsing happens.
  Arguments:
    path (:obj:`str`):
        Path to the data partition.
  """

  def __init__(self, path, use_tokenizer):

    if not os.path.isdir(path):
      raise ValueError('Invalid path! Needs to be a directory')
    self.texts = []
    self.labels = []
    # Since the labels are defined by folders with data we loop 
    # through each label.
    for label in ['pos', 'neg']:
      sentiment_path = os.path.join(path, label)

      # Get all files from path.
      files_names = os.listdir(sentiment_path)
      # Go through each file and read its content.
      for file_name in files_names:
        file_path = os.path.join(sentiment_path, file_name)

        # Read content.
        content = io.open(file_path, mode='r', encoding='utf-8').read()
        self.texts.append(content)
        self.labels.append(label)

    # Number of exmaples.
    self.n_examples = len(self.labels)
    return

  def __len__(self):
    """return the number of examples.
    """
    return self.n_examples


  def __getitem__(self, item):
    """Given an index return an example from the position.
    Arguments:
      item (:obj:`int`):
          Index position to pick an example to return
    Returns:
      :obj:`Dict[str, str]`: Dictionary of inputs that contain text and 
      asociated labels.
    """

    return {'text':self.texts[item],
            'label':self.labels[item]}


class MovieReviewsDataset_simple(Dataset):
  """PyTorch Dataset class for loading data.
  This is where the data parsing happens.
  Arguments:
    path (:obj:`str`):
        Path to the data partition.
  """

  def __init__(self, path, use_tokenizer):

    if not os.path.isdir(path):
      raise ValueError('Invalid path! Needs to be a directory')
    self.texts = []

    sentiment_path = path

    # Get all files from path.
    files_names = os.listdir(sentiment_path)
    # Go through each file and read its content.
    for file_name in files_names:
      file_path = os.path.join(sentiment_path, file_name)

      # Read content.
      content = io.open(file_path, mode='r', encoding='utf-8').read()
      self.texts.append(content)

    # Number of exmaples.
    self.n_examples = len(self.texts)
    return

  def __len__(self):
    """return the number of examples.
    """
    return self.n_examples


  def __getitem__(self, item):
    """Given an index return an example from the position.
    Arguments:
      item (:obj:`int`):
          Index position to pick an example to return
    Returns:
      :obj:`Dict[str, str]`: Dictionary of inputs that contain text and 
      asociated labels.
    """

    return {'text':self.texts[item]}

# used to modify the data as you go, e.g. tokenize, truncate, etc.
# can apply this to each example in the dataset
# takes the text data and converts it to token IDs, uses model's tokenizer
# this applies tokenizer to the data, to get input into right format for the model to recognize it
class Gpt2ClassificationCollator(object):
    """
    Data Collator used for GPT2 in a classificaitn 
    Uses given encoder to convert any text and labels to numbers that can go into GPT

    Arguments:
      use_tokenizer (:obj:`transformers.tokenization_?`):
          Transformer type tokenizer used to process raw text into numbers.
      labels_ids (:obj:`dict`):
          Dictionary to encode any labels names into numbers. Keys map to 
          labels names and Values map to number associated to those labels.
      max_sequence_len (:obj:`int`, `optional`)
          Value to indicate the maximum desired sequence to truncate or pad text
          sequences. If no value is passed it will used maximum sequence size
          supported by the tokenizer and model.
    """

    def __init__(self, use_tokenizer, labels_encoder, max_sequence_len=None):

        # Tokenizer to be used inside the class.
        self.use_tokenizer = use_tokenizer
        # Check max sequence length.
        self.max_sequence_len = use_tokenizer.model_max_length if max_sequence_len is None else max_sequence_len
        # Label encoder used inside the class.
        self.labels_encoder = labels_encoder

        return

    def __call__(self, sequences):
        """
        This function allowes the class objesct to be used as a function call.
        Sine the PyTorch DataLoader needs a collator function, I can use this 
        class as a function.

        Arguments:
          item (:obj:`list`):
              List of texts and labels.
        Returns:
          :obj:`Dict[str, object]`: Dictionary of inputs that feed into the model.
          It holddes the statement `model(**Returned Dictionary)`.
        """

        # Get all texts from sequences list.
        texts = [sequence['text'] for sequence in sequences]
        # Get all labels from sequences list.
        labels = [sequence['label'] for sequence in sequences]
        # Encode all labels using label encoder.
        labels = [self.labels_encoder[label] for label in labels]
        # Call tokenizer on all texts to convert into tensors of numbers with appropriate padding.
        # tokenizer returns number IDs (tokenizer outputs token IDs we feed these IDs into word embedding layer of GPT2 embeddings block)
        inputs = self.use_tokenizer(text=texts, return_tensors="pt", padding=True, truncation=True,  max_length=self.max_sequence_len)
        # Update the inputs with the associated encoded labels as tensor.
        inputs.update({'labels':torch.tensor(labels)})

        return inputs


class Gpt2ClassificationCollator_simple(object):
    """
    Data Collator used for GPT2 in a classificaitn 
    Uses given encoder to convert any text and labels to numbers that can go into GPT

    Arguments:
      use_tokenizer (:obj:`transformers.tokenization_?`):
          Transformer type tokenizer used to process raw text into numbers.
      labels_ids (:obj:`dict`):
          Dictionary to encode any labels names into numbers. Keys map to 
          labels names and Values map to number associated to those labels.
      max_sequence_len (:obj:`int`, `optional`)
          Value to indicate the maximum desired sequence to truncate or pad text
          sequences. If no value is passed it will used maximum sequence size
          supported by the tokenizer and model.
    """

    def __init__(self, use_tokenizer, labels_encoder, max_sequence_len=None):

        # Tokenizer to be used inside the class.
        self.use_tokenizer = use_tokenizer
        # Check max sequence length.
        self.max_sequence_len = use_tokenizer.model_max_length if max_sequence_len is None else max_sequence_len
        # Label encoder used inside the class.
        self.labels_encoder = labels_encoder

        return

    def __call__(self, sequences):
        """
        This function allowes the class objesct to be used as a function call.
        Sine the PyTorch DataLoader needs a collator function, I can use this 
        class as a function.

        Arguments:
          item (:obj:`list`):
              List of texts and labels.
        Returns:
          :obj:`Dict[str, object]`: Dictionary of inputs that feed into the model.
          It holddes the statement `model(**Returned Dictionary)`.
        """

        # Get all texts from sequences list.
        texts = [sequence['text'] for sequence in sequences]
        # # Get all labels from sequences list.
        # labels = [sequence['label'] for sequence in sequences]
        # # Encode all labels using label encoder.
        # labels = [self.labels_encoder[label] for label in labels]
        # Call tokenizer on all texts to convert into tensors of numbers with appropriate padding.
        # tokenizer returns number IDs (tokenizer outputs token IDs we feed these IDs into word embedding layer of GPT2 embeddings block)
        inputs = self.use_tokenizer(text=texts, return_tensors="pt", padding=True, truncation=True,  max_length=self.max_sequence_len)
        # Update the inputs with the associated encoded labels as tensor.
        #inputs.update({'labels':torch.tensor(labels)})

        return inputs