"""
----
Description:

    Implementation of a 2-d Bessel Convolutional Layer using Tensorflow/Keras.
    This layer is rotation-invariant thanks to the use of the
    Bessel functions of the first kind.
----

----
Requirements:

    * numpy
    * scipy
    * tensorflow
----

           Author: Anonymous
         Creation: 23-02-2021
Last modification: 26-02-2021
"""


import numpy as np
from scipy import special
import tensorflow as tf
from tensorflow import keras
from getTransMat import getTransMat


class BesselConv2d(keras.layers.Layer):
    def __init__(self, m_max, j_max, k, n_out, strides=1, padding='VALID', activation='relu', name=None, **kwargs):
        """
        This function initializes the layer.
        It is called only once, before any training.

        * m_max is the maximum value of m considered for the Bessel decomposition of the image
        * j_max                         j
        * k is the size of the sliding window used for the convolution
        * n_out is the number of filters
        * strides is similar to the classic strides parameters used in convolution
        * padding                           padding 
        * activation is the activation function used on the output of the layer
          Available activations are ['relu', 'sigmoid', 'tanh', None]
        """

        super(BesselConv2d, self).__init__(name=name, **kwargs)

        if not isinstance(m_max, int) or m_max < 1:
            print("'m_max' should be an integer > 0")
            m_max = 10
            print("'m_max' automatically set to 10")

        if not isinstance(j_max, int) or j_max < 1:
            print("'j_max' should be an integer > 0")
            j_max = 10
            print("'j_max' automatically set to 10")

        if k % 2 == 0:
            print("Kernel size 'k' should be an odd number.")
            k += 1
            print("'k' automatically set to", k)

        if not isinstance(n_out, int) or m_max < 1:
            print("'n_out' should be an integer > 0")
            n_out = 16
            print("'n_out' automatically set to 16")

        if padding not in ['VALID', 'SAME']:
            print("'padding' should be 'VALID' or 'SAME'")
            padding = 'VALID'
            print("'padding' automatically set to 'SAME'")

        if activation not in ['relu', 'sigmoid', 'tanh', None]:
            print("'activation' should be 'relu', 'sigmoid', 'tanh' or None")
            activation = 'relu'
            print("activation automatically set to 'relu'")
        
        self.m_max = m_max
        self.j_max = j_max
        self.k = k
        self.n_out = n_out
        self.strides = strides
        self.padding = padding
        self.activation = activation
        
        # Build the transformation matrix used to compute the effective filters.
        # Real and imaginary parts are splitted because CUDA can only handle floating
        # point numbers, even if tensorflow is able to handle complex numbers.
        transMat = tf.convert_to_tensor(getTransMat(m_max, j_max, k), dtype=tf.complex64)
        
        self.transMat_r = tf.Variable(
            initial_value=tf.math.real(transMat),
            shape=(k, k, m_max+1, j_max+1),
            dtype=tf.float32,
            trainable=False,
            name='transMat_r'
        )
        self.transMat_i = tf.Variable(
            initial_value=tf.math.imag(transMat),
            shape=(k, k, m_max+1, j_max+1),
            dtype=tf.float32,
            trainable=False,
            name='transMat_i'
        )


    def build(self, input_shape):
        """
        This function builds the model to match a particular input shape.
        It is called only once, when the __call__() method is called for the
        first time. One could also consider calling __build__() directly.

        input_shape is expected to be in the form (batch_axis, x, y, n_channels)
        """

        # Get the number of input channels.
        self.n_in = input_shape[3]
    
        # Initialize trainable weights.
        # Real and imaginary parts are splitted because CUDA can only handle floating
        # point numbers, even if tensorflow is able to handle complex numbers.
        self.w_r = self.add_weight(
            shape=(self.m_max+1, self.j_max+1, self.n_in, self.n_out),
            initializer="random_normal",
            dtype=tf.float32,
            trainable=True,
            name='w_r'
        )
        self.w_i = self.add_weight(
            shape=(self.m_max+1, self.j_max+1, self.n_in, self.n_out),
            initializer="random_normal",
            dtype=tf.float32,
            trainable=True,
            name='w_i'
        )

        # Initialize the biases.
        # There are as many biases as the number of filters of the layer (n_out).
        self.b = self.add_weight(
            shape=(self.n_out,),
            initializer="random_normal",
            dtype=tf.float32,
            trainable=True,
            name='biases'
        )


    def call(self, inputs):
        """
        This function computes the activation of the layer given (a) particular input(s).
        inputs is of shape (n_inputs, x, y, n_channels).

        As CUDA is not able to handle complex numbers, real and imaginary parts are
        treated separately.  
        """

        m_max = self.m_max
        j_max = self.j_max
        k = self.k
        n_in = self.n_in
        n_out = self.n_out

        # ----
        # Compute the effective real part of the filters w: w_r.
        # w_r is of shape (k, k, n_out, (m_max+1)*(n_in)).
        # ----

        # self.w_r is of shape (m_max+1, j_max+1, n_in, n_out).
        # self.transMat_r is of shape (k, k, m_max+1, j_max+1).
        # self.w_r * self.transMat_r contributes to the real part of w: w_r.
        w_r = tf.math.reduce_sum(
            tf.math.multiply(
                self.w_r[tf.newaxis,tf.newaxis,:,:,:,:], self.transMat_r[:,:,:,:,tf.newaxis,tf.newaxis]
            ),
            axis=3
        )
        # self.w_i is of shape (m_max+1, j_max+1, n_in, n_out).
        # self.transMat_i is of shape (k, k, m_max+1, j_max+1).
        # self.w_i * self.transMat_i contributes to the real part of w: w_r.
        w_r = tf.math.add(
            w_r,
            tf.math.reduce_sum(
                tf.math.multiply(
                    self.w_i[tf.newaxis,tf.newaxis,:,:,:,:], self.transMat_i[:,:,:,:,tf.newaxis,tf.newaxis]
                ),
                axis=3
            )
        )
        # tf.nn.conv2d only takes 4-d tensors as input.
        # n_out and m_max are then wrapped together before performing convolutions.
        # They will be unwrapped later.
        w_r = tf.transpose(w_r, perm=[0, 1, 3, 2, 4])
        w_r = tf.reshape(w_r, shape=(k,k,n_in,n_out*(m_max+1)))

        # ----
        # Compute the effective imaginary part of the filters w: w_i.
        # w_i is of shape (k, k, n_out, (m_max+1)*(n_in)).
        # ----

        # self.w_r is of shape (m_max+1, j_max+1, n_in, n_out).
        # self.transMat_i is of shape (k, k, m_max+1, j_max+1).
        # self.w_r * self.transMat_i contributes to the imaginary part of w: w_i.
        w_i = tf.math.reduce_sum(
            tf.math.multiply(
                self.w_r[tf.newaxis,tf.newaxis,:,:,:,:], self.transMat_i[:,:,:,:,tf.newaxis,tf.newaxis]
            ),
            axis=3
        )
        # self.w_i is of shape (m_max+1, j_max+1, n_in, n_out).
        # self.transMat_r is of shape (k, k, m_max+1, j_max+1).
        # self.w_i * self.transMat_r contributes to the imaginary part of w: w_i.
        w_i = tf.math.add(
            w_i,
            tf.math.reduce_sum(
                tf.math.multiply(
                    tf.math.multiply(
                        tf.math.negative(tf.ones(shape=(1,1,m_max+1,j_max+1,n_in,n_out), dtype=tf.float32)), 
                        self.w_i[tf.newaxis,tf.newaxis,:,:,:,:]
                    ),
                    self.transMat_r[:,:,:,:,tf.newaxis,tf.newaxis]
                ),
                axis=3
            )
        )
        # tf.nn.conv2d only takes 4-d tensors as input.
        # n_out and m_max are then wrapped together before performing convolutions.
        # They will be unwrapped later.
        w_i = tf.transpose(w_i, perm=[0, 1, 3, 2, 4])
        w_i = tf.reshape(w_i, shape=(k,k,n_in,n_out*(m_max+1)))

        # ----
        # Computation of the activation.
        # ----

        a_r = tf.math.square(
            tf.nn.conv2d(inputs[:,:,:,:], w_r[:,:,:,:], padding=self.padding, strides=self.strides)
        )
        a_i = tf.math.square(
            tf.nn.conv2d(inputs[:,:,:,:], w_i[:,:,:,:], padding=self.padding, strides=self.strides)
        )
        a = tf.math.add(
            tf.math.reduce_sum(
                tf.reshape(
                    tf.math.add(a_r, a_i), 
                    shape=(-1, a_r.shape[1], a_r.shape[2], m_max+1, n_out)
                ), 
                axis=3
            ),
            self.b[tf.newaxis,tf.newaxis,tf.newaxis,:]
        )

        if self.activation == 'relu':
            return tf.keras.activations.relu(a)
        elif self.activation == 'sigmoid':
            return tf.keras.activations.sigmoid(a)
        elif self.activation == 'tanh':
            return tf.keras.activations.tanh(a)
        else:
            return a


    def get_config(self):
        """
        This function generates a config in order to save the model at its current state.

        A model using BesselConv2d layer(s) can then be saved using
            >> model.save('./model.h5')
        and loaded with
            >> model = tf.keras.models.load_model('./model.h5', custom_objects={'BesselConv2d': BesselConv2d})
        """

        config = super(BesselConv2d, self).get_config()
        config.update({"m_max": self.m_max,
                       "j_max": self.j_max,
                       "k": self.k,
                       "n_out": self.n_out,
                       "strides": self.strides,
                       "padding": self.padding,
                       "activation": self.activation})

        return config
