import numpy as np
import mxnet as mx
import math


class Identity(mx.init.Initializer):
    def __init__(self, init_value=None):
        super(Identity, self).__init__(init_value=init_value)


class mx_dnn_layer(object):
    def __init__(self, input_dim, output_dim, active_op='prelu', use_batch_norm=False, version="default"):

        self.active_op = active_op
        self.use_batch_norm = use_batch_norm
        self.version = version 
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.fc_out = None
        self.out = None

        self.w = None
        self.b = None
        self.alpha = None

        # BatchNorm Param
        self.bn_gamma = None
        self.bn_bias = None

        # Moving mean/var
        self.bn_moving_mean = None
        self.bn_moving_var = None

    def call(self, bottom_data):

        init_mean = 0.0
        init_stddev = 1.   # 0.001
        init_value = (init_stddev * np.random.randn(self.output_dim, self.input_dim).astype(np.float32) + init_mean) / np.sqrt(self.input_dim)
        self.w = mx.sym.var(name='fc_w_%s' % self.version, init=Identity(init_value=init_value.tolist()))
        self.b = mx.sym.var(name='fc_b_%s' % self.version, init=mx.init.Constant(0.1))
        self.out = mx.symbol.FullyConnected(data=bottom_data, name=('fc_%s' % self.version), num_hidden=self.output_dim, weight=self.w, bias=self.b)
        print "if mx.symbol.FullyConnected"

        if self.use_batch_norm:
            print "if self.use_batch_norm:"
            # BatchNorm Param
            self.bn_gamma = mx.sym.var(name='bn_gamma_1_%s' % self.version, shape=(self.output_dim, ), init=mx.init.Constant(1.))
            self.bn_bias = mx.sym.var(name='bn_bias_1_%s' % self.version, shape=(self.output_dim, ), init=mx.init.Constant(0.))

            # Moving mean/var
            self.bn_moving_mean = mx.sym.zeros((self.output_dim,) )
            self.bn_moving_var = mx.sym.zeros((self.output_dim,) )
            self.out = mx.symbol.BatchNorm(data=self.out, fix_gamma=False, name=('bn_%s' % self.version), gamma=self.bn_gamma, beta=self.bn_bias)

        if self.active_op == 'prelu':
            print "if self.active_op == 'prelu':"
            self.alpha = mx.sym.var(name='alpha_1_%s' % self.version, shape=(self.output_dim, ), init=mx.init.Constant(0.25))
            self.out = mx.symbol.LeakyReLU(data=self.out, act_type='prelu', slope=-0.25, name=('prelu_%s' % self.version), gamma=self.alpha)

        return self.out
