"""
Implements Masked AutoEncoder for Density Estimation, by Germain et al. 2015
Re-implementation by Andrej Karpathy based on https://arxiv.org/abs/1502.03509
"""

import math
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from .. import diffeq_layers

__all__ = ['BidirectionalAutoregressiveNetwork']

# ------------------------------------------------------------------------------

nonlinear_modules = {
    'tanh': lambda: nn.Tanh(),
    'elu': lambda: nn.ELU(inplace=True),
}


class MaskedLinear(nn.Linear):
    """ same as Linear except has a configurable mask on the weights """

    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias)
        self.register_buffer('mask', torch.ones(out_features, in_features))
        self.register_buffer('rev_mask', torch.ones(out_features, in_features))
        self.reverse_ordering = False

        # Use a different set of biases for the reverse ordering.
        self.rev_bias = nn.Parameter(torch.Tensor(out_features).copy_(self.bias))

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5) / 2)
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 2 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def set_mask(self, mask, rev_mask=None):
        self.mask.data.copy_(torch.from_numpy(mask.astype(np.uint8).T))
        if rev_mask is not None:
            self.rev_mask.data.copy_(torch.from_numpy(rev_mask.astype(np.uint8).T))

    def forward(self, t, input):
        del t
        mask = self.rev_mask if self.reverse_ordering else self.mask
        bias = self.rev_bias if self.reverse_ordering else self.bias
        return F.linear(input, mask * self.weight, bias)


class MaskedConcatLinear(MaskedLinear):
    """ same as Linear except has a configurable mask on the weights """

    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias)
        # Time-dependence
        self.hyper_bias = nn.Linear(1, out_features, bias=False)
        self.rev_hyper_bias = nn.Linear(1, out_features, bias=False)

    def forward(self, t, input):
        mask = self.rev_mask if self.reverse_ordering else self.mask
        bias = self.rev_bias if self.reverse_ordering else self.bias
        hyper_bias = self.rev_hyper_bias if self.reverse_ordering else self.hyper_bias
        return F.linear(input, mask * self.weight, bias) + hyper_bias(t.view(1, 1))


class MaskedConcatSquashLinear(MaskedLinear):
    """ same as Linear except has a configurable mask on the weights """

    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias)
        # Time-dependence
        self.hyper_bias = nn.Linear(1, out_features, bias=False)
        self.rev_hyper_bias = nn.Linear(1, out_features, bias=False)
        self.hyper_gate = nn.Linear(1, out_features)
        self.rev_hyper_gate = nn.Linear(1, out_features)

    def forward(self, t, input):
        mask = self.rev_mask if self.reverse_ordering else self.mask
        bias = self.rev_bias if self.reverse_ordering else self.bias
        hyper_bias = self.rev_hyper_bias if self.reverse_ordering else self.hyper_bias
        hyper_gate = self.rev_hyper_gate if self.reverse_ordering else self.hyper_gate
        return F.linear(input, mask * self.weight, bias
                        ) * torch.sigmoid(hyper_gate(t.view(1, 1))) + hyper_bias(t.view(1, 1))


class MADE(nn.Module):

    def __init__(
        self, nin, hidden_sizes, nout, num_masks=1, natural_ordering=False, nonlinearity='elu',
        layer_type='concatsquash'
    ):
        """
        nin: integer; number of inputs
        hidden sizes: a list of integers; number of units in hidden layers
        nout: integer; number of outputs
        num_masks: can be used to train ensemble over orderings/connections
        natural_ordering: force natural ordering of dimensions, don't use random permutations
        """

        super().__init__()
        self.nin = nin
        self.nout = nout
        self.hidden_sizes = hidden_sizes
        self.nonlinearity = nonlinearity
        self.layer_type = layer_type
        assert self.nout % self.nin == 0, "nout must be integer multiple of nin"

        base_layer = {
            'ignore': MaskedLinear,
            'concat': MaskedConcatLinear,
            'concatsquash': MaskedConcatSquashLinear,
        }[self.layer_type]

        # define a simple MLP neural net
        self.net = []
        hs = [nin] + list(hidden_sizes) + [nout]
        for h0, h1 in zip(hs, hs[1:]):
            self.net.extend([
                base_layer(h0, h1),
                diffeq_layers.diffeq_wrapper(nonlinear_modules[self.nonlinearity]()),
            ])
        self.net.pop()  # pop the last activation for the output layer
        self.net = diffeq_layers.SequentialDiffEq(*self.net)

        # seeds for orders/connectivities of the model ensemble
        self.natural_ordering = natural_ordering
        self.num_masks = num_masks

        self.m = {}
        self.update_masks()  # builds the initial self.m connectivity
        # note, we could also precompute the masks and cache them, but this
        # could get memory expensive for large number of masks.

    def update_masks(self):
        if self.m and self.num_masks == 1: return  # only a single seed, skip for efficiency
        L = len(self.hidden_sizes)

        # sample the order of the inputs and the connectivity of all neurons
        self.m[-1] = np.arange(self.nin) if self.natural_ordering else np.random.permutation(self.nin)
        for l in range(L):
            self.m[l] = np.random.randint(self.m[l - 1].min(), self.nin, size=self.hidden_sizes[l])

        # construct the mask matrices
        masks = [self.m[l - 1][:, None] <= self.m[l][None, :] for l in range(L)]
        masks.append(self.m[L - 1][:, None] < self.m[-1][None, :])

        # construct the mask matrices for the reverse ordering
        rev_masks = [self.m[l - 1][:, None] >= self.m[l][None, :] for l in range(L)]
        rev_masks.append(self.m[L - 1][:, None] > self.m[-1][None, :])

        # handle the case where nout = nin * k, for integer k > 1
        if self.nout > self.nin:
            k = int(self.nout / self.nin)
            # replicate the mask across the other outputs
            # masks[-1] = np.concatenate([masks[-1]] * k, axis=1)
            # rev_masks[-1] = np.concatenate([rev_masks[-1]] * k, axis=1)

            m, n = masks[-1].shape
            masks[-1] = np.concatenate([masks[-1][:, :, None]] * k, axis=2).reshape(m, n * k)
            rev_masks[-1] = np.concatenate([rev_masks[-1][:, :, None]] * k, axis=2).reshape(m, n * k)

        # set the masks in all MaskedLinear layers
        layers = [l for l in self.net.modules() if isinstance(l, MaskedLinear)]
        for l, m, r in zip(layers, masks, rev_masks):
            l.set_mask(m, r)

    def forward(self, t, x, reverse_ordering=False):
        for l in self.net.modules():
            if isinstance(l, MaskedLinear):
                l.reverse_ordering = reverse_ordering
        y = self.net(t, x)
        return y


