import os

import librosa
import numpy as np
import torch
from torchvision.datasets import DatasetFolder


def spect_loader(path, window_stride=0.01, window='hamming', max_len=64):
    y, sr = librosa.load(path, sr=None)

    n_fft = (max_len - 1) * 2  # int(sr * window_size)
    hop_length = int(sr * window_stride * 1. / max_len * 101)
    win_length = n_fft

    spect = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window)
    spect, _ = librosa.magphase(spect)
    spect = np.log1p(spect)

    if spect.shape[1] < max_len:
        spect = np.hstack((spect, np.zeros((spect.shape[0], max_len - spect.shape[1]))))
    elif spect.shape[1] > max_len:
        spect = spect[:, :max_len]

    spect = np.resize(spect, (1, spect.shape[0], spect.shape[1]))
    spect = torch.from_numpy(spect).float()
    assert spect.min() >= 0

    # 8-bit quantization
    spect[spect > 3] = 3
    spect /= 3.
    spect = (spect * 255).int().float() / 255.

    spect = spect.repeat(3, 1, 1)  # make int easier to fit image CNN

    return spect


class GoogleCommands(DatasetFolder):
    def __init__(self, root, split, resolution, transform=None, target_transform=None):
        from functools import partial
        loader = partial(spect_loader, max_len=resolution)

        super(GoogleCommands, self).__init__(
            root=os.path.join(root, split),
            loader=loader,
            extensions=('.wav', '.WAV'),
            transform=transform,
            target_transform=target_transform
        )
