import torch
import torch.nn as nn
import torch.fft as fft
import numpy as np
import torch.nn.functional as F


def cubic_interp(_x, _y, _dy, x, norm=False, index=None, t=None):

    #Cubic Hermite Spline

    #_x: (Nx, ) linear
    #_y: (..., Nx)
    #_dy: (..., Nx) Given derivatives. Does not gaurantee that the second derivative is contineous.
    #x: (Nx1, ) _x[0] <= x[:] <= _x[-1]

    if norm:
        delta_x = _x[1] - _x[0]
        _x = _x / delta_x
        x = x / delta_x
        _dy = _dy * delta_x

    if index is None:
        index = torch.searchsorted(_x.detach(), x.contiguous().detach())
        index[index==0] = 1
        index[index==len(_x)] = len(_x) - 1

    yi = _y[...,index-1]
    yi1 = _y[...,index]
    Di = _dy[...,index-1]
    Di1 = _dy[...,index]
    a = yi
    b = Di
    c = 3.*(yi1-yi) - 2.*Di - Di1
    d = 2.*(yi-yi1) + Di + Di1

    if t is None:
        t = x - _x[index-1]

    y = c + d*t
    y = b + t*y
    y = a + t*y
    return y
    

def power_spectrum(xk2, k, sidelength, kedge):

    index = torch.searchsorted(kedge.detach(), k.contiguous().detach(), right=True)
    index1 = torch.ones_like(xk2, dtype=torch.int64) * index
    fac = (torch.prod(torch.as_tensor(index1.shape)) / torch.prod(torch.as_tensor(index.shape))).item()

    power = torch.bincount(index1.flatten(), weights=xk2.flatten())
    Nmode = torch.bincount(index.flatten())
    power_k = torch.bincount(index.flatten(), weights=k.flatten())
    if sidelength % 2 == 0:
        power = power + torch.bincount(index1[...,1:-1].flatten(), weights=xk2[...,1:-1].flatten())
        Nmode = Nmode + torch.bincount(index[...,1:-1].flatten())
        power_k = power_k + torch.bincount(index[...,1:-1].flatten(), weights=k[...,1:-1].flatten())
    else:
        power = power + torch.bincount(index1[...,1:].flatten(), weights=xk2[...,1:].flatten())
        Nmode = Nmode + torch.bincount(index[...,1:].flatten())
        power_k = power_k + torch.bincount(index[...,1:].flatten(), weights=k[...,1:].flatten())

    power = power[1:len(kedge)]
    Nmode = Nmode[1:len(kedge)]
    power_k = power_k[1:len(kedge)]
    select = Nmode > 0
    power[select] = power[select] / (fac*Nmode[select])
    power_k[select] = power_k[select] / Nmode[select]

    return power_k, power



