from functools import partial

import torch
import torch.nn.functional as F


# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def gelu_fwd(x):
    return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)


# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def gelu_bwd(g, x):
    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
    # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
    ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
            1 + tanh_out
    )
    return (ff * g).to(dtype=x.dtype)


class FastGeLUFunction(torch.autograd.Function):
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return gelu_fwd(input)

    @staticmethod
    def backward(ctx, grad_output):
        (input,) = ctx.saved_tensors
        tmp = gelu_bwd(grad_output, input)
        return tmp


fast_gelu_impl = FastGeLUFunction.apply


@torch.jit.script
def relu_fwd(x):
    r = F.relu(x)
    return r.to(dtype=x.dtype)


@torch.jit.script
def relu_bwd(g, x):
    return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)


class FastReLUFunction(torch.autograd.Function):
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return relu_fwd(input)

    @staticmethod
    def backward(ctx, grad_output):
        (input,) = ctx.saved_tensors
        tmp = relu_bwd(grad_output, input)
        return tmp


@torch.jit.script
def sqrelu_fwd(x):
    r = F.relu(x)
    return (r * r).to(dtype=x.dtype)


@torch.jit.script
def sqrelu_bwd(g, x):
    return (2.0 * g * F.relu(x)).to(dtype=x.dtype)


class FastSqreLUFunction(torch.autograd.Function):
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return sqrelu_fwd(input)

    @staticmethod
    def backward(ctx, grad_output):
        (input,) = ctx.saved_tensors
        tmp = sqrelu_bwd(grad_output, input)
        return tmp


@torch.jit.script
def golu_fwd(x):
    z = x * torch.exp(-1 * torch.exp(-1 * x))
    return z.to(dtype=x.dtype)


@torch.jit.script
def golu_bwd(g, x):
    z = g * (torch.exp(-torch.exp(-x)) + x * torch.exp(-x) * torch.exp(-torch.exp(-x)))
    return z.to(dtype=x.dtype)


class FastGoLUFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return golu_fwd(input)

    @staticmethod
    def backward(ctx, grad_output):
        (input,) = ctx.saved_tensors
        tmp = golu_bwd(grad_output, input)
        return tmp


if __name__ == "__main__":
    import torch

    latent = torch.randn([64, 100, 512])

    sqrelu = FastSqreLUFunction.apply

    sout = sqrelu(latent)

    print(sout.shape)

    gelu = FastGeLUFunction.apply

    gout = gelu(latent)

    print(gout.shape)

    gelu_pytroch = partial(F.gelu, approximate="tanh")

    gout2 = gelu_pytroch(latent)

    print(gout2.shape)

    golu = FastGoLUFunction.apply

    g3out = golu(latent)

    print(g3out.shape)

    print(torch.dist(gout, sout))
    print(torch.dist(gout, gout2))
    print(torch.dist(gout, g3out))
    print(torch.dist(gout, g3out))

    # import torch
    #
    #
    #
    # # Parameters
    # t = torch.tensor(torch.randn(10,10), requires_grad=True)  # Time
    #
    # # Compute the Gompertz function
    # P_t = golu_fwd(t)
    #
    # # Compute the manual derivative
    # manual_deriv = golu_bwd(torch.tensor(0), t)
    #
    # # Use autograd to compute the derivative
    # P_t.backward()
    # auto_deriv = t.grad
    #
    # print(f"Manual derivative: {manual_deriv.item()}")
    # print(f"Autograd derivative: {auto_deriv.item()}")
    #
    # # Compare manual derivative with autograd derivative
    # # Note: Due to the floating-point arithmetic, there might be a small numerical difference.
    # assert torch.isclose(manual_deriv, auto_deriv,
    #                      atol=1e-6), "The manual derivative does not match the autograd derivative"
