#
# Data Processing for different data sets.
# Currently only deals with dense representations.
import numpy as np
import abc
import tensorflow as tf
import h5py
import pandas as pd
import os
import gc
import sys


class DataProcessorBase(object):
    """Abstract class for data processing.
    Exposes a getData() which returns the data as
    a train test split.
    @todo lab1


    """
    __metaclass__  = abc.ABCMeta

    @abc.abstractmethod
    def getData(self, **kwargs):
        '''
        Returns train_X, train_y, test_X, test_Y
        X as n x num_features
        Y as n x -1
        '''
        return "Should not execute"


class STCIDataMIL(DataProcessorBase):
    '''
    WARNING: This is a time series data set and hence X is a 3D tensor

    Returns data that is split, shuffled and normalized.
    '''

    def __init__(self):
        pass

    def __setRandomSeeds(self, seed):
        # Setting up the random seeds
        np.random.seed(seed)
        tf.set_random_seed(seed)

    def getData(self, combPrefix, trainSplit=0.8, seed=None):
        '''
        Note that the data is already shuffled, there is no need
        for reshuffling
        '''
        if seed is not None:
            self.__setRandomSeeds(seed)
        X = np.load(combPrefix + '_X.npy')
        Y = np.load(combPrefix + '_Y.npy')
        fileList = np.load(combPrefix + '_files.npy')
        trainIndex = int(trainSplit * len(X))
        valIndex = int((trainSplit + (1-trainSplit)/2) * len(X))
        x_train = X[:trainIndex]
        x_val = X[trainIndex:valIndex]
        y_train = Y[:trainIndex]
        y_val = Y[trainIndex:valIndex]
        x_test = X[valIndex:]
        y_test = Y[valIndex:]
        files_train = fileList[:trainIndex]
        files_val = fileList[trainIndex:valIndex]
        files_test = fileList[valIndex:]

        return x_train, y_train, x_val, y_val, files_train, files_val

    def getTest(self, combPrefix, trainSplit=0.8, seed=None):
        if seed is not None:
            self.__setRandomSeeds(seed)
        X = np.load(combPrefix + '_X.npy')
        Y = np.load(combPrefix + '_Y.npy')
        fileList = np.load(combPrefix + '_files.npy')
        trainIndex = int(trainSplit * len(X))
        valIndex = int((trainSplit + (1-trainSplit)/2) * len(X))
        x_train = X[:trainIndex]
        x_val = X[trainIndex:valIndex]
        y_train = Y[:trainIndex]
        y_val = Y[trainIndex:valIndex]
        x_test = X[valIndex:]
        y_test = Y[valIndex:]
        files_train = fileList[:trainIndex]
        files_val = fileList[trainIndex:valIndex]
        files_test = fileList[valIndex:]

        return x_test, y_test, files_test

    def getCombined(self, combPrefix, seed=None):
        x_train, y_train, x_val, y_val, files_train, files_val = self.getData(combPrefix, seed=seed)
        x_combined = np.concatenate((x_train, x_val), axis=0)
        y_combined = np.concatenate((y_train, y_val), axis=0)
        files_combined = None
        if files_val is not None and files_train is not None:
            files_combined = np.concatenate((files_train, files_val), axis=0)
        x_combined, y_combined, files_combined = shuffle(x_combined, y_combined, files_combined=files_combined, seed=seed)
        return x_combined, y_combined, x_val, y_val, files_combined, files_val

