
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BatchNorm1d, Dropout, LeakyReLU, Linear, Module, ReLU, Sequential, functional


def sample_gumbel(params, shape, eps=1e-20):
    U = torch.rand(shape).to(params.device)
    return -torch.log(-torch.log(U + eps) + eps)


def gumbel_softmax_sample(params, logits, temperature, gumbel_noise):

    if gumbel_noise==None:
        gumbel_noise= sample_gumbel(params, logits.size())

    y = logits + gumbel_noise
    return F.softmax(y / temperature, dim=-1)


def gumbel_softmax(params, logits, temperature, gumbel_noise=None, hard=False):
    """
    ST-gumple-softmax
    input: [*, n_class]
    return: flatten --> [*, n_class] an one-hot vector
    """
    output_dim =logits.shape[1]
    y = gumbel_softmax_sample(params, logits, temperature, gumbel_noise)

    if not hard:
        ret = y.view(-1, output_dim)
        return ret

    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    # Set gradients w.r.t. y_hard gradients w.r.t. y
    y_hard = (y_hard - y).detach() + y
    # ret = y_hard.view(-1, output_dim)
    ret = y_hard.view(-1, output_dim)
    return ret

# def _gumbel_softmax(logits, tau=1,  dim=-1, hard=False, eps=1e-10):
#     """Deals with the instability of the gumbel_softmax for older versions of torch.
#
#     For more details about the issue:
#     https://drive.google.com/file/d/1AA5wPfZ1kquaRtVruCd6BiYZGcDeNxyP/view?usp=sharing
#
#     Args:
#         logits […, num_features]:
#             Unnormalized log probabilities
#         tau:
#             Non-negative scalar temperature
#         hard (bool):
#             If True, the returned samples will be discretized as one-hot vectors,
#             but will be differentiated as if it is the soft sample in autograd
#         dim (int):
#             A dimension along which softmax will be computed. Default: -1.
#
#     Returns:
#         Sampled tensor of same shape as logits from the Gumbel-Softmax distribution.
#     """
#     if version.parse(torch.__version__) < version.parse('1.2.0'):
#         for i in range(10):
#             transformed = functional.gumbel_softmax(logits, tau=tau, hard=hard,
#                                                     eps=eps, dim=dim)
#             if not torch.isnan(transformed).any():
#                 return transformed
#         raise ValueError('gumbel_softmax returning NaN.')
#
#     return functional.gumbel_softmax(logits, tau=tau, hard=hard, eps=eps, dim=dim)

class Residual(Module):
    """Residual layer for the CTGANSynthesizer."""

    def __init__(self, i, o):
        super(Residual, self).__init__()
        self.fc = Linear(i, o)
        self.bn = BatchNorm1d(o)
        self.relu = ReLU()

    def forward(self, input_):
        """Apply the Residual layer to the `input_`."""
        out = self.fc(input_)
        out = self.bn(out)
        out = self.relu(out)
        return torch.cat([out, input_], dim=1)



class ControllerGenerator(torch.nn.Module):

    def __init__(self,  input_dim, hid_dims, output_dim_list):
        super(ControllerGenerator, self).__init__()
        self.hidden_dims = hid_dims
        self.input_dim = input_dim
        self.output_dim_list = output_dim_list
        self.output_dim = sum(self.output_dim_list)

        print(f'Causal Generator init: indim {self.input_dim}  outdim {self.output_dim}')

        dim = self.input_dim
        seq = []
        for item in list(self.hidden_dims):
            seq += [Residual(dim, item)]
            dim += item
        seq.append(Linear(dim, self.output_dim))
        self.seq = Sequential(*seq)


    def forward(self, params, noise, gen_labels, **kwargs):

        input = noise
        if len(gen_labels) > 0:
            gen_labels = torch.cat(gen_labels, 1)
            # gen_labels = self.input_label_layer(gen_labels)
            input = torch.cat([input, gen_labels], 1)

        output = self.seq(input)


        # output = output[:, 0: self.output_dim]

        output_feature=[]
        st=0
        for dim in self.output_dim_list:
            en= st+dim
            out = gumbel_softmax(params, output[:, st:en], params.Temperature).to(params.device)
            output_feature.append(out)
            st=en


        final_output = torch.cat(output_feature, dim=1)
        return final_output





class ControllerDiscriminator(nn.Module):

    def __init__(self, input_dim, hid_dims, pac=10, **kwargs):
        super(ControllerDiscriminator, self).__init__()

        input_dim = input_dim

        input_dim = input_dim * pac
        self.pac = pac
        self.pacdim = input_dim

        output_dim = 1
        hidden_dims = hid_dims

        print(f'Critic init: indim {input_dim}  outdim {output_dim}')

        # self.input_layer = nn.Sequential(
        #     nn.Linear(input_dim, hidden_dims[0]),
        #     nn.LeakyReLU(0.2),
        #     nn.Dropout(0.3)
        # )
        #
        # self.hidden_layers = nn.ModuleList()
        # for i in range(len(hidden_dims)-1):
        #     hid_layer = nn.Sequential(
        #         nn.Linear(hidden_dims[i], hidden_dims[i+1]),
        #         nn.LeakyReLU(0.2),
        #         nn.Dropout(0.3)
        #     )
        #     self.hidden_layers.append(hid_layer)
        #
        #
        # self.output_layer = nn.Sequential(
        #     nn.Linear(hidden_dims[-1], output_dim),
        #     # nn.Sigmoid()    #Doesnt use the sigmoid function in WGAN
        # )

        dim = input_dim
        seq = []
        for item in list(hidden_dims):
            seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)]
            dim = item

        seq += [Linear(dim, 1)]
        self.seq = Sequential(*seq)

    def forward(self, generated_data):
        assert generated_data.size()[0] % self.pac == 0
        generated_data = generated_data.view(-1, self.pacdim)

        input = generated_data
        # output = self.input_layer(input)
        # for i in range(len(self.hidden_layers)):
        #     output= self.hidden_layers[i](output)
        # output = self.output_layer(output)

        output = self.seq(input)
        return output
