from utils.operations import scharr_convolution
import numpy as np
import torch
from utils.wavelets import idwt2

class MultiScaleGenerator():

    def __init__(self, 
    list_of_generators_by_scale, 
    reconstruction_function = idwt2,
    gamma = 1.,
    d = 2
    ) -> None:
        """
        Inputs:
        - the list of generators to generate details conditionnally on phi_j at each scale
                g : arr(dims,L,...,L) -> (2^d - 1) tuple of arr(dims,L,...,L)
        - the function reconstructing a field from low freqs + details
                f : (arr(dims,L,...,L) (d times L), (2^d - 1)-tuple of arr(dims,L,...,L)) -> arr(dims, 2L,...,2L)
        - the initial generator for the lower scale
        - the renorm factor (default 1)
        """
        self.generator_list = list_of_generators_by_scale
        self.reconstruction_function = reconstruction_function
        self.renormalization = gamma
        self.d = d
    
    def reconstruct(self, x, y):
        """
        input: 
            x = (n, 1, L, ..., L) array (d dimensions of length L)
            y = (n, 2^d - 1, L, ..., L) 
        output:
            the field reconstructed from x and y with wavelets.
            dimension (n, 1, 2L, ..., 2L) (d dimensions of length 2L)
        """
        return self.reconstruction_function((x,y)) * self.renormalization

    def initialize(self, n_batch, j0=1):
        """for the moment, only gaussian generation at lower scale"""
        dims = [n_batch, 1] + [2**j0 for i in range(self.d)]
        return np.random.randn(*dims)

    def generate_batch(self, n_batch):
        phi = self.initialize(n_batch, 2)
        for f in self.generator_list:
            y = f(phi)
            phi = self.reconstruct(phi, y)
        return phi.squeeze()
    
    def generate_lower_scale(self, j, phi):
        psi = self.generator_list[j](phi)
        return self.reconstruct(phi, psi)



def f(phi):
    n,_,L,L = phi.shape
    return np.random.randn(n,3,L,L) * np.abs(phi)

def gaussian_filtered_variance_2d(x):
    dims = x.shape
    x_filtered = scharr_convolution(x)
    result = [np.random.randn(*dims) * x_filtered for c in range(3)]
    return tuple(result)

def alternate_filtered(x):
    dims = x.shape
    x_fft = np.fft.fft2(x, norm="ortho")
    f = [np.random.rand(*dims) for i in range(3)]
    result = [np.real(np.fft.ifft2(x_fft*f[i], norm="ortho")) for i in range(3)]
    return tuple(result)

def Bernoulli_generator(dims,d=2):
    """
    Returns the Bernoulli model with sparsity parameter d. 
    """
    x = np.random.rand(*dims) < d/dims[-1]
    w = np.random.randn(*dims)
    return x*w


# The following functions were written by T and M
def Ornstein_generator(dims, eta, xi):
#def generate_gaussian_processes(L, eta, xi, nb_realizations):
    """
    Generate a batch of gaussian processes
    Args:
        L (int): size of the process
        eta (float): power spectral index
        xi (float): characteristic length
        nb_realizations (int): number of realizations
    Returns: ndarray (nb_realizations, L,L)
    """
    n,_,l = dims
    gaussian_white_noise = np.random.normal(0, 1., (n, l, l))
    fft_gaussian_white_noise = np.fft.fft2(gaussian_white_noise)
    filter = np.sqrt(compute_psd_analytically(l, eta, xi))
    filter = np.expand_dims(filter, axis=0)
    inverse_fft_signal = np.multiply(filter, fft_gaussian_white_noise)
    result = np.real(np.fft.ifft2(inverse_fft_signal))
    return result


def compute_psd_analytically(L, eta, xi):
    """
    Compute the power spectrum density of a process
    Args:
        L: size of the process
        eta: power spectral index
        xi: characteristic length
    Returns: ndarray(L,L) power spectrum density of the process
    """
    omega_xi = 1. / xi
    fft_freq_square = np.fft.fftfreq(L)**2
    fft_freq_2d = np.sqrt(fft_freq_square.reshape((L,1)) + fft_freq_square.reshape((1,L)))
    low_frequencies = fft_freq_2d < omega_xi
    filter = np.zeros(fft_freq_2d.shape)
    filter[~low_frequencies] = np.power(fft_freq_2d[~low_frequencies], - eta / 2.)
    filter[low_frequencies] = np.power(omega_xi, - eta / 2.)
    PSD = filter**2
    return PSD


def Imbalanced_gaussian_mixture_generator(dims, p=0.3, m_1 = 3., m_2 = -3.):
    """
    Generates a bimodal gaussian mixture. 
    With proba p the result is a N(m2, 1) rv, otherwise N(m1,1)
    """
    gaussian = torch.randn(*dims) + m_1
    bias = (m_2 - m_1)*(torch.rand(*dims) < p)
    return gaussian + bias
