# Helper methods form Google-experimentation
import numpy as np
import tensorflow as tf
from python_speech_features import fbank
import matplotlib.pyplot as plt
import scipy.io.wavfile as r
from divideandconquer.common.utils import convertToDetectionLabels
from divideandconquer.common.utils import flatten_features


def loadAndPrepareData(DATA_DIR, GROUP_SIZE, meanVar=True, verbose=True,
                       flatmode='RNN', detection=True):
    '''
    Loads x_train and y_train, the saved data post data_processing (label
    correction and what not) and flattens it to make it work with an RNN.
    Returns the original data and the flattened group-size data
    '''
    x_train = np.load(DATA_DIR + 'x_train.npy')
    y_train = np.load(DATA_DIR + 'y_train.npy')
    x_val = np.load(DATA_DIR + 'x_val.npy')
    y_val = np.load(DATA_DIR + 'y_val.npy')
    if meanVar:
        mean = np.mean(np.reshape(x_train, [-1, x_train.shape[-1]]), axis=0)
        std = np.std(np.reshape(x_train, [-1, x_train.shape[-1]]), axis=0)
        std[std[:] < 0.000001] = 1
        x_train = (x_train - mean) / std
        x_val = (x_val - mean) / std
    if detection:
        y_train = convertToDetectionLabels(y_train, label=0)
        y_val = convertToDetectionLabels(y_val, label=0)
    x_train_new, y_train_new = flatten_features(x_train, y_train,
                                                group=GROUP_SIZE,
                                                mode=flatmode)
    x_val_new, y_val_new = flatten_features(x_val, y_val, group=GROUP_SIZE,
                                            mode=flatmode)
    if verbose:
        print('x_train ', x_train.shape)
        print('y_train ', y_train.shape)
        print('x_new', x_train_new.shape)
        print('y_new', y_train_new.shape)
        if flatmode == 'Stacked':
            y_new_ = y_train_new[:, 0, :]
        else:
            y_new_ = np.array(y_train_new)
        bag = np.argmax(y_new_, axis=1)
        bins= np.bincount(bag)
        print('Bin count\n', bins)
        print("Trivial acc: ", np.max(bins)/sum(bins), " for class: ",
              np.argmax(bins))
    ret = [x_train_new, y_train_new, x_val_new, y_val_new]
    if meanVar:
        ret.append(mean)
        ret.append(std)
    return ret


def getBrickLength(groupSize, windowWidth=400, windowStride=160):
    '''
    windowWidth and windowStride specified in number of samples
    '''
    brickLength = windowWidth + (groupSize - 1) * windowStride
    return brickLength


def divideIntoBricks(audioList, groupSize, brickStride=None, windowWidth=400,
                     windowStride=160, noZeroPad=False):
    '''
    Converts audio samples into feature extracted bricks.

    The number of audio samples that go into each brick is calculated using
    groupSize, windowWidth and windowStride.

    x is a list of audio samples

    Bricks that fall short of brickLength are zero padded noZeroPad=True
    If brickStride is not given, brickStride = brickLength

    Provide brickStride as number of samples.
    '''
    audioList = np.array(audioList)
    assert audioList.ndim== 2
    brickLength = getBrickLength(groupSize, windowWidth, windowStride)
    if brickStride is None:
        brickStride = brickLength

    brickedList = []
    for audio in audioList:
        brickList = []
        start = 0
        end = 0
        while end <= len(audio):
            end = start + brickLength
            brick = audio[start:end]
            if len(brick) == brickLength:
                brickList.append(brick)
            elif noZeroPad is False:
                brickBed = np.zeros(brickLength)
                brickBed[:len(brick)] = brick[:]
                brickList.append(brickBed)
            start += brickStride
        brickList = np.array(brickList)
        brickedList.append(brickList)
    return np.array(brickedList)


def extractFeatures(stackedWav, expectedNumSteps, numFilt=32,
                    samplerate=16000, winlen=0.025, winstep=0.010):
    '''
    [number of waves, Len(wave)]
    returns [number of waves, numSteps, numFilt]

    In the MIL setting, this is designed to wrok on one
    stacked wave at a time, but this assumption is
    not used anywhere in the program. Hence, this method
    will work with any 2-D list of waves.

    All waves are assumed to be of fixed length
    all arguments are in s
    '''
    assert stackedWav.ndim == 2, 'Should be [number of waves, len(wav)]'
    extractedList = []
    eps = 1e-10
    for sample in stackedWav:
        # sample of 25ms, stride by 10ms
        temp, _ = fbank(sample, samplerate=samplerate, winlen=winlen,
                        winstep=winstep, nfilt=numFilt,
                        winfunc=np.hamming)
        temp = np.log(temp + eps)
        assert temp.ndim == 2, 'Should be [numSteps, numFilt]'
        assert temp.shape[0] == expectedNumSteps, 'Should be [numSteps, numFilt]'
        extractedList.append(temp)
    return np.array(extractedList)


