from deel.lip.layers import (
    SpectralConv2D,
    Identity,
    SpectralDense,
    FrobeniusDense,
    ScaledAveragePooling2D,
    ScaledGlobalL2NormPooling2D,
    ScaledL2NormPooling2D,
    InvertibleDownSampling,
    CircularPadding,
    SymmetricPadding,
    InvertibleDownSampling,
    ScaledGlobalAveragePooling2D)


from deel.lip.extra_layers import SpectralDepthwiseConv2D, BatchCentering, LayerCentering

from deel.lip.normalizers import DEFAULT_NITER_BJORCK, DEFAULT_NITER_SPECTRAL, DEFAULT_NITER_SPECTRAL_INIT

from deel.lip.initializers import SpectralInitializer
from deel.utils.lip_model import get_lip_dense, get_lipConv2D
from deel.lip.model import Model as LipModel
import tensorflow as tf
from tensorflow.keras.layers import ReLU, AveragePooling2D, MaxPooling2D, Input, Flatten, Conv2D, MaxPool2D, Add, BatchNormalization, Dense, Lambda, Activation, Concatenate, LayerNormalization
from tensorflow.keras import backend as K
#from tensorflow.python.keras.models import Model
from tensorflow.keras.models import Model
import numpy as np
from xplique.metrics import MuFidelity

DEFAULT_ADD_COEFF = 0.5


def set_add_coeff(value: float):
    global DEFAULT_ADD_COEFF
    DEFAULT_ADD_COEFF = value