class GoogleDataMIL(DataProcessorBase):
    def __init__(self):
        pass

    def __setRandomSeeds(self, seed):
        np.random.seed(seed)
        tf.set_random_seed(seed)

    def getData(self, combPrefix, seed=None):
        '''
        Note that data is already shuffled and
        train test split made
        '''
        if seed is not None:
            self.__setRandomSeeds(seed)
        x_train = np.load(combPrefix + '/' + '_train_X.npy')
        y_train = np.load(combPrefix + '/' + '_train_Y.npy')
        files_train = np.load(combPrefix + '/' + '_train_files.npy')
        x_val = np.load(combPrefix + '/' + '_val_X.npy')
        y_val = np.load(combPrefix + '/' + '_val_Y.npy')
        files_val = np.load(combPrefix + '/' + '_val_files.npy')
        print("Note that the number->word maping for class labeles is "
               'provided in the feature extractor', file=sys.stderr)
        return x_train, y_train, x_val, y_val, files_train, files_val

    def getTest(self, combPrefix, seed=None):
        '''
        Unfortunately, there was a mixup and test and train was used for
        model tuning. Final accuracy reported on validation.
        '''
        if seed is not None:
            self.__setRandomSeeds(seed)
        x_test = np.load(combPrefix + '/' + '_test_X.npy')
        y_test = np.load(combPrefix + '/' + '_test_Y.npy')
        files_test = np.load(combPrefix + '/' + '_test_files.npy')
        print("Note that the number->word maping for class labeles is "
               'provided in the feature extractor', file=sys.stderr)
        return x_test, y_test, files_test

    def getCombined(self, combPrefix, seed=None):
        ret_ = self.getData(combPrefix, seed=seed)
        x_train, y_train, x_val, y_val, files_train, files_val = ret_
        x_combined = np.concatenate((x_train, x_val), axis=0)
        y_combined = np.concatenate((y_train, y_val), axis=0)
        files_combined = None
        if files_val is not None and files_train is not None:
            files_combined = np.concatenate((files_train, files_val), axis=0)
        ret_ = shuffle(x_combined, y_combined,
                       files_combined=files_combined, seed=seed)

        x_combined, y_combined, files_combined = ret_
        return x_combined, y_combined, x_val, y_val, files_combined, files_val


class HARDataMIL(DataProcessorBase):
    '''
    WARNING: This is a time series data set and hence X is a 3D tensor

    Returns data that is split, shuffled and normalized.
    '''

    def __init__(self):
        pass

    def __setRandomSeeds(self, seed):
        # Setting up the random seeds
        np.random.seed(seed)
        tf.set_random_seed(seed)

    def getData(self, combPrefix, trainSplit=0.8, seed=None):
        '''
        Note that the data is already shuffled, there is no need
        for reshuffling
        '''
        # combPrefix should be /home/t-chpabb/UCI HAR Dataset
        if seed is not None:
            self.__setRandomSeeds(seed)
        x_train = np.load(combPrefix + '/' + 'x_train.npy')
        y_train = np.load(combPrefix + '/' + 'y_train.npy')
        assert x_train.shape[0] == y_train.shape[0]
        x_val = np.load(combPrefix + '/' + 'x_val.npy')
        y_val = np.load(combPrefix + '/' + 'y_val.npy')
        assert x_val.shape[0] == y_val.shape[0]
        return x_train, y_train, x_val, y_val, None, None

    def getTest(self, combPrefix, seed=None):
        if seed is not None:
            self.__setRandomSeeds(seed)
        x_test = np.load(combPrefix + '/' + 'x_test.npy')
        y_test = np.load(combPrefix + '/' + 'y_test.npy')
        files_test = None
        return x_test, y_test, None

    def getCombined(self, combPrefix, seed=None):
        ret_ = self.getData(combPrefix, seed=seed)
        x_train, y_train, x_val, y_val, files_train, files_val = ret_
        x_combined = np.concatenate((x_train, x_val), axis=0)
        y_combined = np.concatenate((y_train, y_val), axis=0)
        files_combined = None
        if files_val is not None and files_train is not None:
            files_combined = np.concatenate((files_train, files_val), axis=0)
        ret_ = shuffle(x_combined, y_combined,
                       files_combined=files_combined, seed=seed)
        x_combined, y_combined, files_combined = ret_
        return x_combined, y_combined, x_val, y_val, files_combined, files_val