class TREConv(nn.Module):

    def __init__(self, sidelength, D, nknot, zeromode=False, positivetf=True, conditional=False): #TODO: log k bin

        super().__init__()

        assert D in [1, 2, 3]

        if D == 1:
            k = torch.as_tensor(np.fft.rfftfreq(sidelength), dtype=torch.get_default_dtype())
        elif D == 2:
            k = torch.zeros(sidelength, sidelength//2+1)
            k += torch.as_tensor(np.fft.fftfreq(sidelength), dtype=torch.get_default_dtype()).reshape(-1,1) ** 2
            k += torch.as_tensor(np.fft.rfftfreq(sidelength), dtype=torch.get_default_dtype()).reshape(1,-1) ** 2
            k = k ** 0.5
        elif D == 3:
            k = torch.zeros(sidelength, sidelength, sidelength//2+1)
            k += torch.as_tensor(np.fft.fftfreq(sidelength), dtype=torch.get_default_dtype()).reshape(-1,1,1) ** 2
            k += torch.as_tensor(np.fft.fftfreq(sidelength), dtype=torch.get_default_dtype()).reshape(1,-1,1) ** 2
            k += torch.as_tensor(np.fft.rfftfreq(sidelength), dtype=torch.get_default_dtype()).reshape(1,1,-1) ** 2
            k = k ** 0.5

        self.sidelength = sidelength
        self.D = D
        self.register_buffer('k', k)
        self.zeromode = zeromode

        maxk = torch.max(self.k).item()
        self.nknot = nknot
        tf_k = torch.linspace(1./sidelength, maxk, nknot)

        delta_k = tf_k[1] - tf_k[0]
        tf_k /= delta_k
        self.k /= delta_k
        self.register_buffer('tf_k', tf_k)

        #save for faster interpolation
        index = torch.searchsorted(self.tf_k, self.k.contiguous().reshape(-1))
        index[index==0] = 1
        index[index==nknot] = nknot - 1
        self.register_buffer('index', index)
        self.register_buffer('t', self.k.reshape(-1) - tf_k[index-1])

        self.positivetf = positivetf
        self.conditional = conditional

        if conditional:
            if zeromode:
                self.nparam = 2*nknot + 1
            else:
                self.nparam = 2*nknot
        else:
            self.nparam = 0
            if zeromode:
                self.param = nn.Parameter(torch.randn(1,2*nknot+1)) 
            else:
                self.param = nn.Parameter(torch.randn(1,2*nknot)) 

    
    def set_zeromode(self, tf, tf_0):

        if self.D == 1:
            tf[...,0] = tf_0
        elif self.D == 2:
            tf[...,0,0] = tf_0
        elif self.D == 3:
            tf[...,0,0,0] = tf_0
        return tf


    def correct_tf(self, param):

        tf = cubic_interp(_x=self.tf_k, _y=param[:, :self.nknot], _dy=param[:, self.nknot:2*self.nknot], x=self.k.reshape(-1), norm=False, index=self.index, t=self.t).reshape(-1, *self.k.shape)

        if self.positivetf:
            tf = F.softplus(tf).clone()

        if self.zeromode:
            zeromode = param[:,2*self.nknot]
            if self.positivetf:
                zeromode = F.softplus(zeromode)
            tf = self.set_zeromode(tf, zeromode)
        else:
            tf = self.set_zeromode(tf, 1)
        
        return tf


    def transform(self, x, param=None, mode='forward'):
        
        assert mode in ['forward', 'inverse']

        x = x.reshape(len(x), *[self.sidelength]*self.D)

        xk = fft.rfftn(x, s=[self.sidelength]*self.D, norm='ortho')

        if self.conditional:
            assert param is not None
            tf = self.correct_tf(param)
        else:
            tf = self.correct_tf(self.param)
        if mode == 'forward':
            xk = xk * tf
        elif mode == 'inverse':
            xk = xk / tf
        if self.D == 1:
            dim = 1
        elif self.D == 2:
            dim = (1,2)
        elif self.D == 3:
            dim = (1,2,3)
        logj = torch.sum(torch.log(torch.abs(tf)), dim=dim)
        if self.sidelength % 2 == 0:
            logj += torch.sum(torch.log(torch.abs(tf[...,1:-1])), dim=dim)
        else:
            logj += torch.sum(torch.log(torch.abs(tf[...,1:])), dim=dim)

        x = fft.irfftn(xk, s=[self.sidelength]*self.D, norm='ortho')

        if not self.conditional:
            logj = torch.repeat_interleave(logj, len(x), dim=0)
        x = x.reshape(len(x), -1)

        return x, logj
        

    def forward(self, x, param=None):
        return self.transform(x, param=param, mode='forward') 


    def inverse(self, x, param=None):
        return self.transform(x, param=param, mode='inverse') 



class TDNEConv(nn.Module):

    #2D conv with DN symmetry

    def __init__(self, sidelength, m=[0,1], N=4, input_rep=[0,0], output_rep=[0,0], nknot=8, zeromode=False, eps=1e-5, logkbin=False, conditional=False):

        super().__init__()

        self.sidelength = sidelength
        kx = torch.repeat_interleave(torch.fft.fftfreq(sidelength).reshape(-1,1), sidelength//2+1, 1)
        ky = torch.repeat_interleave(torch.fft.rfftfreq(sidelength).reshape(1,-1), sidelength, 0)
        k = (kx**2 + ky**2) ** 0.5
        self.register_buffer('k', k)

        theta = torch.acos(kx/k)
        select = ky < 0
        theta[select] = 2*np.pi - theta[select]
        theta[0,0] = 0

        if (input_rep[1] == N/2) != (output_rep[1] == N/2):
            m = torch.tensor([m0 + 0.5 for m0 in m])
            halfint = True
        else:
            m = torch.tensor(m)
            halfint = False

        if (input_rep[0] == 0) == (output_rep[0] == 0):
            T_angle = torch.cos(m.reshape(-1,1,1)*N*theta)
            if halfint:
                T_angle[:,0,0] = 0
        else:
            T_angle = torch.sin(m.reshape(-1,1,1)*N*theta)
            if sidelength % 2 == 0:
                T_angle[:,sidelength//2] = 0
            else:
                T_angle[:,sidelength//2] = 0
                T_angle[:,sidelength//2+1] = 0
                T_angle[:,:,sidelength//2] = 0

        self.register_buffer('T_angle', T_angle)
        self.N = N
        self.input_rep = input_rep
        self.output_rep = output_rep
        self.register_buffer('m', m)

        self.zeromode = zeromode
        self.nknot = nknot
        self.conditional = conditional
        if self.conditional:
            self.nparam = 2*nknot*len(m)
            if zeromode:
                self.nparam += 1
        else:
            self.nparam = 0
            if zeromode:
                self.param = nn.Parameter(torch.randn(1,2*nknot*len(m)+1))
            else:
                self.param = nn.Parameter(torch.randn(1,2*nknot*len(m)))

        if logkbin:
            self.k[:] = torch.log(self.k)
            self.k[0,0] = self.k[0,1]
            tf_k = torch.linspace(torch.min(self.k), torch.max(self.k), nknot)
        else:
            tf_k = torch.linspace(1./sidelength, torch.max(self.k), nknot)

        delta_k = tf_k[1] - tf_k[0]
        tf_k /= delta_k
        self.k /= delta_k
        self.register_buffer('tf_k', tf_k)

        #save for faster interpolation
        index = torch.searchsorted(self.tf_k, self.k.contiguous().reshape(-1))
        index[index==0] = 1
        index[index==nknot] = nknot - 1
        self.register_buffer('index', index)
        self.register_buffer('t', self.k.reshape(-1) - tf_k[index-1])

        self.eps = eps


    def correct_tf(self, param):

        tf = torch.zeros(len(param), self.sidelength, self.sidelength//2+1, device=param.device)
        start = 0
        for i in range(len(self.m)):
            tf_ = cubic_interp(_x=self.tf_k, _y=param[..., start:start+self.nknot], _dy=param[..., start+self.nknot:start+2*self.nknot], x=self.k.reshape(-1), norm=False, index=self.index, t=self.t).reshape(len(param), *self.k.shape)
            start += 2*self.nknot
            if self.m[i] == 0:
                tf_ = F.softplus(tf_).clone()
            tf = tf + tf_ * self.T_angle[i]
        sign = torch.sign(tf)
        sign[sign==0] = 1
        tf = sign * (tf**2 + self.eps**2)**0.5
        if self.zeromode:
            tf[:,0,0] = F.softplus(param[:,-1])
        else:
            tf[:,0,0] = 1

        return tf


    def transform(self, x, param=None, mode='forward'):

        assert mode in ['forward', 'inverse']

        x = x.reshape(len(x), self.sidelength, self.sidelength)

        xk = fft.rfftn(x, s=[self.sidelength]*2, norm='ortho')

        if self.conditional:
            assert param is not None
            tf = self.correct_tf(param)
        else:
            tf = self.correct_tf(self.param)
        if mode == 'forward':
            xk = xk * tf
        elif mode == 'inverse':
            xk = xk / tf
        logj = torch.sum(torch.log(torch.abs(tf)), dim=(1,2))
        if self.sidelength % 2 == 0:
            logj += torch.sum(torch.log(torch.abs(tf[:,:,1:-1])), dim=(1,2))
        else:
            logj += torch.sum(torch.log(torch.abs(tf[:,:,1:])), dim=(1,2))

        x = fft.irfftn(xk, s=[self.sidelength]*2, norm='ortho')
        x = x.reshape(len(x), -1)

        return x, logj


    def forward(self, x, param=None):
        return self.transform(x, param=param, mode='forward')


    def inverse(self, x, param=None):
        return self.transform(x, param=param, mode='inverse')

