import numpy as np

import torch
import torch.nn as nn

#from wassdistance.layers import SinkhornDistance
from geomloss import SamplesLoss

import utils

def generate_set_pairs(size_pfx, n, d, std_X, std_rel_dX, n_dX=None, calc_distance = False, dtype=torch.float64, device=None):
    # Simple use: X1, X2 = generate_set_pairs(nSamples, n, d, std_X, std_rel_dX, n_dX):
    #          Returns two tensors X1, X2 of size (nSamples, n, d), where:
    #          Each X1(t,:,:) and X2(t,:,:) represent two corresponding sets of n vectors in R^d
    #          that differ in n_dX columns. Default: n_dX = n
    #          a. Each entry of X1, X2 is i.i.d. with mean=0 and std=std_X.
    #          b. Each nonzero entry of dX = X2 - X1 has mean=0, std=std_rel_dX*std_X.
    #
    # Output shape:
    # a. If size_pfx is a number nSamples, the output size is (nSamples, n, d)
    # b. If size_pfx is a tuple, the output size is the concatenation of size_pfx and (n,d),
    #    e.g. size_pfx=(nBatches, batchSize) yields output size (nBatches, batchSize, n, d)
    # c. If size_pfx is an empty tuple (), the output size is (n,d)
    #
    # Output distribution:
    #
    # 1a) If n_dX in {0,..,n}: Each dX(t,:,:) has exactly n_dX nonzero rows, namely n_dX values of i
    #     for which dX(t,i,:) may contain nonzero entries. Default n_dX: = n
    # 1b) If n_dX=(n_dX_min, n_dX_max), then the number of nonzero rows for each dX(t,:,:) is chosen
    #     uniformly at random from {n_dX_min, n_dX_min+1, ..., n_dX_max}.
    #
    # 2) Each entry of X1 and X2 is i.i.d. with mean=0, std=std_X 
    #    (but there is dependence between each entry of X1 and its corresponding entry of X2)
    #
    # 3a) If std_rel_dX is a number: Each entry of dX is i.i.d. Normal[mu=0, std=std_rel_dX*std_X] 
    #     Requires: 0 <= std_rel_dX <= 2
    # 3b) If std_rel_dX = (std_rel_dX_min,std_rel_dX_max): Each dX(b,t,:,:) ~ Unif[std_rel_dX_min*std_X,std_rel_dX_max*std_X] * Normal[mu=0,std=1]
    #     with the Unif[] variable drawn once for each dX(t,:,:), and the Normal[] drawn for each entry of dX(t,:,:).
    #     Requires: 0 <= std_rel_dX_min <= std_rel_dX_max <= 2
    #
    # How X1 and X2 are generated:
    #
    # i.  In case 3a, X1=U-V, X2=U+V, where the entries of (U,V) are random Gaussian i.i.d.
    #     The std of each entry of dX is: sqrt( std_rel_dX_min^2 + std_rel_dX_min*std_rel_dX_max + std_rel_dX_max^2 ) / sqrt(3) * std_X
    #
    # ii. If case 3b, U is Gaussian and each entry of V is a product of a Gaussian and a Uniform random variable.
    #     The std of each entry of dX is: sqrt( std_rel_dX_min^2 + std_rel_dX_min*std_rel_dX_max + std_rel_dX_max^2 ) / sqrt(3) * std_X

    if type(size_pfx) == int:
        size_pfx = (size_pfx,)
    elif type(size_pfx) == tuple:
        pass
    else:
        raise Exception('Invalid size_pfx')
    
    ndims_pfx = len(size_pfx)
    
    # This will be the size of the outputs X1, X2
    full_size = size_pfx + (n,d)

    if n_dX is None:
        n_dX = (n,n)
    elif type(n_dX) not in (list,tuple):
        n_dX = (n_dX, n_dX)    

    assert utils.is_sorted((0, n_dX[0], n_dX[1], n)), 'Invalid n_dX'

    if type(std_rel_dX) not in (list,tuple):
        std_rel_dX = (std_rel_dX, std_rel_dX)

    assert utils.is_sorted((0, std_rel_dX[0], std_rel_dX[1], 2)), 'Invalid std_rel_dX'

    std_U_eq = std_X
    std_U_neq = std_X * np.sqrt(1 - (1/12)*( std_rel_dX[0]*std_rel_dX[0] + std_rel_dX[0]*std_rel_dX[1] + std_rel_dX[1]*std_rel_dX[1] ) )

    a_V = std_X * std_rel_dX[0]/2
    b_V = std_X * std_rel_dX[1]/2
   
    # Generate mask that tells where dX may contain nonzeros
    if n_dX[0] == n:
        mask = torch.tensor(True, device=device)

    elif n_dX[1] == 0:
        mask = torch.tensor(False, device=device)

    else:
        P = utils.randperm_batch(size=size_pfx, nElements=n, device=device)

        if n_dX[1] > n_dX[0]:
            nDifferent = torch.randint(low=n_dX[0], high=n_dX[1]+1, size=size_pfx+(1,), device=device)
        else:
            nDifferent = n_dX[0]

        mask = P < nDifferent
        mask = torch.unsqueeze(mask, dim=ndims_pfx+1)

    # Generate U
    U = torch.randn(size=full_size, dtype=dtype, device=device)
    U *= std_U_neq * mask + std_U_eq * torch.logical_not(mask)    
    # Alternative formulation, but excludes broadcasting: Requires mask to be of the same size U.
    #U[mask] *= std_U_neq
    #U[torch.logical_not(mask)] *= std_U_eq

    # Generate V
    if std_rel_dX[0] == std_rel_dX[1]:
        V = a_V * mask * torch.randn(size=full_size, dtype=dtype, device=device)
    else:
        V = mask * torch.randn(size=full_size, dtype=dtype, device=device)
        V *= ( a_V + (b_V-a_V) * torch.rand(size=size_pfx+(1,1), dtype=dtype, device=device) )           
    
    del mask

    X1 = U-V
    X2 = U+V

    if calc_distance:
        #sinkhorn = SinkhornDistance(eps=1e-6, max_iter=1000, reduction=None, device=device)

        #X1 = X1.squeeze(dim=1)
        #X2 = X2.squeeze(dim=1)

        #print('X1 shape:')
        #print(X1.shape)

        # l2 Sinkhorn distance from the wassdistance package
        #dist_sinkhorn, P, C = sinkhorn(X1, X2)
        #dist_sinkhorn = torch.sqrt(n*dist_sinkhorn)
        #print("Sinkhorn(X1,X2): ", dist_sinkhorn)

        # Simple l2 distance: Should be approximately equal to the l2 Sinkhorn distance
        #dist_l2 = torch.sqrt(torch.norm(X1-X2, dim=2).square().sum(dim=1))
        #print('l2dist: ', dist_l2)

        # l1 Sinkhorn distance from the `geomloss' package
        calc_dist = SamplesLoss(loss="sinkhorn", p=1, blur=1e-6)
        dist = n*calc_dist(X1, X2)
        #print('geomloss Sinkhorn distance: ', L)

        # Simple l1 distance. Should be approximately equal to the l1 Sinkhorn distance.
        dist_l1 = torch.norm(X1-X2, dim=2).sum(dim=1)

        dist = torch.min(dist, dist_l1)
        #print('l1dist: ', dist_l1)

        return X1, X2, dist
    
    else:
        return X1, X2