class BidirectionalAutoregressiveNetwork(nn.Module):

    def __init__(
        self,
        nin,
        hidden_sizes,
        nreps,
        natural_ordering=False,
        nonlinearity='elu',
        layer_type='concatsquash',
        mix='additive',
    ):
        super(BidirectionalAutoregressiveNetwork, self).__init__()
        self.nreps = nreps
        self.mix = mix
        self.made = MADE(nin, hidden_sizes, nin * self.nreps, 1, natural_ordering, nonlinearity, layer_type)

        # TODO: Should we have separate output sizes for the autoregressive and the mixing network?

        if mix == 'mlp':
            base_layer = {
                'ignore': diffeq_layers.IgnoreLinear,
                'concat': diffeq_layers.ConcatLinear,
                'concatsquash': diffeq_layers.ConcatSquashLinear,
            }[layer_type]

            self.mixer = diffeq_layers.SequentialDiffEq(
                base_layer(self.nreps * 2, self.nreps * 4),
                diffeq_layers.diffeq_wrapper(nonlinear_modules[nonlinearity]()),
                base_layer(self.nreps * 4, self.nreps * 4),
                diffeq_layers.diffeq_wrapper(nonlinear_modules[nonlinearity]()),
                base_layer(self.nreps * 4, self.nreps),
            )

    def forward(self, t, x):
        preorder = self.made(t, x, reverse_ordering=False)
        postorder = self.made(t, x, reverse_ordering=True)

        if self.mix == 'additive':
            return preorder + postorder
        elif self.mix == 'mlp':
            premix = torch.cat([
                preorder.view(-1, self.nreps),
                postorder.view(-1, self.nreps),
            ], 1)
            return self.mixer(t, premix).view(*preorder.shape)


# ------------------------------------------------------------------------------

if __name__ == '__main__':

    # run a quick and dirty test for the autoregressive property
    D = 5
    x = np.random.randn(1, D).astype(np.float32)

    configs = [
        (D, [], D, False),  # test various hidden sizes
        (D, [200], D, False),
        (D, [200, 220], D, False),
        (D, [200, 220, 230], D, False),
        (D, [200, 220], D, True),  # natural ordering test
        (D, [10], 2 * D, True),  # test nout > nin
        (D, [200, 220], 3 * D, False),  # test nout > nin
    ]

    for nin, hiddens, nout, natural_ordering in configs:

        print("checking nin %d, hiddens %s, nout %d, natural %s" % (nin, hiddens, nout, natural_ordering))
        model = BidirectionalAutoregressiveNetwork(nin, hiddens, nout, natural_ordering=natural_ordering)

        # run backpropagation for each dimension to compute what other
        # dimensions it depends on.
        res = []
        for k in range(nout):
            xtr = torch.from_numpy(x).requires_grad_(True)
            xtrhat = model(xtr)
            loss = xtrhat[0, k]
            loss.backward()

            nreps = nout // nin

            depends = (xtr.grad[0].numpy() != 0).astype(np.uint8)
            depends_ix = list(np.where(depends)[0])
            isok = k // nreps not in depends_ix

            print(depends)
            res.append((len(depends_ix), k // nreps, k % nreps, depends_ix, isok))

        # pretty print the dependencies
        res.sort()
        for nl, k, r, ix, isok in res:
            print("output %2d.%2d depends on inputs: %30s : %s" % (k, r, ix, "OK" if isok else "NOTOK"))
