import math
import torch
import torch.nn as nn


class PlanarFlow(nn.Module):

    def __init__(self, nd=1):
        super(PlanarFlow, self).__init__()
        self.nd = nd

        self.register_parameter('u', nn.Parameter(torch.randn(self.nd)))
        self.register_parameter('w', nn.Parameter(torch.randn(self.nd)))
        self.register_parameter('b', nn.Parameter(torch.randn(1)))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.nd)
        self.u.data.uniform_(-stdv, stdv)
        self.w.data.uniform_(-stdv, stdv)
        self.b.data.fill_(0)

    def forward(self, z, logp=None, reverse=False):
        """Computes f(z) and log q(f(z))"""

        assert not reverse, 'Planar normalizing flow cannot be reversed.'

        h = torch.tanh(torch.mm(z, self.w.view(self.nd, 1)) + self.b)
        z1 = z + self.u.expand_as(z) * h

        if logp is not None:
            psi = torch.mm(1 - h**2, self.w.view(1, self.nd))
            u_dot_psi = torch.mm(psi, self.u.view(self.nd, 1))
            detgrad = 1 + u_dot_psi
            logpz1 = logp - torch.log(detgrad + 1e-12)
            return z1, logpz1
        else:
            return z1

    def extra_repr(self):
        return 'nd={}'.format(self.nd)
