import torch
import torch.utils.data
import torch.nn.functional as F

from librosa.core import load
from librosa.util import normalize


from pathlib import Path
from tqdm import tqdm
import numpy as np
import random
import h5py


def files_to_list(filename):
    """
    Takes a text file of filenames and makes a list of filenames
    """
    with open(filename, encoding="utf-8") as f:
        files = f.readlines()

    files = [f.rstrip() for f in files]
    return files


class AudioDataset(torch.utils.data.Dataset):
    """
    This is the main class that calculates the spectrogram and returns the
    spectrogram, audio pair.
    """

    def __init__(self, training_files, segment_length, sampling_rate, augment=True):
        self.sampling_rate = sampling_rate
        self.segment_length = segment_length
        self.audio_files = files_to_list(training_files)
        self.audio_files = [Path(training_files).parent / x for x in self.audio_files]
        random.seed(1234)
        random.shuffle(self.audio_files)
        self.augment = augment

    def __getitem__(self, index):
        # Read audio
        filename = self.audio_files[index]
        audio, sampling_rate = self.load_wav_to_torch(filename)
        # Take segment
        if audio.size(0) >= self.segment_length:
            max_audio_start = audio.size(0) - self.segment_length
            audio_start = random.randint(0, max_audio_start)
            audio = audio[audio_start : audio_start + self.segment_length]
        else:
            audio = F.pad(
                audio, (0, self.segment_length - audio.size(0)), "constant"
            ).data

        # audio = audio / 32768.0
        return audio.unsqueeze(0)

    def __len__(self):
        return len(self.audio_files)

    def load_wav_to_torch(self, full_path):
        """
        Loads wavdata into torch array
        """
        data, sampling_rate = load(full_path, sr=self.sampling_rate)
        data = 0.95 * normalize(data)

        if self.augment:
            amplitude = np.random.uniform(low=0.3, high=1.0)
            data = data * amplitude

        return torch.from_numpy(data).float(), sampling_rate


class HDF5Dataset(torch.utils.data.Dataset):
    def __init__(self, hdf5_file, segment_length, augment=True, inference=False):
        super().__init__()

        self._n_samples = len(h5py.File(hdf5_file, "r")["raw_audio"])
        self.hdf5_file = hdf5_file
        self.segment_length = segment_length
        self.augment = augment
        self.inference = inference

    def __len__(self):
        return self._n_samples

    def __getitem__(self, index):

        with h5py.File(self.hdf5_file, "r") as hf:

            # Read audio, convert to float 32
            audio = hf["raw_audio"][index].astype(np.float32) / 32768.0

            audio = 0.95 * normalize(audio)

            if self.augment and self.inference is False:
                amplitude = np.random.uniform(low=0.3, high=1.0)
                audio = audio * amplitude

            audio = torch.from_numpy(audio)

            # Take segment
            if self.inference is False:
                if audio.size(0) >= self.segment_length:
                    max_audio_start = audio.size(0) - self.segment_length
                    audio_start = random.randint(0, max_audio_start)
                    audio = audio[audio_start : audio_start + self.segment_length]
                else:
                    audio = F.pad(
                        audio, (0, self.segment_length - audio.size(0)), "constant"
                    ).data

            return audio.unsqueeze(0)


class HDF5DataLoader(object):
    def __init__(self, hdf5_file, batch_size, segment_length, split, augment=True):
        super().__init__()
        self.h5 = h5py.File(hdf5_file, "r")
        self.n_samples = len(self.h5["raw_audio"])
        self.segment_length = segment_length
        self.batch_size = batch_size
        self.augment = augment

        self.indices = np.arange(self.n_samples)
        np.random.seed(111)
        np.random.shuffle(self.indices)

        if split == "train":
            self.indices = self.indices[:-100]
        else:
            self.indices = self.indices[-100:]

    def __len__(self):
        return len(self.indices) // self.batch_size

    def create_iterator(self):
        for i in range(0, len(self.indices), self.batch_size):
            idx = sorted(self.indices[i : i + self.batch_size])
            audio = self.h5["raw_audio"][idx]

            # Compute length and start idx of waveform to extract segment
            lengths = [len(x) for x in audio]
            audio_starts = [
                random.randint(0, max(0, x - self.segment_length)) for x in lengths
            ]

            # Normalize audio and extract segment from each waveform in batch
            audio = np.stack(
                [
                    0.95 * normalize(x / 32768.0)[start : start + self.segment_length]
                    for start, x in zip(audio_starts, audio)
                ],
                axis=0,
            ).astype("float32")

            # Augment the audio by shifting volume
            if self.augment:
                amplitudes = np.random.uniform(low=0.3, high=1.0, size=(len(idx),))
                audio *= amplitudes[:, None]

            audio = torch.from_numpy(audio).cuda().unsqueeze(1)
            yield audio
