import tensorflow as tf
from layers import AggLayer

class OPEN_Model(tf.keras.Model):
    
    def __init__(self,
                 layer_units,
                 num_classes,
                 aggregators = None,
                 dropout_func = tf.keras.layers.GaussianDropout,
                 in_drop = 0.1,
                 L2 = 0.02,
                 activation =  tf.nn.relu,
                 num_feature=None):
        
        super(OPEN_Model, self).__init__()
        
        self.num_classes = num_classes
        self.num_feature = num_feature
        
        self.agg_layers = [] 
        self.dense_layers = [] 
        self.dropout_layer = dropout_func(in_drop) 
        self.activation = activation

        dense_layer = tf.keras.layers.Dense(layer_units[0],  
                                            use_bias = False,
                                            kernel_regularizer = tf.keras.regularizers.l2(L2))
        self.dense_layers.append(dense_layer)
  
        
        dense_layer = tf.keras.layers.Dense(num_classes, 
                                            use_bias = False, 
                                            kernel_regularizer = tf.keras.regularizers.l2(L2))
        self.dense_layers.append(dense_layer)        
        agg_layer = AggLayer(aggregators = aggregators,
                            dropout_layer = self.dropout_layer,
                            num_classes=num_classes)
        self.agg_layers.append(agg_layer)
        self.agg_layers.append(agg_layer)

    def call(self, x, training = True):
    
        x = self.dropout_layer(x)        
        x = self.dense_layers[0](x)       
        x = self.activation(x)
        x = self.dropout_layer(x)
        x = self.dense_layers[-1](x)          
        for agg in self.agg_layers:
            x = agg(x, training = training)    
        return x

