from deel.lip.layers import (
    SpectralConv2D, 
    Identity,
    SpectralDense, 
    FrobeniusDense, 
    ScaledAveragePooling2D, 
    ScaledGlobalL2NormPooling2D,
    ScaledL2NormPooling2D, 
    InvertibleDownSampling,
    CircularPadding,
    SymmetricPadding,
    ScaledGlobalAveragePooling2D)

from deel.lip.extra_layers import SpectralDepthwiseConv2D,OrthoDepthwiseConv2D,BatchCentering
from deel.lip.cayley_layers import CayleyConv1x1
from deel.lip.normalizers import DEFAULT_NITER_BJORCK, DEFAULT_NITER_SPECTRAL, DEFAULT_NITER_SPECTRAL_INIT
from deel.lip.activations import MaxMin, GroupSort, GroupSort2, FullSort
from deel.lip.initializers import SpectralInitializer
from deel.utils.lip_model import get_lip_dense,get_lipConv2D
from deel.lip.model import Model as LipModel
import tensorflow as tf
from tensorflow.keras.layers import ReLU, Input, Flatten, MaxPool2D, Add,BatchNormalization,Dense,Lambda
from tensorflow.keras import backend as K
#from tensorflow.python.keras.models import Model

import numpy as np



def get_ortho_conv2D(filters, kernel_size=(3,3),
                        padding = 'circular',
                        activation=GroupSort2,
                        use_bias=True,
                        ortho = True,
                        strides = (1,1),
                        fft = True,
                        batch_centering = False,
                        niter_spectral=DEFAULT_NITER_SPECTRAL, 
                        niter_bjorck=DEFAULT_NITER_BJORCK):
    def f(x):
        pad = padding
        bias = use_bias
        if padding == 'circular':
            x = CircularPadding(padding=(kernel_size[0]//2,kernel_size[0]//2))(x)
            pad = 'valid'
        if padding == 'symmetric':
            x = SymmetricPadding(padding=(kernel_size[0]//2,kernel_size[0]//2))(x)
            pad = 'valid'
        if ortho :
            x = OrthoDepthwiseConv2D(kernel_size,strides = strides,use_bias = False, padding =pad )(x)
        else :
            x = SpectralDepthwiseConv2D(kernel_size,use_bias = False, strides = strides,fft = fft,padding =pad )(x)
        #if activation is not None:
        #    x = activation()(x)
        x =SpectralConv2D(filters, kernel_size = (1,1),
                              kernel_initializer="orthogonal",
                              regul_type='trans_bjork_coeff',
                              use_bias = bias,
                              padding="valid",
                              niter_spectral=niter_spectral,
                              niter_bjorck=niter_bjorck)(x)
        if batch_centering:
            x =BatchCentering()(x)
        if activation is not None:
            x = activation()(x)
        return x
    return f

def depth_pooling(filters, 
                    activation=GroupSort2,
                    use_bias=True,
                    batch_centering = False,

                    niter_spectral=DEFAULT_NITER_SPECTRAL,
                    niter_bjorck=DEFAULT_NITER_BJORCK):
    def f(x):
        
        bias = use_bias
        x = InvertibleDownSampling(pool_size=(2, 2))(x)
        x =SpectralConv2D(filters, kernel_size = (1,1),
                              kernel_initializer="orthogonal",
                              regul_type='trans_bjork_coeff',
                              use_bias = bias,
                              niter_spectral=niter_spectral,
                              niter_bjorck=niter_bjorck)(x)
        if batch_centering:
            x =BatchCentering()(x)
        if activation is not None:
            x = activation()(x)
        return x
    return f    
def depth_init(filters, 
                    activation=GroupSort2,
                    use_bias=True,
                    batch_centering = False,
                    niter_spectral=DEFAULT_NITER_SPECTRAL,
                    niter_bjorck=DEFAULT_NITER_BJORCK):
    def f(x):
        
        bias = use_bias
        x =SpectralConv2D(filters, kernel_size = (1,1),
                              kernel_initializer="orthogonal",
                              regul_type='trans_bjork_coeff',
                              use_bias = bias,
                              niter_spectral=niter_spectral,
                              niter_bjorck=niter_bjorck)(x)
        if batch_centering:
            x =BatchCentering()(x)
        if activation is not None:
            x = activation()(x)
        return x
    return f 
def ortho_cifar(shape,
                nb_classes=1,
                filter_size = 64,
                padding='circular',
                batch_centering = True,
                use_bias=False,
                bias_output = True,
                niter_spectral=DEFAULT_NITER_SPECTRAL,
                niter_bjorck=DEFAULT_NITER_BJORCK,
                verbose= False):
    inputs=Input(shape)
    x = depth_init(filter_size,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(inputs)
    x = get_ortho_conv2D(filter_size,kernel_size=(3,3),padding = padding, batch_centering = batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    x = get_ortho_conv2D(filter_size,kernel_size=(3,3),padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    x = get_ortho_conv2D(filter_size,kernel_size=(3,3),padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    x = get_ortho_conv2D(filter_size,kernel_size=(3,3),padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    x = InvertibleDownSampling(pool_size=(2, 2))(x)

    x = get_ortho_conv2D(2*filter_size,kernel_size=(3,3),padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    x = get_ortho_conv2D(2*filter_size,kernel_size=(3,3),padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    x = get_ortho_conv2D(2*filter_size,kernel_size=(3,3),padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    x = InvertibleDownSampling(pool_size=(2, 2))(x)

    x = get_ortho_conv2D(4*filter_size,kernel_size=(3,3),padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    x = get_ortho_conv2D(4*filter_size,kernel_size=(3,3),padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    x = get_ortho_conv2D(4*filter_size,kernel_size=(3,3),padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)

    x = ScaledGlobalAveragePooling2D()(x)

    if nb_classes>1:
            x = FrobeniusDense(nb_classes, 
                               activation=None, 
                               disjoint_neurons=True,
                               use_bias=bias_output, 
                               kernel_initializer="orthogonal")(x)
            
    else:
        x=get_lip_dense(1,use_bias=bias_output,by_constraint=False,activation=None,
                        kCoefLip=1.,
                        #kCoefLip=1.0,
                        niter_spectral= niter_spectral, niter_bjorck = niter_bjorck)(x)  
    model=LipModel(inputs=inputs, outputs=x)
    if verbose :
        model.summary()
    return model

def ortho_test(shape,
                nb_classes=1,
                filter_size = 64,
                padding='circular',
                batch_centering = True,
                use_bias=False,
                fft = True,
                stride = False,
                bias_output = True,
                niter_spectral=DEFAULT_NITER_SPECTRAL,
                niter_bjorck=DEFAULT_NITER_BJORCK,
                verbose= False):
    inputs=Input(shape)
    x = depth_init(filter_size,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(inputs)
    x = get_ortho_conv2D(filter_size,kernel_size=(3,3),ortho = False,fft = fft,padding = padding, batch_centering = batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    x = get_ortho_conv2D(filter_size,kernel_size=(3,3),ortho = False,fft = fft,padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    x = get_ortho_conv2D(filter_size,kernel_size=(3,3),ortho = False,fft = fft,padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    if stride :
        x = get_ortho_conv2D(filter_size,kernel_size=(3,3),ortho = False,strides = (2,2),fft = fft,padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    else :
        x = get_ortho_conv2D(filter_size,kernel_size=(3,3),ortho = False,fft = fft,padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
        x = InvertibleDownSampling(pool_size=(2, 2))(x)

    x = get_ortho_conv2D(2*filter_size,kernel_size=(3,3),ortho = False,fft = fft,padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    x = get_ortho_conv2D(2*filter_size,kernel_size=(3,3),ortho = False,fft = fft,padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    if stride :
        x = get_ortho_conv2D(2*filter_size,kernel_size=(3,3),ortho = False,strides = (2,2),fft = fft,padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    else :
        x = get_ortho_conv2D(2*filter_size,kernel_size=(3,3),ortho = False,fft = fft,padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
        x = InvertibleDownSampling(pool_size=(2, 2))(x)


    x = get_ortho_conv2D(4*filter_size,kernel_size=(3,3),ortho = False,fft = fft,padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    x = get_ortho_conv2D(4*filter_size,kernel_size=(3,3),ortho = False,fft = fft,padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
    x = get_ortho_conv2D(4*filter_size,kernel_size=(3,3),ortho = False,fft = fft,padding = padding,batch_centering =batch_centering, use_bias = use_bias,niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)

    x = ScaledGlobalAveragePooling2D()(x)

    if nb_classes>1:
            x = FrobeniusDense(nb_classes, 
                               activation=None, 
                               disjoint_neurons=True,
                               use_bias=bias_output, 
                               kernel_initializer="orthogonal")(x)
            
    else:
        x=get_lip_dense(1,use_bias=bias_output,by_constraint=False,activation=None,
                        kCoefLip=1.,
                        #kCoefLip=1.0,
                        niter_spectral= niter_spectral, niter_bjorck = niter_bjorck)(x)  
    model=LipModel(inputs=inputs, outputs=x)
    if verbose :
        model.summary()
    return model

def get_lip_depthconv_2D(filters, kernel_size=(3,3),
                  padding='same', 
                  activation=ReLU, 
                  use_bias=True,
                  strides=(1,1), 
                  conv_first = False,
                  by_constraint=False,
                  cayley = False,
                  double = False,
                  in_graph=True,
                  fft=True,
                  stop_gradient=True,
                  kCoefLip=1.0, regul_type = "spectral_conv",
                  niter_spectral=DEFAULT_NITER_SPECTRAL, 
                  niter_bjorck=DEFAULT_NITER_BJORCK):
    
    def f(x):
        pad = padding
        bias = use_bias
        regularizer = None
        kernel_constraint = None
        if padding == 'circular':
            x = CircularPadding(padding=(kernel_size[0]//2,kernel_size[0]//2))(x)
            pad = 'valid'
        if padding == 'symmetric':
            x = SymmetricPadding(padding=(kernel_size[0]//2,kernel_size[0]//2))(x)
            pad = 'valid'
        
        
        if cayley and double :
            sqrt2_inv_ = Lambda(lambda y: y / 1.41421356237)
            x1 =SpectralDepthwiseConv2D(kernel_size, use_bias=bias,in_graph = in_graph, fft = fft, stop_gradient = stop_gradient,
                          padding='valid')(x)
            x2 =SpectralDepthwiseConv2D(kernel_size, use_bias=bias,in_graph = in_graph, fft = fft, stop_gradient = stop_gradient,
                          padding='valid')(x)
            x = sqrt2_inv_(tf.keras.layers.Concatenate()([x1, x2]))
        else :
            x =SpectralDepthwiseConv2D(kernel_size, use_bias=bias,in_graph = in_graph, fft = fft, stop_gradient = stop_gradient,
                          padding='valid')(x)
        if activation is not None:
            x = activation()(x)
        if cayley :
            x =CayleyConv1x1(filters, kernel_initializer="orthogonal", 
                          strides=strides, use_bias=bias,
                          padding='valid')(x)
        else :
            x =SpectralConv2D(filters, (1,1),kernel_initializer="orthogonal", regul_type=regul_type,
                              strides=strides, use_bias=bias,conv_first=conv_first,
                              padding='valid',k_coef_lip=kCoefLip, by_constraint=by_constraint,
                              niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
        if activation is not None:
            x = activation()(x)
        return x
    return f



def deel_lip_depth_graph( nb_classes=1,kernel_size=(3,3),
                       coeffs=1.0,filter_size=16,
                       layers_per_depth=[],dense_layers_size=[],
                       padding='same',
                       regul_type = "spectral_conv",
                       by_constraint=False,
                       in_graph=True,
                       fft=True,
                       stop_gradient=True,
                       cayley = False,
                       activation_conv=None,activation_dense=None,use_bias=True,
                       use_stride=False,poolType="avg",batchNorm=0.0,
                       niter_spectral=DEFAULT_NITER_SPECTRAL, niter_bjorck=DEFAULT_NITER_BJORCK, 
                       splitLastLayer=False, activation_lastlayer = None):
    
    def f(x):
        nonlocal filter_size
        poll2fct={'avg':ScaledAveragePooling2D,'max':MaxPool2D,'l2norm':ScaledL2NormPooling2D,'inv':InvertibleDownSampling}
        last_activation=None
        total_coeffs = 1
        conv_first = True
        #print("use_stride",use_stride)
        x =SpectralConv2D(filter_size[0], (1,1),kernel_initializer="orthogonal", regul_type=regul_type,
                          strides=(1,1), use_bias=use_bias,conv_first=conv_first,
                          padding="valid",k_coef_lip=coeffs, by_constraint=by_constraint,
                          niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
        for pos, (layers,nb,kernel) in enumerate(zip(layers_per_depth,filter_size,kernel_size)):
            for i in range(layers):
                strides=(1,1)
                #if use_stride and i == 0 and pos!=0:
                if use_stride and i == layers-1 and pos!=len(layers_per_depth)-1:
                    strides=(2,2)
                #change = change+2*strides[0]
                x = get_lip_depthconv_2D(nb, kernel_size=kernel,
                                  kCoefLip=coeffs,
                                  padding=padding,
                                  by_constraint=by_constraint,
                                  strides = strides,
                                  cayley = cayley,
                                  double = (i==0)  and pos > 0 and filter_size[pos-1]<filter_size[pos],
                                  regul_type = regul_type,
                                  in_graph = in_graph, fft = fft, stop_gradient = stop_gradient,
                                  niter_bjorck=niter_bjorck,
                                  conv_first =conv_first,
                                  activation=activation_conv,
                                  use_bias=use_bias)(x)
                total_coeffs *= coeffs
                conv_first = False
            if not use_stride and pos!=len(layers_per_depth)-1:
                x=poll2fct[poolType](pool_size=(2, 2))(x)
        #if len(dense_layers_size) == 0:
        x = ScaledGlobalAveragePooling2D()(x)
        #x = ScaledGlobalL2NormPooling2D()(x)
    #else:
        #x = Flatten()(x)
        for lay_size in dense_layers_size:
            x = get_lip_dense(lay_size, use_bias=use_bias, 
                              by_constraint=by_constraint,
                              activation=activation_dense, kCoefLip=coeffs,
                             niter_spectral=niter_spectral, niter_bjorck=6)(x)
            total_coeffs *= coeffs
        #print(f"total_coeffs {total_coeffs}")
        if splitLastLayer and (nb_classes>1):
            #x = tf.concat([get_lip_dense(1, by_constraint=by_constraint,use_bias=use_bias, activation=last_activation, kCoefLip=1.0/total_coeffs)(x) for c in
            #               range(nb_classes)], axis=-1)
            x = FrobeniusDense(nb_classes, 
                               activation=last_activation, 
                               by_constraint=by_constraint,
                               disjoint_neurons=True,
                               use_bias=True, k_coef_lip=1., 
                               kernel_initializer="orthogonal")(x)
            
        else:
            x=get_lip_dense(nb_classes,use_bias=use_bias,by_constraint=by_constraint,activation=last_activation,
                            kCoefLip=1.,
                            #kCoefLip=1.0,
                            niter_spectral= niter_spectral, niter_bjorck = niter_bjorck)(x)  
        if activation_lastlayer is not None:
            x=activation_lastlayer()(x)
        return x
    return f


def deel_lip_depthconv(shape, nb_classes=1,kernel_size=3,coeffs=1,filter_size=16,
                 layers_per_depth=[], dense_layers_size=[],regul_type = "spectral_conv",
                 padding='same',activation_conv=None,
                 activation_dense=None,
                 use_bias=True,use_stride=False,
                 poolType="avg",lambdaE=0.0, 
                 by_constraint=False,
                 cayley = False,
                 in_graph=True,
                 fft=True,
                 stop_gradient=True,
                 niter_spectral=DEFAULT_NITER_SPECTRAL, 
                 niter_bjorck=DEFAULT_NITER_BJORCK, splitLastLayer=False,
                 activation_lastlayer = None,
                verbose= False):
    K.clear_session()
    if not isinstance(filter_size, list):
        filter_size = [filter_size*2**i for i in range(len(layers_per_depth))]
    if not isinstance(kernel_size, list):
        kernel_size = [kernel_size for i in range(len(layers_per_depth))]
    k_size=[(k,k) for k in kernel_size]
    inputs=Input(shape)
    net = deel_lip_depth_graph(filter_size=filter_size,
                             layers_per_depth=layers_per_depth,
                             dense_layers_size = dense_layers_size,
                             coeffs=coeffs, 
                             by_constraint=by_constraint,
                             kernel_size=k_size,
                             regul_type =regul_type,
                             cayley = cayley,
                             nb_classes=nb_classes,padding=padding,activation_conv=activation_conv,
                             activation_dense=activation_dense,use_bias=use_bias,
                             use_stride=use_stride,poolType=poolType,batchNorm=0.0, 
                             niter_spectral=niter_spectral,
                             in_graph = in_graph, fft = fft, stop_gradient = stop_gradient,
                             niter_bjorck=niter_bjorck,splitLastLayer=splitLastLayer,
                             activation_lastlayer=activation_lastlayer)(inputs)
    model=LipModel(inputs=inputs, outputs=net)
    if verbose :
        model.summary()
    return model