def exp04_freshFeaturize(audioFileList, mean, std, maxAudioLen, numOutput, groupSize,
                         instanceWidth, instanceStride):
    MAX_AUDIO_LEN, INSTANCE_WIDTH = maxAudioLen, instanceWidth
    NUM_OUTPUT, GROUP_SIZE, INSTANCE_STRIDE = numOutput, groupSize, instanceStride
    audioList = []
    for file in audioFileList:
        _, x_ = r.read(file)
        x = np.zeros(MAX_AUDIO_LEN)
        x[:len(x_)] = x_[:]
        assert len(x) == MAX_AUDIO_LEN, len(x)
        audioList.append(x)

    audioList = np.array(audioList)
    instancedList = divideIntoInstances(audioList, INSTANCE_WIDTH, INSTANCE_STRIDE)
    oldShape = instancedList.shape
    instancedList = np.reshape(instancedList, [-1, instancedList.shape[2]])
    featurizedInstancedList = extractFeatures(instancedList, 49)
    # Normalize
    featurizedInstancedList = (featurizedInstancedList - mean)/std
    # featurizedInstancedList = np.reshape(featurizedInstancedList,
    # [oldShape[0], oldShape[1], 49, 32])
    x_audio = featurizedInstancedList
    y_audio = np.zeros([len(x_audio), NUM_OUTPUT])
    # Now do what ever was done in main script
    x_audio, y_audio = flatten_features(x_audio, y_audio, group=GROUP_SIZE,
                                        mode='Stacked')
    y_audio = y_audio[:, 0, :]
    return x_audio, y_audio, oldShape

def plotWavExp3(wav, brickLength, predictions, brickStride=None, ax=None,
                nth=1):
    '''
    Assumes wav is divided into disjoint bricks of brickLength.
    Plots wav, separators for brickLength and the corresponding predictions.
    '''
    wav = (wav - np.mean(wav)) / np.std(wav)
    if ax is None:
        ax_ = plt
    ax_.plot(wav)
    if ax is not None:
        ymin, ymax = ax.get_ylim()
    else:
        ymin, ymax = plt.gca().get_ylim()

    if brickStride is None:
        brickStride = brickLength
    predictionIndex = 0
    end = brickLength
    while predictionIndex < len(predictions):
        if predictionIndex % nth == 0:
            ax_.vlines(end, ymin, ymax)
        if predictionIndex < len(predictions):
            if predictionIndex % nth == 0:
                ax_.text(end - 5, ymin, '%d' % predictions[predictionIndex])
            predictionIndex += 1
        end += brickStride

def getLabelForKeywords(keywordList, labelDict=None):
    LABELMAP13 = {
        'go': 1, 'no': 2, 'on': 3, 'up': 4, 'bed': 5, 'cat': 6,
        'dog': 7, 'off': 8, 'one': 9, 'six': 10, 'two': 11,
        'yes': 12,
        'wow': 0, 'bird': 0, 'down': 0, 'five': 0, 'four': 0,
        'left': 0, 'nine': 0, 'stop': 0, 'tree': 0, 'zero': 0,
        'eight': 0, 'happy': 0, 'house': 0, 'right': 0, 'seven': 0,
        'three': 0, 'marvin': 0, 'sheila': 0, '_background_noise_': 0
    }
    if labelDict is None:
        labelDict = LABELMAP13

    labelList = []
    for keyword in keywordList:
        assert keyword in labelDict, 'missing: %s' % keyword
        labelList.append(labelDict[keyword])
    return labelList

def divideIntoInstances(audioList, instanceWidth, instanceStride,
                        numSamples=16000):
    '''
    Given an array of audio samples of fixed length, this method returns
    another array of audio samples broken down into sub-instances.
    '''
    def __splitAndStackAudio(wav, subinstanceWidth, subinstanceStride):
        '''
        wav is one dimensional audio sample.
        Returns unstacked wave [numSubinstance, subinstanceWidth]
        Creates as many subinstance as can be created
        with the given width and stride
        '''
        assert wav.ndim == 1
        subinstList = []
        start = 0
        while True:
            vec = wav[start:start + subinstanceWidth]
            if len(vec) < subinstanceWidth:
                ar = np.zeros(subinstanceWidth)
                ar[subinstanceWidth - len(vec):] = vec[:]
                vec = ar
            subinstList.append(vec)
            if start + subinstanceWidth >= len(wav):
                break
            start += subinstanceStride
        subinstList = np.array(subinstList)
        return subinstList
    assert audioList.ndim == 2
    assert audioList.shape[1] == numSamples
    instanceList = []
    for x in audioList:
       subinstArr =  __splitAndStackAudio(x, instanceWidth, instanceStride)
       instanceList.append(subinstArr)
    return np.array(instanceList)


