from convnet_builder import *
import tensorflow as tf


class RTSConvNetBuilder(ConvNetBuilder):


    def __init__(self,
                 input_op,
                 input_nchan,
                 phase_train,
                 use_tf_layers,
                 data_format,
                 rts_params,
                 dtype=tf.float32,
                 variable_dtype=tf.float32,
                 use_dense_layer=False):
        super(RTSConvNetBuilder, self).__init__(input_op=input_op, input_nchan=input_nchan, phase_train=phase_train,
            use_tf_layers=use_tf_layers, data_format=data_format, dtype=dtype, variable_dtype=variable_dtype, use_dense_layer=use_dense_layer)
        self.rts_params = rts_params



    def conv(self,
             num_out_channels,
             k_height,
             k_width,
             d_height=1,
             d_width=1,
             mode='SAME',
             input_layer=None,
             num_channels_in=None,
             use_batch_norm=None,
             stddev=None,
             activation='relu',
             bias=0.0,
             kernel_initializer=None,
             specify_padding=None,
             name=None,
             name_postfix=None,
             count_convs=True,
             just_as_classic_conv=True):    # just_as_classic_conv is no used by convnet_builder.py
        """Construct a conv2d layer on top of cnn."""
        if input_layer is None:
            input_layer = self.top_layer
        if num_channels_in is None:
            num_channels_in = self.top_size
        if stddev is not None and kernel_initializer is None:
            kernel_initializer = tf.truncated_normal_initializer(stddev=stddev)
        if name is None:
            name = 'conv' + str(self.counts['conv'])
        if name_postfix is not None:
            name += name_postfix
        if count_convs:
            self.counts['conv'] += 1
        with tf.variable_scope(name):
            strides = [1, d_height, d_width, 1]
            if self.data_format == 'NCHW':
                strides = [strides[0], strides[3], strides[1], strides[2]]

            if mode != 'SAME_RESNET':
                conv = self._conv2d_impl(input_layer, num_channels_in, num_out_channels,
                    kernel_size=[k_height, k_width],
                    strides=[d_height, d_width], padding=mode,
                    kernel_initializer=kernel_initializer, specify_padding=specify_padding)
            else:  # Special padding mode for ResNet models
                if d_height == 1 and d_width == 1:
                    conv = self._conv2d_impl(input_layer, num_channels_in,
                        num_out_channels,
                        kernel_size=[k_height, k_width],
                        strides=[d_height, d_width], padding='SAME',
                        kernel_initializer=kernel_initializer)
                else:
                    rate = 1  # Unused (for 'a trous' convolutions)
                    kernel_height_effective = k_height + (k_height - 1) * (rate - 1)
                    pad_h_beg = (kernel_height_effective - 1) // 2
                    pad_h_end = kernel_height_effective - 1 - pad_h_beg
                    kernel_width_effective = k_width + (k_width - 1) * (rate - 1)
                    pad_w_beg = (kernel_width_effective - 1) // 2
                    pad_w_end = kernel_width_effective - 1 - pad_w_beg
                    padding = [[0, 0], [pad_h_beg, pad_h_end],
                               [pad_w_beg, pad_w_end], [0, 0]]
                    if self.data_format == 'NCHW':
                        padding = [padding[0], padding[3], padding[1], padding[2]]
                    input_layer = tf.pad(input_layer, padding)
                    conv = self._conv2d_impl(input_layer, num_channels_in,
                        num_out_channels,
                        kernel_size=[k_height, k_width],
                        strides=[d_height, d_width], padding='VALID',
                        kernel_initializer=kernel_initializer)
                    # assert False, 'shawn does not know what this is'

            if use_batch_norm is None:
                use_batch_norm = self.use_batch_norm
            if not use_batch_norm:
                if bias is not None:
                    biases = self.get_variable('biases', [num_out_channels],
                        self.variable_dtype, self.dtype,
                        initializer=tf.constant_initializer(bias))
                    biased = tf.reshape(
                        tf.nn.bias_add(conv, biases, data_format=self.data_format),
                        conv.get_shape())
                else:
                    biased = conv
            else:
                self.top_layer = conv
                self.top_size = num_out_channels
                biased = self.batch_norm(**self.batch_norm_config)

            #   TODO record the internal outputs of conv
            if self.need_record_internal_outputs:
                self.internal_outputs_dict['{}${}'.format(self.num_internal_conv_outputs, name)] = conv
                self.internal_outputs_dict['{}#{}'.format(self.num_internal_conv_outputs, name)] = biased
                self.num_internal_conv_outputs += 1

            if activation == 'relu':
                conv1 = tf.nn.relu(biased)
            elif activation == 'linear' or activation is None:
                conv1 = biased
            elif activation == 'tanh':
                conv1 = tf.nn.tanh(biased)
            elif activation == 'sigmoid':
                conv1 = tf.nn.sigmoid(biased)
            else:
                raise KeyError('Invalid activation type \'%s\'' % activation)
            self.top_layer = conv1
            self.top_size = num_out_channels
            return conv1


    def get_bds_metrics(self):
        print('the current metric is ', self.bds_params.metric_type)
        print('len of normal_output is {}, len of shadow_output is {}'.format(len(self.normal_output), len(self.shadow_output)))
        result = {}
        for i in self.shadow_output.keys():
            result[i - 1] = tf.reduce_mean((self.normal_output[i] - self.shadow_output[i]) ** 2) / tf.reduce_mean(self.normal_output[i] ** 2)        #   TODO may not work for resnets
        if (self.cur_layer_idx - 1) in self.bds_params.target_layers:
            result[self.cur_layer_idx - 1] = tf.reduce_mean((self.first_fc_shadow_output - self.first_fc_normal_output) ** 2) / tf.reduce_mean(self.first_fc_normal_output ** 2)
        print('the keys of bds metrics are {}'.format(result.keys()))
        return result