class SmartstickDataMIL(DataProcessorBase):
    '''
    WARNING: This is a time series data set and hence X is a 3D tensor

    Returns data that is split, shuffled and normalized.
    '''

    def __init__(self):
        pass

    def __setRandomSeeds(self, seed):
        # Setting up the random seeds
        np.random.seed(seed)
        tf.set_random_seed(seed)

    def getData(self, combPrefix, trainSplit=0.8, seed=None, excludeNoise=False):
        '''
        Note that the data is already shuffled, there is no need
        for reshuffling
        '''
        if seed is not None:
            self.__setRandomSeeds(seed)
        x_train = np.load(combPrefix + '/' + 'x_train.npy')
        y_train = np.load(combPrefix + '/' + 'y_train.npy')
        assert x_train.shape[0] == y_train.shape[0]

        x_val = np.load(combPrefix + '/' + 'x_val.npy')
        y_val = np.load(combPrefix + '/' + 'y_val.npy')
        assert x_val.shape[0] == y_val.shape[0]

        # Comment out this portion if you want noise
        if excludeNoise:
            nonNoiseIndices = []
            for i in range(len(y_train)):
                if np.argmax(y_train[i]) != 0:
                    nonNoiseIndices.append(i)
            nonNoiseIndices = np.array(nonNoiseIndices)
            x_train = x_train[nonNoiseIndices]
            y_train = y_train[nonNoiseIndices, ..., :-1]

            nonNoiseIndices = []
            for i in range(len(y_val)):
                if np.argmax(y_val[i]) != 0:
                    nonNoiseIndices.append(i)
            nonNoiseIndices = np.array(nonNoiseIndices)
            x_val = x_val[nonNoiseIndices]
            y_val = y_val[nonNoiseIndices, ..., :-1]


        return x_train, y_train, x_val, y_val, None, None

    def getTest(self, combPrefix, seed=None, excludeNoise=False):
        if seed is not None:
            self.__setRandomSeeds(seed)
        x_test = np.load(combPrefix + '/' + 'x_test.npy')
        y_test = np.load(combPrefix + '/' + 'y_test.npy')
        if excludeNoise:
            nonNoiseIndices = []
            for i in range(len(y_test)):
                if np.argmax(y_test[i]) != 0:
                    nonNoiseIndices.append(i)
            nonNoiseIndices = np.array(nonNoiseIndices)
            x_test = x_test[nonNoiseIndices]
            y_test = y_test[nonNoiseIndices,...,:-1]

        return x_test, y_test, None

    def getCombined(self, combPrefix, seed=None, excludeNoise=False):
        ret_ = self.getData(combPrefix, seed=seed, excludeNoise=excludeNoise)
        x_train, y_train, x_val, y_val, files_train, files_val = ret_
        x_combined = np.concatenate((x_train, x_val), axis=0)
        y_combined = np.concatenate((y_train, y_val), axis=0)
        files_combined = None
        if files_val is not None and files_train is not None:
            files_combined = np.concatenate((files_train, files_val), axis=0)
        ret_ = shuffle(x_combined,y_combined,
                       files_combined=files_combined, seed=seed)
        x_combined, y_combined, files_combined = ret_
        return x_combined, y_combined, x_val, y_val, files_combined, files_val


class IMDBDataMIL(DataProcessorBase):
    '''
    WARNING: This is a time series data set and hence X is a 3D tensor

    Returns data that is split, shuffled and normalized.
    '''

    def __init__(self):
        pass

    def __setRandomSeeds(self, seed):
        # Setting up the random seeds
        np.random.seed(seed)
        tf.set_random_seed(seed)

    def getData(self, combPrefix, trainSplit=0.8, seed=None, skipFiles=True):
        '''
        Note that the data is already shuffled, there is no need
        for reshuffling
        '''
        if seed is not None:
            self.__setRandomSeeds(seed)
        x_train = np.load(combPrefix + '/' + 'x_train.npy')
        y_train = np.load(combPrefix + '/' + 'y_train.npy')

        assert x_train.shape[0] == y_train.shape[0]
        x_val = np.load(combPrefix + '/' + 'x_val.npy')
        y_val = np.load(combPrefix + '/' + 'y_val.npy')
        files_train, files_val = None, None
        if not skipFiles:
            files_train = np.load(combPrefix + '/' + 'files_train.npy')
            files_val = np.load(combPrefix + '/' + 'files_val.npy')
        assert x_val.shape[0] == y_val.shape[0]
        return x_train, y_train, x_val, y_val, files_train, files_val

    def getTest(self, combPrefix, seed=None, skipFiles=True):
        if seed is not None:
            self.__setRandomSeeds(seed)
        x_test = np.load(combPrefix + '/' + 'x_test.npy')
        y_test = np.load(combPrefix + '/' + 'y_test.npy')
        files_test = None
        if not skipFiles:
            files_test = np.load(combPrefix + '/' + 'files_test.npy')
        return x_test, y_test, files_test

    def getCombined(self, combPrefix, seed=None, skipFiles=True):
        ret_ = self.getData(combPrefix, seed=seed, skipFiles=skipFiles)
        x_train, y_train, x_val, y_val, files_train, files_val = ret_
        x_combined = np.concatenate((x_train, x_val), axis=0)
        y_combined = np.concatenate((y_train, y_val), axis=0)
        files_combined = None
        if files_val is not None and files_train is not None:
            files_combined = np.concatenate((files_train, files_val), axis=0)
        ret_ = shuffle(x_combined, y_combined,
                       files_combined=files_combined, seed=seed)
        x_combined, y_combined, files_combined = ret_
        return x_combined, y_combined, x_val, y_val, files_combined, files_val


