import torch
import torch.nn as nn

class GatedBatchNorm2d(nn.Module):
    def __init__(self, bn, hw=0.0, eta=0.0):
        '''
            hw: h \times w, which h, w is the input feature map's height and width.
            eta: to adjust the impact of the FLOPS regularization term
        '''
        super(GatedBatchNorm2d, self).__init__()
        self.bn = bn
        self.hw = hw
        self.eta = eta

        self.channel_size = bn.weight.shape[0]
        self.device = bn.weight.device

        self.g = nn.Parameter(torch.ones(1, self.channel_size, 1, 1).to(self.device), requires_grad=True)
        self.register_buffer('score', torch.zeros(1, self.channel_size, 1, 1).to(self.device))
        self.register_buffer('bn_mask', torch.ones(1, self.channel_size, 1, 1).to(self.device))
        
        self.extract_from_bn()

    def extract_from_bn(self):
        # freeze bn weight
        self.bn.bias.data.set_(self.bn.bias / self.bn.weight)
        self.g.data.set_(self.g * self.bn.weight.view(1, -1, 1, 1))
        self.bn.weight.data.set_(torch.ones_like(self.bn.weight))
        self.bn.weight.requires_grad = False

    def reset_score(self):
        self.score.zero_()

    def cal_score(self, grad):
        self.score += (grad * self.g).abs()
    
    def get_score(self):
        # use self.bn_mask.sum() to calculate the number of input channel
        flops_reg = self.eta * self.hw * self.bn_mask.sum()
        return ((self.score - flops_reg) * self.bn_mask).view(-1)

    def forward(self, x):
        x = self.bn(x) * self.g

        if self.bn_mask is not None:
            return x * self.bn_mask
        return x

class ChannelMask(nn.Module):
    '''
        Learning Efficient Convolutional Networks through Network Slimming, In ICCV 2017.
    '''
    def __init__(self, bn):
        super(ChannelMask, self).__init__()
        self.bn = bn

        self.channel_size = bn.weight.shape[0]
        self.device = bn.weight.device
        self.register_buffer('score', torch.zeros(1, self.channel_size, 1, 1).to(self.device))
        self.register_buffer('bn_mask', torch.ones(1, self.channel_size, 1, 1).to(self.device))
    
    def get_score(self):
        return self.bn.weight.view(-1) * self.bn_mask.view(-1)

    def forward(self, x):
        x = self.bn(x)

        if self.bn_mask is not None:
            return x * self.bn_mask
        return x
