"This code is originally from StyleGAN-pytorch implementation github https://github.com/huangzh13/StyleGAN.pytorch"
import torch
import torch.nn as nn
import torch.nn.functional as F

from models.op import fused_leaky_relu



class EqualizedLinear(nn.Module):
    """Linear layer with equalized learning rate and custom learning rate multiplier."""

    def __init__(self, in_dim, out_dim, bias=True, bias_init=0., activation=None,
                 gain=1., use_wscale=True, lrmul=1.):
        super(EqualizedLinear, self).__init__()

        # Equalized learning rate and custom learning rate multiplier.
        he_std = gain * in_dim ** (-0.5)  # He init
        if use_wscale:
            init_std = 1.0 / lrmul
            self.w_mul = he_std * lrmul
        else:
            init_std = he_std / lrmul
            self.w_mul = lrmul

        self.weight = torch.nn.Parameter(torch.randn(
            out_dim, in_dim) * init_std, requires_grad=True)

        if bias:
            self.bias = nn.Parameter(torch.zeros(
                out_dim).fill_(bias_init), requires_grad=True)
            self.b_mul = lrmul
        else:
            self.bias = None

        self.activation = activation

    def forward(self, x):
        if self.activation == 'lrelu':  # act='lrelu'
            out = F.linear(x, self.weight * self.w_mul)
            out = fused_leaky_relu(out, self.bias * self.b_mul)
            #out = nn.LeakyReLU(0.2)(out)
        else:
            out = F.linear(x, self.weight * self.w_mul,
                           bias=self.bias * self.b_mul)

        return out



if __name__ == '__main__':
    print('Done.')