class MNISTDataMIL(DataProcessorBase):
    '''
    WARNING: This is a time series data set and hence X is a 3D tensor

    Returns data that is split, shuffled and normalized.
    '''

    def __init__(self):
        pass

    def __setRandomSeeds(self, seed):
        # Setting up the random seeds
        np.random.seed(seed)
        tf.set_random_seed(seed)

    def getData(self, combPrefix, seed=None):
        '''
        Note that the data is already shuffled, there is no need
        for reshuffling
        '''
        if seed is not None:
            self.__setRandomSeeds(seed)
        x_train = np.load(combPrefix + '/' + 'x_train.npy')
        y_train = np.load(combPrefix + '/' + 'y_train.npy')

        assert x_train.shape[0] == y_train.shape[0]
        # x_val = np.load(combPrefix + '/' + 'x_val.npy')
        # y_val = np.load(combPrefix + '/' + 'y_val.npy')
        x_test = np.load(combPrefix + '/' + 'x_test.npy')
        y_test = np.load(combPrefix + '/' + 'y_test.npy')

        files_train, files_val = None, None

        return x_train, y_train, x_test, y_test, files_train, files_val

    def getTest(self, combPrefix, seed=None):
        if seed is not None:
            self.__setRandomSeeds(seed)
        x_test = np.load(combPrefix + '/' + 'x_test.npy')
        y_test = np.load(combPrefix + '/' + 'y_test.npy')
        files_test = None
        return x_test, y_test, files_test

    def getCombined(self, combPrefix, seed=None):
        ret_ = self.getData(combPrefix, seed=seed)
        x_train, y_train, x_val, y_val, files_train, files_val = ret_
        x_combined = np.concatenate((x_train, x_val), axis=0)
        y_combined = np.concatenate((y_train, y_val), axis=0)
        files_combined = None
        if files_val is not None and files_train is not None:
            files_combined = np.concatenate((files_train, files_val), axis=0)
        ret_ = shuffle(x_combined, y_combined,
                       files_combined=files_combined, seed=seed)
        x_combined, y_combined, files_combined = ret_
        return x_combined, y_combined, x_val, y_val, files_combined, files_val


class SPORTSDataMIL(DataProcessorBase):
    '''
    WARNING: This is a time series data set and hence X is a 3D tensor

    Returns data that is split, shuffled and normalized.
    '''

    def __init__(self):
        pass

    def __setRandomSeeds(self, seed):
        # Setting up the random seeds
        np.random.seed(seed)
        tf.set_random_seed(seed)

    def getData(self, combPrefix, trainSplit=0.8, seed=None):
        '''
        Note that the data is already shuffled, there is no need
        for reshuffling
        '''
        # combPrefix should be /home/t-chpabb/UCI HAR Dataset
        if seed is not None:
            self.__setRandomSeeds(seed)
        x_train = np.load(combPrefix + '/' + 'x_train.npy')
        y_train = np.load(combPrefix + '/' + 'y_train.npy')
        assert x_train.shape[0] == y_train.shape[0]
        x_val = np.load(combPrefix + '/' + 'x_val.npy')
        y_val = np.load(combPrefix + '/' + 'y_val.npy')
        assert x_val.shape[0] == y_val.shape[0]
        return x_train, y_train, x_val, y_val, None, None

    def getTest(self, combPrefix, seed=None):
        if seed is not None:
            self.__setRandomSeeds(seed)
        x_test = np.load(combPrefix + '/' + 'x_test.npy')
        y_test = np.load(combPrefix + '/' + 'y_test.npy')
        files_test = None
        return x_test, y_test, None

    def getCombined(self, combPrefix, seed=None):
        ret_ = self.getData(combPrefix, seed=seed)
        x_train, y_train, x_val, y_val, files_train, files_val = ret_
        x_combined = np.concatenate((x_train, x_val), axis=0)
        y_combined = np.concatenate((y_train, y_val), axis=0)
        files_combined = None
        if files_val is not None and files_train is not None:
            files_combined = np.concatenate((files_train, files_val), axis=0)
        ret_ = shuffle(x_combined, y_combined,
                       files_combined=files_combined, seed=seed)
        x_combined, y_combined, files_combined = ret_
        return x_combined, y_combined, x_val, y_val, files_combined, files_val


def shuffle(x, y, files_combined=None, seed=None):
    assert len(x) == len(y)
    if seed is not None:
        np.random.seed(seed)
    idx = np.arange(len(x))
    np.random.shuffle(idx)
    files_ret = None
    if files_combined is not None:
        assert len(files_combined) == len(x)
        files_ret = files_combined[idx]
    return x[idx], y[idx], files_ret
