import numpy as np
from p_stars import GaussMix
from math import cos,sin

def get_2drot_in_dimen(dimen, ang, nd1, nd2):
    ang = ang / 180 * np.pi
    retval = np.eye(dimen)
    retval[nd1,nd1] = cos(ang)
    retval[nd1,nd2] = -sin(ang)
    retval[nd2,nd1] = sin(ang)
    retval[nd2,nd2] = cos(ang)
    
    return retval
    
def get_std_gauss_4mix(dimen, crnr=1., rotMtx=None, shift=None):
    # empr set
    gmixprobs=np.ones((4,))* 0.25 #[.25,.5,.1,.1]
    gmixmeans=np.ones((4, dimen))
    gmixstdevs = np.ones((4,dimen))
    gmixmeans[0,:] *= -crnr
    gmixmeans[3,:] *=  crnr
    gmixmeans[1, range(0,dimen,2)] *= -crnr
    gmixmeans[1, range(1,dimen,2)] *=  crnr
    gmixmeans[2, range(1,dimen,2)] *=  crnr
    gmixmeans[2, range(0,dimen,2)] *= -crnr

    if rotMtx is not None:
        for j in range(4):
            gmixmeans[j,:] = rotMtx @ gmixmeans[j,:]
    
    if shift is not None:
        for j in range(4):
            gmixmeans[j,:] += shift[:]
            
    print ("Mean: \n", gmixmeans)
    gmix = GaussMix(gmixprobs, gmixmeans, gmixstdevs)

    return gmix


def get_theta_set(n_theta, dimen):
    while True:
        thetas = np.random.random((n_theta, dimen))
        eigs,v = np.linalg.eig(thetas @ thetas.transpose())
        smol=np.where(np.abs(eigs)<1e-2)
        if (np.size(smol) <= 0):
            return thetas


def sample_rotation_mtx(dimen):

    n_rots=dimen*2
#    [ (22.5, 0,1), (10, 1,2), (65, 3, 4), (40., 2, 7),
#        (22.5, 10,5), (70, 5, 17), (35, 13, 4), (55., 12, 7),
#        (32.5, 8,1), (66.7, 11,12), (33., 14, 9), (70., 18, 19)
#       ]
#    
    angs = np.random.uniform(-180.,180.,size=(n_rots,))
    from_ndx=np.random.randint(0,dimen,size=(n_rots,))
    to_delta=np.random.randint(1,dimen,size=(n_rots,))
    rotm = np.eye(dimen)
    for rot in range(n_rots):
        ang = angs[rot]
        fro = from_ndx[rot]
        to = (fro + to_delta[rot]) % (dimen)
        rotm = rotm @ get_2drot_in_dimen(dimen, ang, fro, to)
    
    return rotm


if __name__ == '__main__':

#    configure_logging({
#            'matplotlib' : {'level':'WARN'},
#            'wdro.func_approx' : {'level': 'INFO'},
#            'optim': {'level':'WARN'},
#            'optim.line_search' : {'level':'ERROR'},
#            'optim.iterdata':{'level':'WARN'},
#            'optim.stopctrn':{'level':'WARN'},           
#            })
#
    
    dimen, n_theta, n_samp, n_anchs=20, 3, 50, 10
    
    np.random.seed(1548663359)
    # base distn    
    thetas = get_theta_set(n_theta, dimen)
    for nt in range(n_theta):
        thetas[nt] /= np.sqrt(thetas[nt].dot(thetas[nt]))
    print("thetas is {}".format(thetas))

    rotm = sample_rotation_mtx(dimen)
    print("Rotation: \n", rotm)

    shift=np.random.uniform(0,2.,size=(dimen,))
    print("Shift: \n", shift)

    # empr set
    gmix = get_std_gauss_4mix(dimen, 1.)
    rngs_set = gmix.get_range_of_lincomb(thetas)
    print("orig rngset is {}".format(rngs_set))


    gmixrot=get_std_gauss_4mix(dimen, 1., rotm)    
    rngs_set_rot = gmixrot.get_range_of_lincomb(thetas)
    print("rotd rngset is {}".format(rngs_set_rot))
    
#    gmixrot_shift=get_std_gauss_4mix(dimen, 1., rotm, shift)
#    rngs_set_rot_shift = gmixrot_shift.get_range_of_lincomb(thetas)    
#    print("r+sf rngset is {}".format(rngs_set_rot_shift))


    import matplotlib.pyplot as plt
    fg,ax = plt.subplots(nrows=n_theta,ncols=1, 
                         figsize=(12,4.*n_theta), 
                         squeeze=False)    

    n_samples,num_bins=82500,200

    osamp=gmix.next_sample(n_samples)
    oproj = osamp @ thetas.transpose()
    print("ssiz: {}".format( oproj.shape), end='')
    print("Min {}, max {}".format(np.min(oproj,axis=0), np.max(oproj,axis=0)))
    rosamp=gmixrot.next_sample(n_samples)
    roproj = rosamp @ thetas.transpose()
    print("ssiz: {}".format( roproj.shape), end='')
    print("Min {}, max {}".format(np.min(roproj,axis=0), np.max(roproj,axis=0)))

    for n in range(n_theta):
        ax[n,0].hist(oproj[:,n],num_bins,facecolor='blue', alpha=0.5, density=True)
        ax[n,0].get_yaxis().set_visible(False)
        ax[n,0].set_xticks([0.0])
        #print("weights of orig is ", p)

        ax[n,0].hist(roproj[:,n],num_bins,facecolor='red', alpha=0.5, density=True)
        #print("weights of rotd is ", p2)
    
#    rsosamp=gmixrot_shift.next_sample(n_samples)
#    rsoproj =  rsosamp @ thetas.transpose()
#    print("ssiz: ", oproj.shape)
#    for n in range(n_theta):
#        ax[n,0].hist(rsoproj[:,n],num_bins,facecolor='green', alpha=0.5)
    
    plt.show()
