import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
import scipy 
#import tensorflow_probability as tfp
import sys
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.metrics import binary_crossentropy,categorical_crossentropy,binary_accuracy,BinaryCrossentropy
from tensorflow.keras.utils import register_keras_serializable
from tensorflow.keras.losses import Loss
from tensorflow.keras.losses import Reduction
from tensorflow.python import debug as tf_debug
def wasserstein_acc():
    def wasserstein_acc_fct(y_true, y_pred):
        S_true= tf.dtypes.cast(tf.greater(y_true[:,0], 0),dtype=tf.float32)
        S_pred= tf.dtypes.cast(tf.greater(y_pred[:,0],0),dtype=tf.float32)
        return binary_accuracy(S_true,S_pred)
    return wasserstein_acc_fct



@tf.function
def soft_hinge(x,m):
    hinge = tf.nn.relu(x)
    hinge = tf.where(tf.math.logical_and(x>=0,x<=m),(1./(2*m))*hinge**2,hinge)
    hinge = tf.where(x>m,hinge-m/2,hinge)
    return hinge

#@register_keras_serializable("deel-lip", "HKR_binary")
class HKR_binary():

    def __init__(self,  margin_coeff = 0.1, alpha = 'auto', beta =1., soft_hinge = False,init_margin = None, min_margin = 0.01,verbose = False):
        self.margin_coeff = K.variable(margin_coeff)
        self.beta = beta
        self.soft_hinge = soft_hinge
        if alpha == 'auto' :
            alpha = 1/margin_coeff
        self.alpha = K.variable(alpha, name='alpha', dtype=tf.float32)
        self.min_margin = min_margin
        if init_margin is None :
            self.init_margin = min_margin
        else :
            self.init_margin = init_margin
        self.margins = tf.Variable(self.init_margin,dtype=tf.float32,constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        self.verbose = verbose
        self.eps = 1e-7
        self.__name__ = "HKR_binary"
        self.nb = K.variable(0, name='nb') 
    

    def set_margin_coeff(self, margin_coeff, alpha = 'auto'):
        self.margin_coeff.assign(margin_coeff)
        if alpha == 'auto' :
            self.alpha.assign(1/margin_coeff)
    @tf.function 
    def __call__(self, y_true, y_pred):
        S0 = tf.equal(y_true, 1)
        S1 = tf.not_equal(y_true, 1)
        #tf.print(y_true.shape)
        # compute the KR dual representation
        KR_loss = K.mean(tf.boolean_mask(y_pred, S0)) - K.mean(tf.boolean_mask(y_pred, S1))
        sign = K.sign(y_true-self.eps)  # subtracting a small eps make the loss works for bot (1,0) and (1,-1) labels
        var_x = self.margins - sign * y_pred
        
        if self.soft_hinge:
            hinge = soft_hinge(var_x,0.9*self.margins)
        else :
            hinge = tf.nn.relu(var_x)
        #tf.print("kr",KR_loss,"hinge",hinge,"mean", tf.reduce_mean(tf.abs(y_pred)),tf.reduce_min(y_true),tf.reduce_max(y_true))
        self.nb.assign(self.nb+1)
        if self.verbose and self.nb %500 == 0:
            tf.print("margin :",self.margin)
        return self.alpha * (hinge-self.margin_coeff*self.margins) - self.beta* KR_loss

 

    def get_config(self):
        config = {"alpha": self.alpha, "min_margin": self.min_margin}
        return config
    
    
#@register_keras_serializable("deel-lip", "HKR_binary_auto")
class HKR_binary_auto():

    def __init__(self,  perc = 5, alpha = 'auto', min_margin = 0.01,verbose = False):
        if alpha == 'auto' :
            alpha = 100/perc
        self.alpha = K.variable(alpha, name='alpha', dtype=tf.float32)
        self.min_margin = K.variable(min_margin, name='min_margin', dtype=tf.float32)
        self.learning_rate = 0.001
        self.current_perc =  K.variable(perc/100, name='current_perc', dtype=tf.float64)
        self.eps = 1e-7
        self.perc = perc
        self.verbose = verbose
        self.__name__ = "HKR_auto"
        self.nb = K.variable(0, name='nb') 
        
    @tf.function 
    def __call__(self, y_true, y_pred):
        S0 = tf.equal(y_true, 1)
        S1 = tf.not_equal(y_true, 1)
        
        # compute the KR dual representation
        KR_loss = K.mean(tf.boolean_mask(y_pred, S0)) - K.mean(tf.boolean_mask(y_pred, S1))
        sign = K.sign(y_true-self.eps)  # subtracting a small eps make the loss works for bot (1,0) and (1,-1) labels
        
        #alpha = tf.nn.relu(self.alpha-1)
        
        #self.alpha.assign(alpha)
        dist_pos = tf.boolean_mask(y_pred, S0)
       
        min_margin = tf.math.maximum(0.1,(1-self.learning_rate)*self.min_margin + self.learning_rate * tfp.stats.percentile(dist_pos,self.perc))
        self.min_margin.assign(min_margin)
        in_margin = tf.boolean_mask(dist_pos,tf.less(dist_pos, self.min_margin))
        current_perc = (1-self.learning_rate)*self.current_perc +self.learning_rate*(tf.size(in_margin)/tf.size(dist_pos))
        self.current_perc.assign(current_perc)
        hinge = K.maximum(0.0, self.min_margin - sign * y_pred)
        hinge = K.mean(hinge)
        #tf.print(tfp.stats.percentile(dist_pos,2) ,perc, tfp.stats.percentile(dist_pos,50) ,K.mean(dist_pos), "Inside loss function", output_stream=sys.stdout)
        self.nb.assign(self.nb+1)
        #self.alpha.assign(tf.dtypes.cast(current_perc, tf.float32)*10000/(self.perc*self.perc))
        
        if self.verbose and self.nb %500 == 0:
            tf.print(min_margin,tfp.stats.percentile(dist_pos,self.perc),self.current_perc,self.alpha)
        return self.alpha * hinge - KR_loss

 

    def get_config(self):
        config = {"alpha": self.alpha, "min_margin": self.min_margin}
        return config


def margin():
    @tf.function
    def wasserstein_acc_fct(y_true, y_pred):
        S_true= tf.dtypes.cast(tf.greater_equal(y_true[:,0], 0),dtype=tf.float32)
        S_pred= tf.dtypes.cast(tf.greater_equal(y_pred[:,0],0),dtype=tf.float32)
        return binary_accuracy(S_true,S_pred)
    return wasserstein_acc_fct    
    
#@register_keras_serializable("deel-lip", "HKR_multiclass_auto")
class HKR_cross_ent ():

    def __init__(self, nb_class , MIN_MARGIN = 0.02,margin_coeff=0.2):
        self.MIN_MARGIN = MIN_MARGIN
        self.nb_class = nb_class
       
        self.margin_coeff = margin_coeff
       
        self.margins = tf.Variable(np.array([self.MIN_MARGIN]*nb_class),dtype=tf.float32)
        
   

    @tf.function
    def __call__(self, y_true, y_pred):
        return tf.reduce_mean(categorical_crossentropy(y_true, y_pred*self.margins, from_logits = True))-self.margin_coeff*tf.reduce_mean(self.margins)
      
        
    def get_config(self):
        config = {"alpha": self.alpha, "min_margin": self.margins}
        return config
    

class HKR_multiclass_auto():

    def __init__(self, nb_class , 
                 alpha = 1, 
                 beta = 1,
                 t_max = 10.,
                 t_min = .3,
                 regul_coeff = 0,
                 soft_hinge = True, 
                 soft_KR = False,
                 init_margin = None, 
                 min_margin = 0.02,
                 verbose_batch = 100,
                 verbose = True):
        
        self.min_margin = min_margin
        self.nb_class = nb_class
        self.t_max = t_max
        self.t_min = t_min
        self.regul_coeff = regul_coeff
        self.verbose_batch = verbose_batch
        if init_margin is None :
            init_margin = min_margin
        self.init_margin = init_margin
        self.margins = tf.Variable(np.array([self.init_margin]*nb_class),dtype=tf.float32)
        #beta = 1
        if alpha<0 :
            alpha =1
            beta=0
        tf.print("loss start ",output_stream=sys.stdout)
        self.alpha = alpha
        self.beta = beta
        self.eps = 1e-8
        self.soft_KR = soft_KR
        self.__name__ = "HKR_multiclass_auto"
        self.verbose  = verbose
        
 

    @tf.function
    def __call__(self, y_true, y_pred):
        espYtrue = tf.reduce_sum(y_pred * y_true, axis=0) / (tf.reduce_sum(y_true, axis=0)+self.eps)
        vYtrue = tf.reduce_sum(y_pred * y_true, axis=1)  ## keep only y_true value
        H1 = tf.where(y_true==1,tf.reduce_min(y_pred), y_pred) ## set y_true at minimum on batch to avoid being the max
        maxOthers = tf.reduce_max(H1, axis=1)  # keep only not y_true max value 
        H2 = tf.where(H1==tf.expand_dims(maxOthers,1),tf.reduce_min(H1), H1)
        maxOthers_2 = tf.reduce_max(H2, axis=1)  # keep only the second max value
        if self.soft_KR:
            H_soft = tf.nn.softmax(tf.where(y_true == 1,-tf.float32.max, y_pred), axis=0)
            espNotYtrue = tf.reduce_sum(H_soft * H1, axis=0)
        else:
            espNotYtrue = tf.reduce_sum(y_pred * (1 - y_true), axis=0) / (
                tf.cast(tf.shape(y_true)[0], dtype=tf.float32)
                - tf.reduce_sum(y_true, axis=0)+self.eps)
        
        # compute the differences to have the KR term for each output neuron, and compute the average over the classes
        KR_loss = tf.reduce_mean((-espNotYtrue + espYtrue)*tf.reduce_max(y_true, axis=0))  
        ind_s = tf.equal(y_true, 1)
        vMargin = tf.reduce_sum(self.margins * y_true, axis=1, keepdims=True)
        diffs = tf.abs(maxOthers-maxOthers_2)
        temperatures = 1.2/(diffs+self.eps)
        #tf.print(f"temperatures : max : {tf.reduce_max(temperatures):0.2f} mean: {tf.reduce_mean(temperatures:0.2f}",output_stream=sys.stdout)
        temperatures = tf.stop_gradient(tf.clip_by_value(temperatures, clip_value_min=self.t_min, clip_value_max = self.t_max))
        vYtrue = tf.reduce_sum(y_pred * y_true, axis=1, keepdims=True)
        y_pred_temperature = tf.multiply(y_pred, temperatures[...,tf.newaxis] )
        y_pred_temperature = tf.where(y_true==1,-tf.float32.max,y_pred_temperature)
        F_soft = tf.nn.softmax(y_pred_temperature,axis = 1)
        
        hinge = tf.nn.relu(vMargin/2-vYtrue)+ tf.nn.relu(vMargin/2+y_pred)
        real_classes = tf.reduce_max(y_true,axis = 0)
        regul = 1/tf.reduce_sum(real_classes)*tf.reduce_sum(self.margins*real_classes)
        loss_val = self.alpha *( hinge-self.regul_coeff*regul) - self.beta*KR_loss
       
        #if self.verbose and self.nb %self.verbose_batch  == 0:
            
         #   tf.print(tf.shape(y_true),
         #            tf.shape(y_pred),
         #            "margins",tf.reduce_mean(self.margins),
         #            " temperature",tf.reduce_mean(temperatures),
         #            summarize=self.nb_class, 
         #            output_stream=sys.stdout)
        
        return loss_val

        
    def get_config(self):
        config = {"alpha": self.alpha, "min_margin": self.margins}
        return config
    
    
    
class HKR_multiclass_auto_test ():

    def __init__(self, nb_class , alpha = 1, beta = 1,out_factor = 1., t_coeff = 1.2, inv_coeff = False, old_version = False, centered = True,soft_hinge = True, soft_KR = False,init_margin = None, min_margin = 0.02,margin_coeff=0.2,verbose = True):
        self.min_margin = min_margin
        self.nb_class = nb_class
        self.margin_coeff = margin_coeff
        self.old_version = old_version
        self.t_coeff = t_coeff
        self.centered = centered
        self.out_factor = out_factor
        if init_margin is None :
            init_margin = min_margin
        self.init_margin = init_margin
        self.margins = tf.Variable(np.array([self.init_margin]*nb_class),dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        #beta = 1
        if alpha<0 :
            alpha =1
            beta=0
        if inv_coeff :
            beta=1./alpha
            alpha = 1.
        self.alpha = K.variable(alpha, name='alpha', dtype=K.floatx())
        self.beta = K.variable(beta, name='beta', dtype=K.floatx())
        self.eps = 1e-8
        self.nb = K.variable(0, name='nb') 
        self.soft_KR = soft_KR
        self.soft_hinge = soft_hinge
        self.__name__ = "HKR_multiclass_auto"
        self.verbose  = verbose
        
 

    @tf.function
    def __call__(self, y_true, y_pred):
        #y_pred = y_pred*self.out_factor
        #y_true = tf.one_hot(tf.cast(y_true,tf.int32), tf.cast(self.nb_class,tf.int32))
        espYtrue = tf.reduce_sum(y_pred * y_true, axis=0) / (tf.reduce_sum(y_true, axis=0)+self.eps)
        # use(1- y_true) to zero out y_pred where y_true == 1
        # espNotYtrue is the avg value of y_pred when y_true==0 (one average per output neuron)
        
         
        vYtrue = tf.reduce_sum(y_pred * y_true, axis=1)  ## keep only y_true value
        H1 = tf.where(y_true==1,tf.reduce_min(y_pred), y_pred) ## set y_true at minimum on batch to avoid being the max
        maxOthers = tf.reduce_max(H1, axis=1)  # keep only not y_true max value 
        H2 = tf.where(H1==tf.expand_dims(maxOthers,1),tf.reduce_min(H1), H1)
        
        maxOthers_2 = tf.reduce_max(H2, axis=1)  # keep only the second max value
        if self.soft_KR:
            temperatures_kr = 1. /(tf.math.reduce_std(y_pred)+self.eps)
           
            if self.verbose:
                tf.print("temperatures kr: max : ",temperatures_kr,output_stream=sys.stdout)
            temperatures_kr = tf.stop_gradient(tf.clip_by_value(temperatures_kr, clip_value_min=0.1, clip_value_max = self.out_factor/4))
            #temperatures_kr = 1.
            H_soft =tf.nn.softmax(tf.where(y_true == 1,-tf.float32.max, temperatures_kr*y_pred), axis=0)
            if self.verbose:
                tf.print("softmax kr: max : ",tf.reduce_max(H_soft)," mean: ",tf.reduce_mean(H_soft))
            #espNotYtrue = tf.reduce_sum(y_true* tf.expand_dims(maxOthers,axis=1),axis=0)/ (
            #    tf.reduce_sum(y_true, axis=0)+self.eps)
            espNotYtrue = tf.reduce_sum(H_soft * H1, axis=0)
        else:
            espNotYtrue = tf.reduce_sum(y_pred * (1 - y_true), axis=0) / (
                tf.cast(tf.shape(y_true)[0], dtype=K.floatx())
                - tf.reduce_sum(y_true, axis=0)+self.eps)
        
        # compute the differences to have the KR term for each output neuron, and compute the average over the classes
        KR_loss = tf.reduce_mean((-espNotYtrue + espYtrue)*tf.reduce_max(y_true, axis=0))
        
        
        
        ind_s = tf.equal(y_true, 1)
 
       
        vMargin = tf.reduce_sum(self.margins * y_true, axis=1, keepdims=True)
        diffs = tf.abs(maxOthers-maxOthers_2)
        #temperatures = tf.squeeze(vMargin)/(diffs+self.eps)*self.out_factor
        
        temperatures = self.t_coeff /(diffs+self.eps)
        if self.verbose:
            tf.print("temperatures : max : ",tf.reduce_max(temperatures), "mean: ", tf.reduce_mean(temperatures),output_stream=sys.stdout)
        temperatures = tf.stop_gradient(tf.clip_by_value(temperatures, clip_value_min=0.1, clip_value_max = self.out_factor))
        #contains_error =tf.math.logical_not(tf.math.logical_not(tf.math.is_finite(temperatures)))
        #if tf.equal(contains_error, tf.constant(True)): 
        #tf.print("KR_loss",KR_loss, output_stream=sys.stdout)
        #tf.print("y_pred :",y_pred,summarize=-1, output_stream=sys.stdout)
        #tf.print("diffs:",diffs,summarize=-1, output_stream=sys.stdout)
        #tf.print("diffs:",tf.reduce_min(diffs),tf.reduce_max(diffs), output_stream=sys.stdout)
        if self.old_version :
            sign = tf.where(y_true == 1, 1.0, -1.0)
            
            # compute the elementwise hinge term
            hinge = tf.reduce_mean(tf.nn.relu(self.margins - sign * y_pred))
        else :
            vYtrue = tf.reduce_sum(y_pred * y_true, axis=1, keepdims=True)   
            if self.soft_hinge : 
                
                y_pred_temperature = tf.multiply(y_pred, temperatures[...,tf.newaxis] )
                #y_pred_temperature = tf.where(y_true==1,-tf.float32.max,y_pred_temperature)
                #tf.print("temps:",tf.reduce_min(y_pred_temperature),tf.reduce_max(y_pred_temperature), output_stream=sys.stdout)
                #tf.print(y_pred_temperature, output_stream=sys.stdout)
                #exp_part = tf.exp(y_pred_temperature)
                #exp_part = tf.where(y_true==1,0.,exp_part)
                #denom =  1./tf.reduce_sum(exp_part, axis = 1)[...,tf.newaxis]
                #F_soft = tf.multiply(exp_part,denom)
                #F_soft = tf.where(tf.math.is_nan(F_soft), 1., F_soft) 

                #F_soft = tf.clip_by_value(F_soft, clip_value_min=1e-6, clip_value_max = 1)
                #F_soft = tf.where(y_true==1,0.,F_soft)
                #F_soft = F_soft/ tf.reshape(tf.reduce_sum(F_soft,axis = 1), (-1, 1)) 

                #F_soft = tf.nn.softmax(y_pred_temperature,axis = 1)
                #y_pred_temperature = tf.where(y_true==1,-tf.float32.max,y_pred)
                y_pred_temperature = tf.where(y_true==1,-tf.float32.max,y_pred_temperature)
                F_soft = tf.stop_gradient(tf.nn.softmax(y_pred_temperature,axis = 1))
                if self.verbose:
                    tf.print("softmax : max : ",tf.reduce_max(F_soft)," mean: ",tf.reduce_mean(F_soft)," mean_max ",tf.reduce_mean(tf.reduce_max(F_soft,axis=1)),output_stream=sys.stdout)
                
                #tf.print("f_soft",F_soft,summarize=-1, output_stream=sys.stdout)
                
                #tf.print(tf.reduce_sum(hinge *F_soft ,axis=1),summarize=-1, output_stream=sys.stdout)
            else :
                F_soft = tf.where(y_true==1,0.,y_pred)
                F_soft = tf.where(y_true==0,1./(self.nb_class-1.),F_soft)
            if self.centered :
                hinge = tf.nn.relu(vMargin/2-vYtrue)+ tf.nn.relu(vMargin/2+y_pred)
            else :
                hinge = tf.nn.relu(vMargin-vYtrue+y_pred)
            hinge = tf.reduce_mean(tf.reduce_sum(hinge *F_soft ,axis=1))
            
            #tf.print("Hinge red",hinge)
       # tf.print(hinge, output_stream=sys.stdout)
         ## two steps is useless
        #vYtrue = -tf.nn.relu(-vYtrue+vMargin/2)
        #hinge = tf.nn.relu( - vYtrue + tf.nn.relu(vMargin/2+maxOthers))
        #hinge = tf.nn.relu(vMargin - vYtrue + maxOthers)
        real_classes = tf.reduce_max(y_true,axis = 0)
        #regul = 1/tf.reduce_sum(real_classes)*tf.norm(self.margins*real_classes)
        regul = 1/tf.reduce_sum(real_classes)*tf.reduce_sum(self.margins*real_classes)
        #tf.print(regul, output_stream=sys.stdout)
        loss_val = self.alpha *( hinge-self.margin_coeff*regul) - self.beta*KR_loss
        #tf.print(loss_val, output_stream=sys.stdout)
        self.nb.assign(self.nb+1)
       
        #if self.verbose and self.nb %390  == 0:
            
        #    tf.print(tf.shape(y_true),tf.shape(y_pred),"margins",tf.reduce_mean(self.margins)," temperature",tf.reduce_mean(temperatures),summarize=self.nb_class, output_stream=sys.stdout)
        
        return loss_val
        #return tf.reduce_mean(categorical_crossentropy(y_true, y_pred*self.margins, from_logits = True))-self.margin_coeff*tf.reduce_mean(vMargin)
    

        
    def get_config(self):
        config = {"alpha": self.alpha, "min_margin": self.margins}
        return config


class HKR_multiclass_auto_testv2 ():

    def __init__(self, nb_class , alpha = 1, beta = 1,out_factor = 1., KR_coeff = 1., inv_coeff = False, old_version = False, centered = True,soft_hinge = True, soft_KR = False,init_margin = None, min_margin = 0.02,margin_coeff=0.2,verbose = True):
        self.min_margin = min_margin
        self.nb_class = nb_class
        self.margin_coeff = margin_coeff
        self.old_version = old_version
        self.KR_coeff = KR_coeff
        self.centered = centered
        self.out_factor = out_factor
        if init_margin is None :
            init_margin = min_margin
        self.init_margin = init_margin
        self.margins = tf.Variable(np.array([self.init_margin]*nb_class),dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        #beta = 1
        if alpha<0 :
            alpha =1
            beta=0
        if inv_coeff :
            beta=1./alpha
            alpha = 1.
        self.alpha = K.variable(alpha, name='alpha', dtype=K.floatx())
        self.beta = K.variable(beta, name='beta', dtype=K.floatx())
        self.eps = 1e-8
        self.nb = K.variable(0, name='nb') 
        self.soft_KR = soft_KR
        self.soft_hinge = soft_hinge
        self.__name__ = "HKR_multiclass_auto_testv2"
        self.verbose  = verbose
        
 

    @tf.function
    def __call__(self, y_true, y_pred):
        espYtrue = tf.reduce_sum(y_pred * y_true, axis=0) / (tf.reduce_sum(y_true, axis=0)+self.eps)
         
        vYtrue = tf.reduce_sum(y_pred * y_true, axis=1)  ## keep only y_true value
        H1 = tf.where(y_true==1,tf.reduce_min(y_pred), y_pred) ## set y_true at minimum on batch to avoid being the max
        maxOthers = tf.reduce_max(H1, axis=1)  # keep only not y_true max value 
        H2 = tf.where(H1==tf.expand_dims(maxOthers,1),tf.reduce_min(H1), H1)
        
        maxOthers_2 = tf.reduce_max(H2, axis=1)  # keep only the second max value
        if self.soft_KR:
         
            H_soft =tf.stop_gradient(tf.nn.softmax(tf.where(y_true == 1,-tf.float32.max, self.KR_coeff*y_pred), axis=0))
            if self.verbose:
                tf.print("softmax kr: max : ",tf.reduce_max(H_soft)," mean: ",tf.reduce_mean(H_soft))
          
            espNotYtrue = tf.reduce_sum(H_soft * H1, axis=0)
        else:
            espNotYtrue = tf.reduce_sum(y_pred * (1 - y_true), axis=0) / (
                tf.cast(tf.shape(y_true)[0], dtype=K.floatx())
                - tf.reduce_sum(y_true, axis=0)+self.eps)
        
        # compute the differences to have the KR term for each output neuron, and compute the average over the classes
        KR_loss = tf.reduce_mean((-espNotYtrue + espYtrue)*tf.reduce_max(y_true, axis=0))
        
        
        
        ind_s = tf.equal(y_true, 1)
 
       
        vMargin = tf.reduce_sum(self.margins * y_true, axis=1, keepdims=True)
        if self.old_version :
            sign = tf.where(y_true == 1, 1.0, -1.0)
            
            # compute the elementwise hinge term
            hinge = tf.reduce_mean(tf.nn.relu(self.margins - sign * y_pred))
        else :   
            if self.soft_hinge : 
                vYtrue = tf.reduce_sum(y_pred * y_true, axis=1, keepdims=True)

                y_pred_temperature = tf.where(y_true==1,-tf.float32.max,y_pred*self.out_factor)
                F_soft = tf.stop_gradient(tf.nn.softmax(y_pred_temperature,axis = 1))
                if self.verbose:
                    tf.print("softmax : max : ",tf.reduce_max(F_soft)," mean: ",tf.reduce_mean(F_soft)," mean_max ",tf.reduce_mean(tf.reduce_max(F_soft,axis=1)),output_stream=sys.stdout)
                
                #tf.print("f_soft",F_soft,summarize=-1, output_stream=sys.stdout)
                if self.centered :
                    hinge = tf.nn.relu(vMargin/2-vYtrue)+ tf.nn.relu(vMargin/2+y_pred)
                else :
                    hinge = tf.nn.relu(vMargin-vYtrue+y_pred)
                #tf.print(tf.reduce_sum(hinge *F_soft ,axis=1),summarize=-1, output_stream=sys.stdout)
                hinge = tf.reduce_mean(tf.reduce_sum(hinge *F_soft ,axis=1))
            else :
                vYtrue = tf.reduce_sum(y_pred * y_true, axis=1, keepdims=True)
                F_soft =tf.where(y_true==1,0.,1./(self.nb_class-1))
                if self.centered :
                    hinge = tf.nn.relu(vMargin/2-vYtrue)+ tf.nn.relu(vMargin/2+y_pred)
                else :
                    hinge = tf.nn.relu(vMargin-vYtrue+y_pred)
                
                hinge = tf.reduce_mean(tf.reduce_sum(hinge *F_soft ,axis=1))

        real_classes = tf.reduce_max(y_true,axis = 0)
        #regul = 1/tf.reduce_sum(real_classes)*tf.norm(self.margins*real_classes)
        regul = 1/tf.reduce_sum(real_classes)*tf.reduce_sum(self.margins*real_classes)
        #tf.print(regul, output_stream=sys.stdout)
        loss_val = self.alpha *( hinge-self.margin_coeff*regul) - self.beta*KR_loss
        #tf.print(loss_val, output_stream=sys.stdout)
        self.nb.assign(self.nb+1)

        

        return loss_val
        #return tf.reduce_mean(categorical_crossentropy(y_true, y_pred*self.margins, from_logits = True))-self.margin_coeff*tf.reduce_mean(vMargin)
    

        
    def get_config(self):
        config = {"alpha": self.alpha, "min_margin": self.margins}
        return config


class HKR_multiclass_auto_testv3():

    def __init__(self, nb_class , soft_kr = True, alpha = 1, beta = 1,temperature = 1.,init_margin = None, min_margin = 0.02,margin_coeff=0.2,verbose = True):
        self.min_margin = min_margin
        self.nb_class = nb_class
        self.margin_coeff = margin_coeff
        self.temperature = temperature
        self.soft_kr = soft_kr
        if init_margin is None :
            init_margin = min_margin
        self.init_margin = init_margin
        self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        #beta = 1
        if alpha<0 :
            alpha =1
            beta=0
        self.alpha = K.variable(alpha, name='alpha', dtype=K.floatx())
        self.beta = K.variable(beta, name='beta', dtype=K.floatx())
        self.eps = 1e-8
        self.nb = K.variable(0, name='nb') 
        self.__name__ = "HKR_multiclass_auto_testv3"
        self.verbose  = verbose
        
 

    @tf.function
    def __call__(self, y_true, y_pred):
        #number of positives examples per class/columns
        row_sum = tf.reduce_sum(y_true, axis = 0)
        row_neg = y_true - row_sum
        #at least  of positives examples per class/columns
        row_max = tf.reduce_max(y_true, axis = 0)
        #nb classes respresented
        nb_rep = tf.reduce_sum(row_max)

        #columns importance = non represented classes count half the value of represented ones
        class_coeff =  tf.where(row_max == 0,nb_rep/(2.*(self.nb_class)),1.)
        #class_coeff =  tf.where(row_max == 0,0.,row_sum)
        class_coeff = class_coeff/tf.reduce_sum(class_coeff)
        
        #modification of values
        #y_pred_temperature = tf.where(y_true==1,-tf.float32.max,y_pred*self.temperature)
        y_pred_temperature = tf.where(y_true==1,-tf.float32.max,y_pred*self.temperature/self.margins)
        #softmax per columns

        F_soft = tf.nn.softmax(y_pred_temperature,axis = 0)
        F_soft = tf.stop_gradient(tf.where(y_true==1,1./row_sum,F_soft) )
        if self.soft_kr :
            F_soft_KR = F_soft
        else :
            F_soft_KR = tf.stop_gradient(tf.where(y_true==0,1./row_neg,F_soft) )
        #kr computation per case
        KR = tf.where(y_true==0,-y_pred,y_pred)
        #hinge computation per case
        hinge = tf.nn.relu(self.margins/2 - KR)
        #loss_val per columns/class
        loss_val = tf.reduce_sum(F_soft*self.alpha*hinge - self.beta*F_soft_KR*KR,axis = 0)
        #margins regul
        loss_val = loss_val -self.margin_coeff*self.alpha*self.margins
        #weighted sum
        loss_val = tf.reduce_sum(loss_val*class_coeff)
        self.nb.assign(self.nb+1)

        return loss_val


class HKR_multiclass_auto_testv4():

    def __init__(self, nb_class ,  alpha = 1, beta = 1,init_margin = None, min_margin = 0.02,empty_const = False, margin_coeff=0.2,verbose = True):
        self.min_margin = min_margin
        self.nb_class = nb_class
        self.margin_coeff = margin_coeff
        if init_margin is None :
            init_margin = min_margin
        self.init_margin = init_margin
        self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        #beta = 1
        if alpha<0 :
            alpha =1
            beta=0
        self.empty_const = empty_const
        self.alpha = K.variable(alpha, name='alpha', dtype=K.floatx())
        self.beta = K.variable(beta, name='beta', dtype=K.floatx())
        self.eps = 1e-8
        self.nb = K.variable(0, name='nb') 
        self.__name__ = "HKR_multiclass_auto_testv4"
        self.verbose  = verbose
        
 

    @tf.function
    def __call__(self, y_true, y_pred):
        #number of positives examples per class/columns
        row_sum = tf.reduce_sum(y_true, axis = 0)
        row_neg = tf.cast(tf.shape(y_true)[0],tf.float32) - row_sum
        #at least  of positives examples per class/columns
        row_max = tf.reduce_max(y_true, axis = 0)
        #nb classes respresented
        nb_rep = tf.reduce_sum(row_max)

        #columns importance = non represented classes count half the value of represented ones
        if self.empty_const:
            class_coeff =  tf.where(row_max == 0,1./(self.nb_class-nb_rep),row_sum)
            #class_coeff =  tf.where(row_max == 0,nb_rep/(2.*(self.nb_class-nb_rep)),row_sum)
        else :
            class_coeff =  tf.where(row_max == 0,nb_rep/(2.*(self.nb_class-nb_rep)),row_sum)
        #class_coeff =  tf.where(row_max == 0,0.,row_sum)
        class_coeff = class_coeff/tf.reduce_sum(class_coeff)
        #softmax per columns
        F_coeff = tf.where(y_true==1,1./row_sum,y_pred)
        F_coeff = tf.where(y_true==0,1./row_neg,F_coeff)
        #kr computation per case
        KR = tf.where(y_true==0,-y_pred,y_pred)
        #hinge computation per case
  
        hinge = tf.nn.relu(self.margins/2 - KR)
        #loss_val per columns/class
        loss_val = tf.reduce_sum(F_coeff*(self.alpha*hinge - self.beta*KR),axis = 0)
        #margins regul
        loss_val = loss_val -self.margin_coeff*self.alpha*self.margins
        #weighted sum
        loss_val = tf.reduce_sum(loss_val*class_coeff)
        
        class_val = tf.reduce_sum(y_pred * y_true, axis=1)  ## keep only y_true value
        class_margin = tf.reduce_sum(self.margins * y_true, axis=1, keepdims=True)
        H1 = tf.where(y_true==1,tf.reduce_min(y_pred), y_pred) ## set y_true at minimum on batch to avoid being the max
        max_others = tf.reduce_max(H1, axis=1)  # keep only not y_true max value 
        #hinge_cols = tf.reduce_mean(class_margin-class_val+max_others)
        coeff = 1.
        if self.empty_const:
            coeff = 2.
        hinge_cols = tf.reduce_mean(tf.nn.relu(class_margin/(2*coeff)-class_val)+ tf.nn.relu(class_margin/(2*coeff)+max_others))
        #y_norms = tf.norm(y_pred,axis = -1)
        loss_val += hinge_cols*self.alpha*coeff
        #loss_val += -2*tf.reduce_mean(class_val/y_norms )
        self.nb.assign(self.nb+1)

        return loss_val


class HKR_multiclass_auto_testv5():

    def __init__(self, nb_class ,  alpha = 1, beta = 1,init_margin = None, min_margin = 0.02, margin_coeff=0.2,temperature =100,verbose = True):
        self.min_margin = min_margin
        self.nb_class = nb_class
        self.margin_coeff = margin_coeff
        if init_margin is None :
            init_margin = min_margin
        self.init_margin = init_margin
        self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        #beta = 1
        if alpha<0 :
            alpha =1
            beta=0
        self.temperature = temperature
        self.alpha = K.variable(alpha, name='alpha', dtype=K.floatx())
        self.beta = K.variable(beta, name='beta', dtype=K.floatx())
        self.eps = 1e-8
        self.nb = K.variable(0, name='nb') 
        self.regul_loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
        self.__name__ = "HKR_multiclass_auto_testv5"
        self.verbose  = verbose
        
 

    @tf.function
    def __call__(self, y_true, y_pred):
        #number of positives examples per class/columns
        row_sum = tf.reduce_sum(y_true, axis = 0)
        row_neg = tf.cast(tf.shape(y_true)[0],tf.float32) - row_sum
        #at least  of positives examples per class/columns
        row_max = tf.reduce_max(y_true, axis = 0)
        #nb classes respresented
        nb_rep = tf.reduce_sum(row_max)

        #columns importance = non represented classes count half the value of represented ones
       
        class_coeff =  tf.where(row_max == 0,1./(self.nb_class-nb_rep),row_sum)

        #class_coeff =  tf.where(row_max == 0,0.,row_sum)
        class_coeff = class_coeff/tf.reduce_sum(class_coeff)
        #softmax per columns
        F_coeff = tf.where(y_true==1,1./row_sum,y_pred)
        F_coeff = tf.where(y_true==0,1./row_neg,F_coeff)
        #kr computation per case
        KR = tf.where(y_true==0,-y_pred,y_pred)
        #hinge computation per case
  
        hinge = tf.nn.relu(self.margins/2 - KR)
        #loss_val per columns/class
        loss_val = tf.reduce_sum(F_coeff*(self.alpha*hinge - self.beta*KR),axis = 0)
        #margins regul
        loss_val = loss_val -self.margin_coeff*self.alpha*self.margins
        #weighted sum
        loss_val = tf.reduce_sum(loss_val*class_coeff)
        
        
        #y_norms = tf.norm(y_pred,axis = -1)
        loss_val += self.alpha*self.regul_loss(y_true,y_pred*self.temperature)
        #loss_val += -2*tf.reduce_mean(class_val/y_norms )
        self.nb.assign(self.nb+1)

        return loss_val

class HKR_multiclass_auto_testv6():

    def __init__(self, nb_class ,  alpha = 1, beta = 1,init_margin = None, min_margin = 0.02, margin_coeff=0.2,temperature =100,verbose = True):
        self.min_margin = min_margin
        self.nb_class = nb_class
        self.margin_coeff = margin_coeff
        if init_margin is None :
            init_margin = min_margin
        self.init_margin = init_margin
        self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        #beta = 1
        if alpha<0 :
            alpha =1
            beta=0
        self.temperature = temperature
        self.alpha = K.variable(alpha, name='alpha', dtype=K.floatx())
        self.beta = K.variable(beta, name='beta', dtype=K.floatx())
        self.eps = 1e-8
        self.nb = K.variable(0, name='nb') 
        self.regul_loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
        self.__name__ = "HKR_multiclass_auto_testv6"
        self.verbose  = verbose
        
 

    @tf.function
    def __call__(self, y_true, y_pred):
        #number of positives examples per class/columns
        row_sum = tf.reduce_sum(y_true, axis = 0)
        row_neg = tf.cast(tf.shape(y_true)[0],tf.float32) - row_sum
        #at least  of positives examples per class/columns
        row_max = tf.reduce_max(y_true, axis = 0)
        #nb classes respresented
        nb_rep = tf.reduce_sum(row_max)

        #columns importance = non represented classes count half the value of represented ones
       
        class_coeff =  tf.where(row_max == 0,1./(self.nb_class-nb_rep),1.)

        #class_coeff =  tf.where(row_max == 0,0.,row_sum)
        class_coeff = class_coeff/tf.reduce_sum(class_coeff)
        #softmax per columns
        F_coeff = tf.where(y_true==1,1./row_sum,y_pred)
        F_coeff = tf.where(y_true==0,1./row_neg,F_coeff)
        #tf.print("class coeff",class_coeff)
        #tf.print("F_coeff",F_coeff)
        #kr computation per case
        KR = tf.where(y_true==0,-y_pred,y_pred)
        #hinge computation per case
  
        hinge = tf.nn.relu(self.margins/2 - KR)
        #loss_val per columns/class
        loss_val = tf.reduce_sum(F_coeff*(self.alpha*hinge - self.beta*KR),axis = 0)
        #margins regul
        loss_val = loss_val -self.margin_coeff*self.alpha*self.margins
        #weighted sum
        loss_val = tf.reduce_mean(loss_val)
        
        
        #y_norms = tf.norm(y_pred,axis = -1)
        #loss_val += self.alpha*self.regul_loss(y_true,y_pred*self.temperature)
        #loss_val += -2*tf.reduce_mean(class_val/y_norms )
        self.nb.assign(self.nb+1)

        return loss_val


class HKR_multiclass_auto_testv7():

    def __init__(self, nb_class , alpha = 1, beta = 1,temperature = 1.,alpha_mean = 0.99, balance = 0.5,auto_margin = False,soft_hinge = False,init_margin = None,weight_pos =1.,  multi_margin = True, min_margin = 0.02,margin_coeff=0.2,verbose = True):
        self.min_margin = min_margin
        self.nb_class = nb_class
        self.margin_coeff = margin_coeff
        self.temperature = temperature
        if init_margin is None :
            init_margin = min_margin
        self.init_margin = init_margin
        self.alpha_mean = alpha_mean
        self.auto_margin = auto_margin
        if auto_margin:
            multi_margin = True
        self.soft_hinge = soft_hinge
        self.weight_pos = weight_pos
        self.balance = balance 
        if multi_margin :
            self.margins = tf.Variable([self.init_margin]*nb_class,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        else :
            self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        #self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        self.moving_mean = tf.Variable([self.init_margin]*nb_class,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, 0.005, 1000))
        #beta = 1
        if alpha<0 :
            alpha =1
            beta=0
        self.alpha = K.variable(alpha, name='alpha', dtype=K.floatx())
        self.beta = K.variable(beta, name='beta', dtype=K.floatx())
        self.eps = 1e-8
        self.nb = K.variable(0, name='nb') 
        self.__name__ = "HKR_multiclass_auto_testv7"
        self.verbose  = verbose
        
 

    @tf.function
    def __call__(self, y_true, y_pred):
        #number of positives examples per class/columns
        row_sum = tf.reduce_sum(y_true, axis = 0)
        #at least  of positives examples per class/columns
        row_max = tf.reduce_max(y_true, axis = 0)
        #nb classes respresented
        nb_rep = tf.reduce_sum(row_max)
        row_neg = tf.cast(tf.shape(y_true)[0],tf.float32) - row_sum
        #upadating moving_mean
        current_mean = tf.reduce_mean(tf.abs(y_pred),axis = 0)
        #tf.print(current_mean)
        #tf.print(self.moving_mean)
        current_mean = self.alpha_mean*self.moving_mean + (1-self.alpha_mean)*current_mean
        self.moving_mean.assign(current_mean)
        if self.auto_margin:
            self.margins.assign(tf.clip_by_value(current_mean*self.margin_coeff, self.min_margin , 500))
        empty_coeff = tf.math.minimum(nb_rep/(self.nb_class),1.)
        tf.print("ecoeff ",self.soft_hinge," mean", current_mean)
        class_coeffs = tf.where(y_true==1,1.,1./(self.nb_class-1))
        #modification of values
        #y_pred_temperature = tf.where(y_true==1,-tf.float32.max,y_pred*self.temperature)
        curr_temp =tf.clip_by_value(self.temperature/current_mean, 0.005, 250) 
        y_pred_temperature = tf.where(y_true==1,-tf.float32.max,y_pred*curr_temp)
        #softmax per columns

        F_soft = tf.nn.softmax(y_pred_temperature,axis = 0)
        tf.print("softmax : ",tf.reduce_mean((tf.reduce_max(F_soft,axis = 0))))
        F_soft = F_soft*class_coeffs*row_neg
        #tf.print(F_soft)
        
        F_soft_hinge = tf.stop_gradient(tf.where(y_true==1,self.weight_pos,(2.-self.weight_pos)*F_soft) )
        
        F_soft_KR = tf.stop_gradient(tf.where(y_true==1,1.,F_soft) )
        #tf.print(F_soft_KR)
        #kr computation per case
        KR = tf.where(y_true==0,-y_pred,y_pred)
        #hinge computation per case
        #var_x = self.margins/2 - KR
        var_x = tf.where(y_true==1,self.margins*self.balance - KR,self.margins*(1-self.balance) - KR)
        if self.soft_hinge:
            hinge = soft_hinge(var_x,0.9*self.margins/2.)
        else :
            hinge = tf.nn.relu(var_x)

        #loss_val per columns/class
        loss_val = tf.reduce_mean(F_soft_hinge*self.alpha*hinge - self.beta*F_soft_KR*KR,axis = 0)
        #tf.print("1",loss_val)
        #margins regul
        if not self.auto_margin:
            loss_val = loss_val -self.margin_coeff*self.alpha*self.margins*tf.reduce_mean(F_soft_hinge,axis = 0)
        #loss_val = -self.margin_coeff*self.alpha*self.margins
        tf.print("margin",self.margins)
        #tf.print("2",loss_val)
        #weighted sum
        loss_val = tf.reduce_sum(loss_val)
        #tf.print("3", loss_val)
        self.nb.assign(self.nb+1)

        return loss_val

class HKR_multiclass_testv7_Vanilla():

    def __init__(self, nb_class , alpha = 1, beta = 1,alpha_mean = 0.99, balance = 0.5,auto_margin = False,soft_hinge = False,init_margin = None,weight_pos =1.,  multi_margin = True, min_margin = 0.02,margin_coeff=0.2,verbose = True):
        self.min_margin = min_margin
        self.nb_class = nb_class
        self.margin_coeff = margin_coeff
        if init_margin is None :
            init_margin = min_margin
        self.init_margin = init_margin
        self.alpha_mean = alpha_mean
        self.auto_margin = auto_margin
        if auto_margin:
            multi_margin = True
        self.soft_hinge = soft_hinge
        self.weight_pos = weight_pos
        self.balance = balance 
        if multi_margin :
            self.margins = tf.Variable([self.init_margin]*nb_class,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        else :
            self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        #self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        self.moving_mean = tf.Variable([self.init_margin]*nb_class,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, 0.005, 1000))
        #beta = 1
        if alpha<0 :
            alpha =1
            beta=0
        self.alpha = K.variable(alpha, name='alpha', dtype=K.floatx())
        self.beta = K.variable(beta, name='beta', dtype=K.floatx())
        self.eps = 1e-8
        self.nb = K.variable(0, name='nb') 
        self.__name__ = "HKR_multiclass_testv7_Vanilla"
        self.verbose  = verbose
        
 

    @tf.function
    def __call__(self, y_true, y_pred):
        #number of positives examples per class/columns
        row_sum = tf.reduce_sum(y_true, axis = 0)
        #at least  of positives examples per class/columns
        row_max = tf.reduce_max(y_true, axis = 0)
        #nb classes respresented
        nb_rep = tf.reduce_sum(row_max)
        row_neg = tf.cast(tf.shape(y_true)[0],tf.float32) - row_sum
        #upadating moving_mean
        current_mean = tf.reduce_mean(tf.abs(y_pred),axis = 0)
        #tf.print(current_mean)
        #tf.print(self.moving_mean)
        current_mean = self.alpha_mean*self.moving_mean + (1-self.alpha_mean)*current_mean
        self.moving_mean.assign(current_mean)
        if self.auto_margin:
            self.margins.assign(tf.clip_by_value(current_mean*self.margin_coeff, self.min_margin , 500))
        
        tf.print("ecoeff ",self.soft_hinge," mean", current_mean)
        class_coeffs = tf.where(y_true==1,1.,1./(self.nb_class-1))
        #modification of values
        #y_pred_temperature = tf.where(y_true==1,-tf.float32.max,y_pred*self.temperature)
       

        F_soft =class_coeffs
        #tf.print(F_soft)
        
        F_soft_hinge = tf.where(y_true==1,self.weight_pos,(2.-self.weight_pos)*F_soft) 
        
        F_soft_KR = tf.where(y_true==1,1.,F_soft) 
        #tf.print(F_soft_KR)
        #kr computation per case
        KR = tf.where(y_true==0,-y_pred,y_pred)
        #hinge computation per case
        #var_x = self.margins/2 - KR
        var_x = tf.where(y_true==1,self.margins*self.balance - KR,self.margins*(1-self.balance) - KR)
        if self.soft_hinge:
            hinge = soft_hinge(var_x,0.9*self.margins/2.)
        else :
            hinge = tf.nn.relu(var_x)
        non_zero = tf.cast(tf.math.count_nonzero(hinge),tf.float32)/tf.cast(self.nb_class*tf.shape(y_true)[0],tf.float32)
        
        #loss_val per columns/class
        loss_val = tf.reduce_mean(F_soft_hinge*self.alpha*hinge - self.beta*F_soft_KR*KR,axis = 0)
        #tf.print("1",loss_val)
        #margins regul
        if not self.auto_margin:
            loss_val = loss_val -self.margin_coeff*self.alpha*self.margins*tf.reduce_mean(F_soft_hinge,axis = 0)
        #loss_val = -self.margin_coeff*self.alpha*self.margins
        tf.print("margin",self.margins)
        #tf.print("2",loss_val)
        #weighted sum
        loss_val = tf.reduce_sum(loss_val)
        tf.print("non_zero :",non_zero," hinge :",tf.reduce_sum(tf.reduce_mean(F_soft_hinge*hinge,axis = 0)),"KR :",tf.reduce_sum(tf.reduce_mean(F_soft_KR*KR,axis = 0)) )
        #tf.print("3", loss_val)
        self.nb.assign(self.nb+1)

        return loss_val

class HKR_multiclass_testv7_Max():

    def __init__(self, nb_class , alpha = 1, max_coeff = 0.5,beta = 1,alpha_mean = 0.99, balance = 0.5,auto_margin = False,soft_hinge = False,init_margin = None,weight_pos =1.,  multi_margin = True, min_margin = 0.02,margin_coeff=0.2,verbose = True):
        self.min_margin = min_margin
        self.nb_class = nb_class
        self.margin_coeff = margin_coeff
        if init_margin is None :
            init_margin = min_margin
        self.init_margin = init_margin
        self.alpha_mean = alpha_mean
        self.auto_margin = auto_margin
        if auto_margin:
            multi_margin = True
        self.soft_hinge = soft_hinge
        self.weight_pos = weight_pos
        self.balance = balance 
        self.max_coeff = max_coeff
        if multi_margin :
            self.margins = tf.Variable([self.init_margin]*nb_class,dtype=K.floatx())
        else :
            self.margins = tf.Variable(self.init_margin,dtype=K.floatx())
        #self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        self.moving_mean = tf.Variable([self.init_margin]*nb_class,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, 0.005, 1000))
        #beta = 1
        if alpha<0 :
            alpha =1
            beta=0
        self.alpha = K.variable(alpha, name='alpha', dtype=K.floatx())
        self.beta = K.variable(beta, name='beta', dtype=K.floatx())
        self.eps = 1e-8
        self.nb = K.variable(0, name='nb') 
        self.__name__ = "HKR_multiclass_testv7_Max"
        self.verbose  = verbose
        
 

    @tf.function
    def __call__(self, y_true, y_pred):
        #number of positives examples per class/columns
        row_sum = tf.reduce_sum(y_true, axis = 0)
        #at least  of positives examples per class/columns
        row_max = tf.reduce_max(y_true, axis = 0)
        #nb classes respresented
        nb_rep = tf.reduce_sum(row_max)
        row_neg = tf.cast(tf.shape(y_true)[0],tf.float32) - row_sum
        #upadating moving_mean
        current_mean = tf.reduce_mean(tf.abs(y_pred),axis = 0)
        #tf.print(current_mean)
        #tf.print(self.moving_mean)
        current_mean = self.alpha_mean*self.moving_mean + (1-self.alpha_mean)*current_mean
        total_mean=  tf.reduce_mean(current_mean)
        self.moving_mean.assign(current_mean)
        if self.auto_margin:
            if total_mean*self.margin_coeff<= self.min_margin:
                diff_mean = 1+(current_mean - total_mean)/total_mean
                self.margins.assign(self.min_margin*diff_mean)
            else :
                self.margins.assign(current_mean*self.margin_coeff)
        
        tf.print("ecoeff ",self.soft_hinge," mean", current_mean)
        F_soft_KR = tf.where(y_true==1,1.,1./(self.nb_class-1))
        #tf.print(F_soft_KR)
        #kr computation per case
        KR = tf.where(y_true==0,-y_pred,y_pred)
        #hinge computation per case
        #var_x = self.margins/2 - KR

        kr_loss = tf.reduce_sum( tf.reduce_mean(F_soft_KR*KR,axis = 0))
        margin_row = tf.reduce_sum(self.margins * y_true, axis=1, keepdims=True)
        y_pred_row = tf.nn.relu(margin_row/(2*self.max_coeff) - tf.reduce_sum(y_pred * y_true, axis=1, keepdims=True))

        opposite_values = tf.where(y_true==1,-tf.float32.max,y_pred)
        y_pred_opposite = tf.reduce_max(opposite_values,axis = 1,keepdims=True)
        margin_opposite = tf.where(opposite_values ==y_pred_opposite,1.,0.)
        margin_opposite = tf.reduce_sum(self.margins *margin_opposite, axis=1, keepdims=True)

        y_pred_temp_row =tf.nn.relu( margin_opposite/(2*self.max_coeff) + y_pred_opposite)
        non_zero_row = tf.cast(tf.math.count_nonzero(y_pred_row)+tf.math.count_nonzero(y_pred_temp_row),tf.float32)/(2.*tf.cast(tf.shape(y_true)[0],tf.float32))
        hinge_row = tf.reduce_mean(y_pred_row + y_pred_temp_row)
        
        F_soft_col = tf.where(y_true==1,self.weight_pos,(2.-self.weight_pos)*F_soft_KR) 
        var_x = tf.where(y_true==1,self.margins*self.balance - KR,self.margins*(1-self.balance) - KR)
        if self.soft_hinge:
            hinge_col = soft_hinge(var_x,0.9*self.margins/2.)
        else :
            hinge_col = tf.nn.relu(var_x)
        non_zero_col = tf.cast(tf.math.count_nonzero(hinge_col),tf.float32)/tf.cast(self.nb_class*tf.shape(y_true)[0],tf.float32)
        hinge_col = tf.reduce_sum( tf.reduce_mean(F_soft_col*hinge_col,axis = 0))

        loss_val = hinge_row*self.alpha+hinge_col*self.alpha - self.beta*kr_loss
        
        #loss_val = -self.margin_coeff*self.alpha*self.margins
        tf.print("margin",self.margins)
        #tf.print("2",loss_val)
        #weighted sum
        tf.print("non_zero col :",non_zero_col,"non_zero row :",non_zero_row,"hinge col :", hinge_col,"hinge row :",hinge_row ,"KR :",kr_loss )
        #tf.print("3", loss_val)
        self.nb.assign(self.nb+1)

        return loss_val

class HKR_multiclass_row_grad():

    def __init__(self, nb_class , 
                alpha = 1,
                beta = 1,
                stop_gradient = True,
                alpha_mean = 0.99, 
                temperature = 1.,
                variable_temp = False,
                variable_margin = False,
                init_margin = None,
                min_margin = 0.02,
                margin_coeff=0.2,
                verbose = True):
        self.min_margin = min_margin
        self.nb_class = nb_class
        self.margin_coeff = margin_coeff
        if init_margin is None :
            init_margin = min_margin
        self.variable_temp = variable_temp
        self.variable_margin = variable_margin
        self.stop_gradient = stop_gradient
        self.init_margin = init_margin
        self.alpha_mean = alpha_mean
        self.temperature = temperature
        self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        #self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        self.moving_mean = tf.Variable([self.init_margin]*nb_class,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, 0.005, 1000))
        #beta = 1
        if alpha<0 :
            alpha =1
            beta=0
        self.alpha = K.variable(alpha, name='alpha', dtype=K.floatx())
        self.beta = K.variable(beta, name='beta', dtype=K.floatx())
        self.eps = 1e-8
        self.nb = K.variable(0, name='nb') 
        self.__name__ = "HKR_multiclass_row"
        self.verbose  = verbose
        
 

    @tf.function
    def __call__(self, y_true, y_pred):
        #number of positives examples per class/columns
        row_sum = tf.reduce_sum(y_true, axis = 0)
        #at least  of positives examples per class/columns
        row_max = tf.reduce_max(y_true, axis = 0)
        #nb classes respresented
        nb_rep = tf.reduce_sum(row_max)
        row_neg = tf.cast(tf.shape(y_true)[0],tf.float32) - row_sum
        #upadating moving_mean
        current_mean = tf.reduce_mean(tf.abs(y_pred),axis = 0)
        #tf.print(current_mean)
        #tf.print(self.moving_mean)
        current_mean = self.alpha_mean*self.moving_mean + (1-self.alpha_mean)*current_mean
        self.moving_mean.assign(current_mean)
        total_mean=  tf.reduce_mean(current_mean)
        if self.variable_temp :
            curr_temp =tf.clip_by_value(self.temperature/total_mean, 0.005, 100) 
        else :
            curr_temp = self.temperature
        opposite_values = tf.where(y_true==1,-tf.float32.max,curr_temp*y_pred)
       
        
        F_soft_KR = tf.nn.softmax(opposite_values)
        if self.stop_gradient:
            F_soft_KR = tf.stop_gradient(F_soft_KR)
        tf.print("margin",self.margins, "temp_kr :",curr_temp, "temp_hinge :",curr_temp, "soft_max :",tf.reduce_mean(tf.reduce_max(F_soft_KR,axis = 1)))
        F_soft_KR = tf.where(y_true==1,1.,F_soft_KR) 
        KR = tf.where(y_true==0,-y_pred,y_pred)
        #hinge computation per case
        #var_x = self.margins/2 - KR
        if self.variable_margin:
            self.margins.assign(tf.clip_by_value(total_mean*self.margin_coeff, self.min_margin, 200))
        opposite_values_hinge = tf.where(KR>self.margins/2,-tf.float32.max,opposite_values)
        F_soft_hinge = tf.nn.softmax(opposite_values_hinge)
        if self.stop_gradient:
            F_soft_hinge = tf.stop_gradient(F_soft_hinge)
        F_soft_hinge = tf.where(y_true==1,1.,F_soft_hinge) 
        hinge_row = tf.nn.relu(self.margins/2 - KR)*F_soft_hinge
        non_zero_row = tf.cast(tf.math.count_nonzero(hinge_row),tf.float32)/(self.nb_class*tf.cast(tf.shape(y_true)[0],tf.float32))
        hinge_row = tf.reduce_sum(hinge_row,axis = 1)

        kr_row = tf.reduce_sum( F_soft_KR*KR,axis = 1)
       
        loss_val = tf.reduce_mean(hinge_row*self.alpha- self.beta*kr_row)
        
        #loss_val = -self.margin_coeff*self.alpha*self.margins
        
        #tf.print("2",loss_val)
        #weighted sum
        tf.print("non_zero row :",non_zero_row,"hinge row :",tf.reduce_mean(hinge_row) ,"KR :",tf.reduce_mean(kr_row) )
        #tf.print("3", loss_val)
        self.nb.assign(self.nb+1)
        return loss_val




class HKR_multiclass_hinge_auto():

    def __init__(self, nb_class , 
                alpha = 1,
                beta = 1,
                stop_gradient = False,
                alpha_mean = 0.99, 
                temperature = 1.,
                margin_coeff=0.2,
                init_margin = None,
                min_margin = 0.02,
                verbose = True):
        self.min_margin = min_margin
        self.nb_class = nb_class
        self.margin_coeff = margin_coeff
        if init_margin is None :
            init_margin = min_margin
        self.stop_gradient = stop_gradient
        self.init_margin = init_margin
        self.alpha_mean = alpha_mean
        self.temperature = temperature*init_margin
        self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        self.moving_mean = tf.Variable([self.init_margin]*nb_class,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, 0.005, 1000))
        #beta = 1
        if alpha<0 :
            alpha =1
            beta=0
        self.alpha = K.variable(alpha, name='alpha', dtype=K.floatx())
        self.beta = K.variable(beta, name='beta', dtype=K.floatx())
        self.eps = 1e-8
        self.nb = K.variable(0, name='nb') 
        self.__name__ = "HKR_multiclass_hinge_auto"
        self.verbose  = verbose
        
    @tf.function
    def _update_mean(self, y_pred):
        current_mean = tf.cast(tf.reduce_mean(tf.abs(y_pred),axis = 0), self.moving_mean.dtype)
        current_mean = self.alpha_mean*self.moving_mean + (1-self.alpha_mean)*current_mean
        self.moving_mean.assign(current_mean)
        total_mean=  tf.reduce_mean(current_mean)
        return current_mean, total_mean

    @tf.function
    def __call__(self, y_true, y_pred):
        y_pred = tf.cast(y_pred,self.margins.dtype)
        current_mean, total_mean = self._update_mean(y_pred)
        total_mean = tf.clip_by_value(total_mean,self.init_margin, 20000)
        #self.margins.assign(tf.clip_by_value(total_mean*self.margin_coeff, self.min_margin, 200))
        #current_margin = tf.stop_gradient(current_mean/total_mean*self.margins)
        current_temperature = tf.cast(tf.stop_gradient(tf.clip_by_value(self.temperature/total_mean,0.005, 250)),y_pred.dtype)

        opposite_values = tf.where(y_true==1,-y_pred.dtype.max,current_temperature*y_pred)
        F_soft_KR = tf.nn.softmax(opposite_values)
       
        tf.print( "temp_hinge :",current_temperature, "soft_max :",tf.reduce_mean(tf.reduce_max(F_soft_KR,axis = 1)))
        F_soft_KR = tf.where(y_true==1,tf.cast(1.,F_soft_KR.dtype),F_soft_KR) 
        if self.stop_gradient:
            F_soft_KR = tf.stop_gradient(F_soft_KR)
        KR = tf.where(y_true==0,-y_pred,y_pred)

    
        #hinge_row = tf.nn.relu(current_margin/2 - KR)*F_soft_KR
        hinge_row = tf.nn.relu(tf.cast(self.margins/2,F_soft_KR.dtype) - KR)*F_soft_KR
        non_zero_row = tf.cast(tf.math.count_nonzero(hinge_row),tf.float32)/(self.nb_class*tf.cast(tf.shape(y_true)[0],tf.float32))
        hinge_row = tf.reduce_sum(hinge_row,axis = 1)

        kr_row = tf.reduce_sum( F_soft_KR*KR,axis = 1)
       
        loss_val = tf.reduce_mean(hinge_row*self.alpha- self.beta*kr_row)
        
        #loss_val = -self.margin_coeff*self.alpha*self.margins
        
        #tf.print("2",loss_val)
        #weighted sum
        tf.print("non_zero row :",non_zero_row,"hinge row :",tf.reduce_mean(hinge_row) ,"KR :",tf.reduce_mean(kr_row) )
        #tf.print("3", loss_val)
        self.nb.assign(self.nb+1)
        return loss_val


class HKR_multiclass_hingerow():

    def __init__(self, nb_class , 
                alpha = 1,
                beta = 1,
                stop_gradient = False,
                alpha_mean = 0.99, 
                temperature = 1.,
                temperature_KR = -1,
                init_margin = None,
                min_margin = 0.02,
                verbose = True):
        self.min_margin = min_margin
        self.nb_class = nb_class
        if temperature_KR == -1:
            temperature_KR = temperature
        self.temperature_KR = temperature_KR
        if init_margin is None :
            init_margin = min_margin
        self.stop_gradient = stop_gradient
        self.init_margin = init_margin
        self.alpha_mean = alpha_mean
        self.temperature = temperature
        self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        
        #beta = 1
        if alpha<0 :
            alpha =1
            beta=0
        self.alpha = K.variable(alpha, name='alpha', dtype=K.floatx())
        self.beta = K.variable(beta, name='beta', dtype=K.floatx())
        self.eps = 1e-8
        self.nb = K.variable(0, name='nb') 
        self.__name__ = "HKR_multiclass_row"
        self.verbose  = verbose
        
 

    @tf.function
    def __call__(self, y_true, y_pred):
        

        opposite_values = tf.where(y_true==1,-tf.float32.max,self.temperature_KR*y_pred)
        F_soft_KR = tf.nn.softmax(opposite_values)
       
        tf.print("margin",self.margins, "temp_kr :",self.temperature_KR, "temp_hinge :",self.temperature, "soft_max :",tf.reduce_mean(tf.reduce_max(F_soft_KR,axis = 1)))
        F_soft_KR = tf.where(y_true==1,1.,F_soft_KR) 
        if self.stop_gradient:
            F_soft_KR = tf.stop_gradient(F_soft_KR)
        KR = tf.where(y_true==0,-y_pred,y_pred)

        #F_soft_hinge =F_soft_KR
        opposite_values_hinge = tf.where(y_true==1,-tf.float32.max,self.temperature*y_pred)
        F_soft_hinge = tf.nn.softmax(opposite_values_hinge)
        F_soft_hinge = tf.where(y_true==1,1.,F_soft_hinge) 
        if self.stop_gradient:
            F_soft_hinge = tf.stop_gradient(F_soft_hinge)
        hinge_row = tf.nn.relu(self.margins/2 - KR)*F_soft_hinge
        non_zero_row = tf.cast(tf.math.count_nonzero(hinge_row),tf.float32)/(self.nb_class*tf.cast(tf.shape(y_true)[0],tf.float32))
        hinge_row = tf.reduce_sum(hinge_row,axis = 1)

        kr_row = tf.reduce_sum( F_soft_KR*KR,axis = 1)
       
        loss_val = tf.reduce_mean(hinge_row*self.alpha- self.beta*kr_row)
        
        #loss_val = -self.margin_coeff*self.alpha*self.margins
        
        #tf.print("2",loss_val)
        #weighted sum
        tf.print("non_zero row :",non_zero_row,"hinge row :",tf.reduce_mean(hinge_row) ,"KR :",tf.reduce_mean(kr_row) )
        #tf.print("3", loss_val)
        self.nb.assign(self.nb+1)
        return loss_val
class HKR_multiclass_row():

    def __init__(self, nb_class , alpha = 1,beta = 1,alpha_mean = 0.99, temperature = 1.,limit_temp = False,init_margin = None,weight_pos =1.,soft_kr = False,  multi_margin = True, min_margin = 0.02,margin_coeff=0.2,verbose = True):
        self.min_margin = min_margin
        self.nb_class = nb_class
        self.margin_coeff = margin_coeff
        if init_margin is None :
            init_margin = min_margin
        self.soft_kr = soft_kr
        self.limit_temp = limit_temp
        self.init_margin = init_margin
        self.alpha_mean = alpha_mean
        self.weight_pos = weight_pos
        self.temperature = temperature
        if multi_margin :
            self.margins = tf.Variable([self.init_margin]*nb_class,dtype=K.floatx())
        else :
            self.margins = tf.Variable(self.init_margin,dtype=K.floatx())
        #self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        self.moving_mean = tf.Variable([self.init_margin]*nb_class,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, 0.005, 1000))
        #beta = 1
        if alpha<0 :
            alpha =1
            beta=0
        self.alpha = K.variable(alpha, name='alpha', dtype=K.floatx())
        self.beta = K.variable(beta, name='beta', dtype=K.floatx())
        self.eps = 1e-8
        self.nb = K.variable(0, name='nb') 
        self.__name__ = "HKR_multiclass_row"
        self.verbose  = verbose
        
 

    @tf.function
    def __call__(self, y_true, y_pred):
        #number of positives examples per class/columns
        row_sum = tf.reduce_sum(y_true, axis = 0)
        #at least  of positives examples per class/columns
        row_max = tf.reduce_max(y_true, axis = 0)
        #nb classes respresented
        nb_rep = tf.reduce_sum(row_max)
        row_neg = tf.cast(tf.shape(y_true)[0],tf.float32) - row_sum
        #upadating moving_mean
        current_mean = tf.reduce_mean(tf.abs(y_pred),axis = 0)
        #tf.print(current_mean)
        #tf.print(self.moving_mean)
        current_mean = self.alpha_mean*self.moving_mean + (1-self.alpha_mean)*current_mean
        
        self.moving_mean.assign(current_mean)
        total_mean=  tf.reduce_mean(current_mean)
        curr_temp_kr =tf.clip_by_value(self.temperature/total_mean, 0.005, 250) 
        if self.limit_temp:
            curr_temp_hinge =tf.clip_by_value(self.temperature/total_mean, self.temperature/tf.reduce_mean(self.min_margin), 250) 
        else :
            curr_temp_hinge =curr_temp_kr

        
        opposite_values = tf.where(y_true==1,-tf.float32.max,curr_temp_hinge*y_pred)
       
        
        F_soft_KR = tf.nn.softmax(opposite_values)
        tf.print("margin",self.margins, "temp_kr :",curr_temp_kr, "temp_hinge :",curr_temp_hinge, "soft_max :",tf.reduce_mean(tf.reduce_max(F_soft_KR,axis = 1)))
        F_soft_KR = tf.where(y_true==1,1.,F_soft_KR) 
        KR = tf.where(y_true==0,-y_pred,y_pred)
        #hinge computation per case
        #var_x = self.margins/2 - KR
       
        F_soft_hinge = tf.where(y_true==1,self.weight_pos,(2.-self.weight_pos)*F_soft_KR) 
        hinge_row = tf.nn.relu(self.margins/2 - KR)*F_soft_hinge
        non_zero_row = tf.cast(tf.math.count_nonzero(hinge_row),tf.float32)/(self.nb_class*tf.cast(tf.shape(y_true)[0],tf.float32))
        hinge_row = tf.reduce_sum(hinge_row,axis = 1)
        if self.soft_kr:
            if self.limit_temp:
                opposite_values = tf.where(y_true==1,-tf.float32.max,curr_temp_kr*y_pred)
                F_soft_KR = tf.nn.softmax(opposite_values)
                F_soft_KR = tf.where(y_true==1,1.,F_soft_KR) 
            kr_row = tf.reduce_sum( F_soft_KR*KR,axis = 1)
        else:
            F_KR = tf.where(y_true==1,1.,1./(self.nb_class-1))
            kr_row = tf.reduce_sum( F_KR*KR,axis = 1)
        loss_val = tf.reduce_mean(hinge_row*self.alpha- self.beta*kr_row)
        
        #loss_val = -self.margin_coeff*self.alpha*self.margins
        
        #tf.print("2",loss_val)
        #weighted sum
        tf.print("non_zero row :",non_zero_row,"hinge row :",tf.reduce_mean(hinge_row) ,"KR :",tf.reduce_mean(kr_row) )
        #tf.print("3", loss_val)
        self.nb.assign(self.nb+1)

        return loss_val



class HKR_multiclass_auto_testv8():

    def __init__(self, nb_class , alpha = 1, beta = 1,alpha_col = 0.5,alpha_row = 0.5,auto_margin = False,temperature = 1.,alpha_mean = 0.99, soft_hinge = False, init_margin = None, multi_margin = True, min_margin = 0.02,margin_coeff=0.2,verbose = True):
        self.min_margin = min_margin
        self.nb_class = nb_class
        self.margin_coeff = margin_coeff
        self.temperature = temperature
        if init_margin is None :
            init_margin = min_margin
        self.init_margin = init_margin
        self.soft_hinge = soft_hinge
        self.alpha_mean = alpha_mean
        self.alpha_row = alpha_row
        self.alpha_col = alpha_col
        self.auto_margin = auto_margin
        if auto_margin:
            multi_margin = True
        if multi_margin :
            self.margins = tf.Variable([self.init_margin]*nb_class,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        else :
            self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        #self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        self.moving_mean = tf.Variable([self.init_margin]*nb_class,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, 0.005, 1000))
        #beta = 1
        if alpha<0 :
            alpha =1
            beta=0
        self.alpha = K.variable(alpha, name='alpha', dtype=K.floatx())
        self.beta = K.variable(beta, name='beta', dtype=K.floatx())
        self.eps = 1e-8
        self.nb = K.variable(0, name='nb') 
        self.__name__ = "HKR_multiclass_auto_testv8"
        self.verbose  = verbose
        
 

    @tf.function
    def __call__(self, y_true, y_pred):
        #number of positives examples per class/columns
        row_sum = tf.reduce_sum(y_true, axis = 0)
        row_neg = y_true - row_sum
        #at least  of positives examples per class/columns
        row_max = tf.reduce_max(y_true, axis = 0)
        #nb classes respresented
        nb_rep = tf.reduce_sum(row_max)
        #upadating moving_mean
        current_mean = tf.reduce_mean(tf.abs(y_pred),axis = 0)
        #tf.print(current_mean)
        #tf.print(self.moving_mean)
        current_mean = self.alpha_mean*self.moving_mean + (1-self.alpha_mean)*current_mean
        self.moving_mean.assign(current_mean)
        if self.auto_margin:
            self.margins.assign(current_mean*self.marg_coeff)
        empty_coeff = tf.math.minimum(nb_rep/(self.nb_class),1.)
        tf.print(" mean", current_mean)
        
        #columns importance = non represented classes count half the value of represented ones
        #class_coeff =  tf.where(row_max == 0,0.5*empty_coeff,1.)
        class_coeff =  tf.where(row_max == 0,0.,1.)
        marg_coeff = tf.where(row_max == 0,0.,1.)
        class_coeff = class_coeff/tf.reduce_sum(class_coeff)
        marg_coeff = marg_coeff/tf.reduce_sum(marg_coeff)
        
        #modification of values
        #y_pred_temperature = tf.where(y_true==1,-tf.float32.max,y_pred*self.temperature)
        curr_temp =tf.clip_by_value(self.temperature/current_mean, 0.005, 250) 

        margin_row = tf.reduce_sum(self.margins * y_true, axis=1, keepdims=True)
        temp_row = tf.reduce_sum(curr_temp * y_true, axis=1, keepdims=True)
        y_pred_row = tf.reduce_sum(y_pred * y_true, axis=1, keepdims=True)
        y_pred_temp_row = tf.where(y_true==1,-tf.float32.max,y_pred*temp_row)
        F_soft_row = tf.stop_gradient(tf.nn.softmax(y_pred_temp_row,axis = 1))
        if self.soft_hinge:
    
            hinge_row = soft_hinge(margin_row/2-y_pred_row,0.9*self.margins/2.)+ soft_hinge(margin_row/2+y_pred,0.9*self.margins/2.)
        else :
            hinge_row = tf.nn.relu(margin_row/2-y_pred_row)+ tf.nn.relu(margin_row/2+y_pred)
       
        hinge_row = tf.reduce_mean(tf.reduce_sum(hinge_row *F_soft_row ,axis=1))
        



        y_pred_temperature = tf.where(y_true==1,-tf.float32.max,y_pred*curr_temp)

        
        #softmax per columns

        F_soft = tf.nn.softmax(y_pred_temperature,axis = 0)
        #tf.print(F_soft)
        
        F_soft = tf.stop_gradient(tf.where(y_true==1,1./row_sum,F_soft) )
        
        F_soft_KR = F_soft
        #tf.print(F_soft_KR)
        #kr computation per case
        KR = tf.where(y_true==0,-y_pred,y_pred)
        #hinge computation per case
        #var_x = self.margins/2 - KR
        var_x = tf.where(y_true==1,self.margins/2 - KR,self.margins/2 - KR)
        if self.soft_hinge:
            hinge = soft_hinge(var_x,0.9*self.margins/2.)
        else :
            hinge = tf.nn.relu(var_x)

        #loss_val per columns/class
        loss_val = tf.reduce_sum(F_soft*self.alpha*hinge*self.alpha_col - self.beta*F_soft_KR*KR,axis = 0)
        #tf.print("1",loss_val)
        #margins regul
        regul_margin =  -self.margin_coeff*self.alpha*self.margins
        #loss_val = -self.margin_coeff*self.alpha*self.margins
        tf.print("margin",self.margins,"hinge row",hinge_row, "hinge col",tf.reduce_sum(tf.reduce_sum(F_soft*self.alpha*hinge,axis = 0)*class_coeff),class_coeff )
        #tf.print("2",loss_val)
        #weighted sum
        loss_val = tf.reduce_sum(loss_val*class_coeff)+tf.reduce_sum(regul_margin*marg_coeff)+hinge_row*self.alpha*self.alpha_row
        #tf.print("3", loss_val)
        self.nb.assign(self.nb+1)

        return loss_val


class HKR_multiclass_auto_testv9():

    def __init__(self, nb_class , alpha = 1, beta = 1,temperature = 1.,alpha_mean = 0.99,auto_margin = False, soft_hinge = False,weight_pos = 1., init_margin = None, multi_margin = True, min_margin = 0.02,margin_coeff=0.2,verbose = True):
        self.min_margin = min_margin
        self.nb_class = nb_class
        self.margin_coeff = margin_coeff
        self.temperature = temperature
        if init_margin is None :
            init_margin = min_margin
        self.init_margin = init_margin
        self.soft_hinge = soft_hinge
        self.alpha_mean = alpha_mean
        self.weight_pos = weight_pos
        self.auto_margin = auto_margin
        if auto_margin:
            multi_margin = True
        if multi_margin :
            self.margins = tf.Variable([self.init_margin]*nb_class,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        else :
            self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        #self.margins = tf.Variable(self.init_margin,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, min_margin, 200))
        self.moving_mean = tf.Variable([self.init_margin]*nb_class,dtype=K.floatx(),constraint=lambda z: tf.clip_by_value(z, 0.005, 1000))
        #beta = 1
        if alpha<0 :
            alpha =1
            beta=0
        self.alpha = K.variable(alpha, name='alpha', dtype=K.floatx())
        self.beta = K.variable(beta, name='beta', dtype=K.floatx())
        self.eps = 1e-8
        self.nb = K.variable(0, name='nb') 
        self.__name__ = "HKR_multiclass_auto_testv9"
        self.verbose  = verbose
        
 

    @tf.function
    def __call__(self, y_true, y_pred):
        #number of positives examples per class/columns
        row_sum = tf.reduce_sum(y_true, axis = 0)
        row_neg = y_true - row_sum
        #at least  of positives examples per class/columns
        row_max = tf.reduce_max(y_true, axis = 0)
        #nb classes respresented
        nb_rep = tf.reduce_sum(row_max)
        #upadating moving_mean
        current_mean = tf.reduce_mean(tf.abs(y_pred),axis = 0)
        #tf.print(current_mean)
        #tf.print(self.moving_mean)
        current_mean = self.alpha_mean*self.moving_mean + (1-self.alpha_mean)*current_mean
        self.moving_mean.assign(current_mean)
        if self.auto_margin:
            self.margins.assign(tf.clip_by_value(current_mean*self.margin_coeff, self.min_margin , 500))
        tf.print(" mean", current_mean)
        
        #columns importance = non represented classes count half the value of represented ones
        #class_coeff =  tf.where(row_max == 0,0.5*empty_coeff,1.)

        marg_coeff = tf.where(row_max == 0,0.,1.)
 
 
        marg_coeff = marg_coeff/tf.reduce_sum(marg_coeff)
        
        #modification of values
        #y_pred_temperature = tf.where(y_true==1,-tf.float32.max,y_pred*self.temperature)
        curr_temp =tf.clip_by_value(self.temperature/current_mean, 0.005, 250) 

        margin_row = tf.reduce_sum(self.margins * y_true, axis=1, keepdims=True)
        temp_row = tf.reduce_sum(curr_temp * y_true, axis=1, keepdims=True)
        y_pred_row = tf.reduce_sum(y_pred * y_true, axis=1, keepdims=True)
        y_pred_temp_row = tf.where(y_true==1,-tf.float32.max,y_pred*temp_row)
        F_soft_row = tf.stop_gradient(tf.nn.softmax(y_pred_temp_row,axis = 1))
        if self.soft_hinge:
    
            hinge_row = self.weight_pos*soft_hinge(margin_row/2-y_pred_row,0.9*self.margins/2.)+ (2.-self.weight_pos)*soft_hinge(margin_row/2+y_pred,0.9*self.margins/2.)
        else :
            hinge_row = self.weight_pos*tf.nn.relu(margin_row/2-y_pred_row)+ (2.-self.weight_pos)*tf.nn.relu(margin_row/2+y_pred)
        tf.print("margin :",self.margins)
        hinge_row = tf.reduce_mean(tf.reduce_sum(hinge_row *F_soft_row ,axis=1))
        kr_row = y_pred_row - y_pred
        kr_row = tf.reduce_mean(tf.reduce_sum(kr_row *F_soft_row ,axis=1))



     
        regul_margin =  -self.margin_coeff*self.alpha*self.margins
        #loss_val = -self.margin_coeff*self.alpha*self.margins
        #tf.print("2",loss_val)
        #weighted sum
        if self.auto_margin:
            loss_val = hinge_row*self.alpha - self.beta*kr_row
        else :
            loss_val = hinge_row*self.alpha - self.beta*kr_row +tf.reduce_sum(regul_margin*marg_coeff)
        #tf.print("3", loss_val)
        self.nb.assign(self.nb+1)

        return loss_val

class cosin_lip():

    def __init__(self,nb_class,verbose = False,alpha = 0):
        self.__name__ = "cosin_lip"
        self.verbose= verbose
        self.alpha = alpha
        self.margins = tf.Variable(np.array([0]*nb_class),dtype=K.floatx())
 

    @tf.function
    def __call__(self, y_true, y_pred):
        y_norms = tf.norm(y_pred,axis = -1)
        vYtrue = tf.reduce_sum(y_pred * y_true, axis=1) 
        #tf.print(-tf.reduce_mean(vYtrue/y_norms),tf.reduce_mean(y_norms))
        #return -tf.reduce_mean(vYtrue*(1+0.001*y_norms)/y_norms)
        return -tf.reduce_mean(vYtrue/y_norms +self.alpha*vYtrue )
    

        
    def get_config(self):
        config = {}
        return config

class CategoricalCrossentropyLip():

    def __init__(self,verbose = False,temperature = 1.):
        self.__name__ = "CategoricalCrossentropyLip"
        self.verbose= verbose
        self.temperature = temperature
        self.loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
        self.margins = tf.Variable([0],dtype=K.floatx())
 

    @tf.function
    def __call__(self, y_true, y_pred):
        
        return self.loss(y_true,y_pred*self.temperature)
    

        
    def get_config(self):
        config = {}
        return config

class BinaryCrossentropyLip():

    def __init__(self,verbose = False,temperature = 1.):
        self.__name__ = "BinaryCrossentropyLip"
        self.verbose= verbose
        self.temperature = temperature
        self.loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        self.margins = tf.Variable([0],dtype=K.floatx())
 

    @tf.function
    def __call__(self, y_true, y_pred):
        
        return self.loss(y_true,y_pred*self.temperature)
    

        
    def get_config(self):
        config = {}
        return config