def identity_block(x, filters):
    # defining the input tensor shape
    input_tensor = x

    # defining the first block
    x = Conv2D(filters, kernel_size=(1, 1), strides=(1, 1), padding='valid')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # defining the second block
    x = Conv2D(filters, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # defining the third block
    x = Conv2D(filters * 4, kernel_size=(1, 1), strides=(1, 1), padding='valid')(x)
    x = BatchNormalization()(x)

    # adding the input tensor and output of third block
    x = Add()([x, input_tensor])
    x = Activation('relu')(x)

    return x


def conv_block(x, filters, strides):
    # defining the input tensor shape
    input_tensor = x

    # defining the first block
    x = Conv2D(filters, kernel_size=(1, 1), strides=strides, padding='valid')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # defining the second block
    x = Conv2D(filters, kernel_size=(3, 3), strides=(1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # defining the third block
    x = Conv2D(filters * 4, kernel_size=(1, 1), strides=(1, 1), padding='valid')(x)
    x = BatchNormalization()(x)

    # defining shortcut path
    shortcut = Conv2D(filters * 4, kernel_size=(1, 1),
                      strides=strides, padding='valid')(input_tensor)
    shortcut = BatchNormalization()(shortcut)

    # adding the shortcut path and output of third block
    x = Add()([x, shortcut])
    x = Activation('relu')(x)

    return x


def ResNet50(input_shape=(224, 224, 3), classes=1000):
    # defining the input tensor
    x_input = Input(input_shape)

    # stage 1
    x = Conv2D(64, kernel_size=(7, 7), strides=(2, 2), padding='same')(x_input)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)

    # stage 2
    x = conv_block(x, filters=64, strides=(1, 1))
    x = identity_block(x, filters=64)
    x = identity_block(x, filters=64)

    # stage 3
    x = conv_block(x, filters=128, strides=(2, 2))
    x = identity_block(x, filters=128)
    x = identity_block(x, filters=128)
    x = identity_block(x, filters=128)

    # stage 4
    x = conv_block(x, filters=256, strides=(2, 2))
    x = identity_block(x, filters=256)
    x = identity_block(x, filters=256)
    x = identity_block(x, filters=256)
    x = identity_block(x, filters=256)
    x = identity_block(x, filters=256)

    # stage 5
    x = conv_block(x, filters=512, strides=(2, 2))
    x = identity_block(x, filters=512)
    x = identity_block(x, filters=512)

    # output layer
    x = AveragePooling2D(pool_size=(7, 7))(x)
    x = Flatten()(x)
    x = Dense(classes)(x)

    # creating the model
    model = Model(inputs=x_input, outputs=x, name='ResNet50')

    return model


def get_resnet50_init_lip(padding='same',
                          activation=ReLU,
                          use_bias=True,
                          filters=64,
                          pixelwise=True,
                          channelwise=False,
                          regul_type="spectral_conv",
                          centering=False,
                          niter_spectral=DEFAULT_NITER_SPECTRAL,
                          init_l2_norm_pooling=False,
                          niter_bjorck=DEFAULT_NITER_BJORCK):
    def f(x):
        pad = padding
        bias = use_bias
        if pad == 'circular':
            x = CircularPadding(padding=(7//2, 7//2))(x)
            pad = 'valid'
        if pad == 'symmetric':
            x = SymmetricPadding(padding=(7//2, 7//2))(x)
            pad = 'valid'

        x = SpectralConv2D(filters, (7, 7),
                           regul_type=regul_type,
                           strides=(2, 2),
                           kernel_initializer="orthogonal",
                           use_bias=bias,
                           conv_first=False,
                           padding=pad,
                           niter_spectral=niter_spectral,
                           niter_bjorck=niter_bjorck)(x)
        if activation is not None:
            if centering:
                x = BatchCentering(pixelwise=pixelwise, channelwise=channelwise)(x)
            x = activation()(x)
        if init_l2_norm_pooling:
            x = ScaledL2NormPooling2D(pool_size=(2, 2))(x)
        else:
            x = InvertibleDownSampling(pool_size=(2, 2))(x)
        return x
    return f


@tf.custom_gradient
def sqrt_op(x):
    sqrtx = tf.sqrt(x)

    def grad(dy):
        return dy / (2 * (sqrtx + 1e-6))

    return sqrtx, grad


def resnet50_block(filters,
                   use_shortcut=True,
                   activation=ReLU,
                   padding='same',
                   max_cut=False,
                   use_bias=True,
                   first=False,
                   depth=False,
                   pixelwise=False,
                   channelwise=True,
                   activate_before=True,
                   regul_type="spectral_conv",
                   centering=False,
                   mult_dim=4,
                   add_coeff=DEFAULT_ADD_COEFF,
                   strides=(1, 1),
                   niter_spectral=DEFAULT_NITER_SPECTRAL,
                   niter_bjorck=DEFAULT_NITER_BJORCK):
    def f(x):
        pad = padding
        bias = use_bias
        short_cut = x
        pad = padding

        if (first):
            short_cut = SpectralConv2D(mult_dim * filters, strides,
                                       regul_type=regul_type,
                                       use_bias=bias,
                                       conv_first=False,
                                       kernel_initializer="orthogonal",
                                       padding='valid',
                                       strides=strides,
                                       niter_spectral=niter_spectral,
                                       niter_bjorck=niter_bjorck)(short_cut)
            if centering:
                short_cut = BatchCentering(
                    pixelwise=pixelwise, channelwise=channelwise)(short_cut)

        x = SpectralConv2D(filters, (1, 1),
                           regul_type=regul_type,
                           use_bias=bias,
                           conv_first=False,
                           kernel_initializer="orthogonal",
                           padding='valid',
                           niter_spectral=niter_spectral,
                           niter_bjorck=niter_bjorck)(x)
        if activation is not None:
            if centering:
                x = BatchCentering(pixelwise=pixelwise, channelwise=channelwise)(x)
            x = activation()(x)
        if pad == 'circular':
            x = CircularPadding(padding=(3//2, 3//2))(x)
            pad = 'valid'
        if pad == 'symmetric':
            x = SymmetricPadding(padding=(3//2, 3//2))(x)
            pad = 'valid'

        x = SpectralConv2D(filters, (3, 3),
                           regul_type=regul_type,
                           use_bias=bias,
                           conv_first=False,
                           kernel_initializer="orthogonal",
                           padding=pad,
                           niter_spectral=niter_spectral,
                           niter_bjorck=niter_bjorck)(x)
        if activation is not None:
            if centering:
                x = BatchCentering(pixelwise=pixelwise, channelwise=channelwise)(x)
            x = activation()(x)
        x = SpectralConv2D(mult_dim * filters, (1, 1),
                           regul_type=regul_type,
                           use_bias=bias,
                           conv_first=False,
                           strides=strides,
                           kernel_initializer="orthogonal",
                           padding='valid',
                           niter_spectral=niter_spectral,
                           niter_bjorck=niter_bjorck)(x)
        if centering:
            x = BatchCentering(pixelwise=pixelwise, channelwise=channelwise)(x)
        if activate_before:
            if activation is not None:
                x = activation()(x)
        scalar_ = Lambda(lambda y: y * 0.5)
        x = scalar_(tf.keras.layers.Add()([short_cut, x]))
        if activation is not None:
            x = activation()(x)
        return x
    return f


def ResNet50_lip(shape, coeff_filters=1.,
                               nb_classes=1,
                               padding='same',
                               activation_conv=None,
                               activation_dense=None,
                               centering=False,
                               use_bias=True,
                               pixelwise=False,
                               channelwise=True,
                               conv_1x1=False,
                               out_activation=False,
                               max_cut=False,
                               depth=False,
                               bias_output=True,
                               activate_before=True,
                               regul_type="spectral_conv",
                               init_l2_norm_pooling=False,
                               niter_spectral=DEFAULT_NITER_SPECTRAL,
                               niter_bjorck=DEFAULT_NITER_BJORCK,
                               verbose=False):
    inputs = Input(shape)
    filters = int(64*coeff_filters)

    default_parameters = {"strides": (1, 1), "use_shortcut": True, "activation": activation_conv, "use_bias": use_bias, "max_cut": max_cut,
                          "centering": centering, "padding": padding, "depth": depth, "regul_type": regul_type, "niter_spectral": niter_spectral,
                          "niter_bjorck": niter_bjorck, "activate_before": activate_before, "pixelwise": pixelwise, "channelwise": channelwise}
    x = get_resnet50_init_lip(filters=filters, padding=padding, activation=activation_conv, centering=centering, use_bias=use_bias, pixelwise=pixelwise, channelwise=channelwise,
                              regul_type=regul_type, niter_spectral=niter_spectral, niter_bjorck=niter_bjorck, init_l2_norm_pooling=init_l2_norm_pooling)(inputs)

    # first block
    x = resnet50_block(filters, first=init_l2_norm_pooling or conv_1x1,
                       mult_dim=4, **default_parameters)(x)
    x = resnet50_block(filters, first=False or conv_1x1,
                       mult_dim=4, **default_parameters)(x)
    x = resnet50_block(filters, first=False or conv_1x1,
                       mult_dim=4, **default_parameters)(x)
    x = ScaledL2NormPooling2D(pool_size=(2, 2))(x)

    # second block
    x = resnet50_block(2*filters, first=True, mult_dim=4, **default_parameters)(x)
    x = resnet50_block(2*filters, first=False or conv_1x1,
                       mult_dim=4, **default_parameters)(x)
    x = resnet50_block(2*filters, first=False or conv_1x1,
                       mult_dim=4, **default_parameters)(x)
    x = resnet50_block(2*filters, first=False or conv_1x1,
                       mult_dim=4, **default_parameters)(x)
    x = ScaledL2NormPooling2D(pool_size=(2, 2))(x)

    # third block
    x = resnet50_block(4*filters, first=True, mult_dim=4, **default_parameters)(x)
    x = resnet50_block(4*filters, first=False or conv_1x1,
                       mult_dim=4, **default_parameters)(x)
    x = resnet50_block(4*filters, first=False or conv_1x1,
                       mult_dim=4, **default_parameters)(x)
    x = resnet50_block(4*filters, first=False or conv_1x1,
                       mult_dim=4, **default_parameters)(x)
    x = resnet50_block(4*filters, first=False or conv_1x1,
                       mult_dim=4, **default_parameters)(x)
    x = resnet50_block(4*filters, first=False or conv_1x1,
                       mult_dim=4, **default_parameters)(x)
    x = ScaledL2NormPooling2D(pool_size=(2, 2))(x)

    # fourth block
    x = resnet50_block(8*filters, first=True, mult_dim=4, **default_parameters)(x)
    x = resnet50_block(8*filters, first=False or conv_1x1,
                       mult_dim=4, **default_parameters)(x)
    x = resnet50_block(8*filters, first=False or conv_1x1,
                       mult_dim=4, **default_parameters)(x)

    x = ScaledGlobalL2NormPooling2D()(x)
    if out_activation:
        if centering:
            x = BatchCentering()(x)
        if activation_dense is not None:
            x = activation_dense()(x)
    x = FrobeniusDense(nb_classes, activation=None, disjoint_neurons=True,
                       use_bias=bias_output, kernel_initializer="orthogonal")(x)

    model = LipModel(inputs=inputs, outputs=x)
    if verbose:
        model.summary()
    return model
