import torch

# Hook that saves outputs of a module. Optionally apply an activation function to the outputs.
class CollectActivationsConv2d:
    def __init__(self, unconditional=False, activation_fkt=None, samples_per_prompt=10):
        self.unconditional=unconditional
        self.outputs = None
        self.elements = 0
        self.activation_fkt = activation_fkt
        self.active = True
        self.samples_per_prompt = samples_per_prompt

    def __call__(self, module, module_in, module_out):
        if self.active:
            if self.activation_fkt is not None:
                module_out = self.activation_fkt(module_out)
                
            if self.unconditional:
                # get the last half of the activations
                module_out = module_out[self.samples_per_prompt:]
            else:
                # get the first half of the activations
                module_out = module_out[:self.samples_per_prompt]
                
            self.outputs = module_out.detach()
                
            self.elements += 1
    
    def activate(self):
        self.active = True
    
    def deactivate(self):
        self.active = False

    def clear(self):
        self.outputs = None
        self.elements = 0
        
    def max_activations_per_unit(self):
        """Returns the maximum activation per unit"""
        return self.outputs.max(0).values

    def mean_median_activations(self):
        return self.outputs / self.elements
    
# Hook that saves outputs of a module. Optionally apply an activation function to the outputs.
class CollectActivationDiffsConv2d:
    def __init__(self, unconditional=False, activation_fkt=None):
        self.unconditional=unconditional
        self.diff = None
        self.elements = 0
        self.activation_fkt = activation_fkt
        self.active = True
        self.last_activations = None

    def __call__(self, module, module_in, module_out):
        if self.active:
            if self.activation_fkt is not None:
                module_out = self.activation_fkt(module_out)
                
            if self.unconditional:
                module_out = module_out[:module_out.shape[0] // 2]
            else:
                module_out = module_out[module_out.shape[0] // 2:]
                
            current_activations = module_out.detach()

            if self.last_activations is None:
                self.last_activations = current_activations
            else:
                self.diff = self.last_activations - current_activations
            
                self.last_activations = current_activations
                
            self.elements += 1
    
    def activate(self):
        self.active = True
    
    def deactivate(self):
        self.active = False

    def clear(self):
        self.diff = None
        self.last_activations = None
        self.elements = 0
        

class CollectActivationsLinear:
    def __init__(self, unconditional=False, activation_fkt=None, multiple=False):
        self.unconditional=unconditional
        self.outputs = None
        self.elements = 0
        self.activation_fkt = activation_fkt
        self.active = True

    def __call__(self, module, module_in, module_out):
        if self.active:
            if self.activation_fkt is not None:
                module_out = self.activation_fkt(module_out)
                
            if self.unconditional:
                module_out = module_out[:module_out.shape[0] // 2]
            else:
                module_out = module_out[module_out.shape[0] // 2:]
                
            if self.outputs is None:
                self.outputs = module_out.detach().mean(0)
            else:
                self.outputs += module_out.detach().mean(0)
                
            self.elements += 1

    def activate(self):
        self.active = True
    
    def deactivate(self):
        self.active = False

    def clear(self):
        self.outputs = None
        self.elements = 0
        
    def median_activations(self):
        if self.outputs is None:
            return None
        return self.outputs.median(0).values
        
    def activations(self):
        return self.outputs
    
class CollectActivationsLinearNoMean:
    def __init__(self, unconditional=False, activation_fkt=None):
        self.unconditional=unconditional
        self.outputs = None
        self.activation_fkt = activation_fkt
        self.active = True

    def __call__(self, module, module_in, module_out):
        if self.active:
            if self.activation_fkt is not None:
                module_out = self.activation_fkt(module_out)
                                
            self.outputs = module_out.detach()
                
    def activate(self):
        self.active = True
    
    def deactivate(self):
        self.active = False

    def clear(self):
        self.outputs = None
        
    def median_activations(self):
        if self.outputs is None:
            return None
        return self.outputs.median(0).values
        
    def activations(self):
        return self.outputs
    

class CollectActivationsValueOutputLayer:
    def __init__(self, unconditional=False):
        self.unconditional=unconditional
        self.outputs = None
        self.active = True

    def __call__(self, module, module_in, module_out):
        if self.active:
            if self.unconditional:
                # if we have no guidance scale get only the unconditional activations
                self.outputs = module_out.detach()
            else:
                # if we have guidance scale get only the conditional activations
                self.outputs = module_out[module_out.shape[0] // 2:].detach()

    def activate(self):
        self.active = True
    
    def deactivate(self):
        self.active = False

    def clear(self):
        self.outputs = None
        
    def median_activations(self):
        if self.outputs is None:
            return None
        return self.outputs.median(0).values
        
    def activations(self):
        return self.outputs