import numpy as np
from sklearn import metrics
import matplotlib.pyplot as plt
from itertools import permutations
import time
from tqdm import tqdm
import seaborn as sns
import scipy

def sigmaSquareHat(X):
    ## X: n * p data matrix, nparray
    n = X.shape[0]
    H = X @ X.T
    S1 = (np.sum(H**2) - np.sum(np.diag(H)**2)) / (n*(n-1))
    H_square = H @ H
    S2 = (np.sum(H_square) - np.sum(np.diag(H_square)) - 2 * np.diag(H)@np.sum(H,0) + 2 * np.diag(H)@np.diag(H))
    S2 = S2 / (n*(n-1)*(n-2))
    H_zero_diag = H - np.diag(np.diag(H))
    S3 = np.sum(H_zero_diag) - 4 * np.sum(np.sum(H_zero_diag, 0)**2)
    S3 = S3 / (n*(n-1)*(n-2)*(n-3))
    return S1 - 2 * S2 + S3

def sigma_square_hat(Y):
    n = np.shape(Y)[0]
    return (np.sum(Y**2) - n * np.mean(Y)**2) / (n-1)

def cgz_t(X, Y):
    # deltaX
    X_scale = X - np.tile(np.mean(X, 0),[n,1])
    H_scale = X_scale @ X_scale.T
    H = X @ X.T
    Q = np.tile(np.diag(H), [n, 1])
    DeltaX = H_scale + (Q + Q.T - 2 * H)/(2*n)
    
    #deltaY
    Y_scale = Y - np.mean(Y)
    G_scale = Y_scale @ Y_scale.T
    G = Y @ Y.T
    D = np.tile(Y**2, [1, n])
    DeltaY = G_scale + (D + D.T - 2 * G)/(2*n)
    
    # calculate T(n, p)
    DeltaXY = DeltaX * DeltaY
    DeltaXY_zero_diag = DeltaXY - np.diag(np.diag(DeltaXY))
    T = n / ((n-1)*(n-2)**2) * np.sum(DeltaXY_zero_diag)
    return T

def cgz(X, Y):
    # return z score of the test
    return n * cgz_t(X, Y) / (sigma_square_hat(Y) * np.sqrt(2 * sigmaSquareHat(X)))

def zc_t(X, Y):
    H = X @ X.T
    T = 0
    n = X.shape[0]
    for (i1,i2,i3,i4) in permutations(np.arange(n), 4):
        T += (H[i1,i3] + H[i2,i4] - H[i1,i4] - H[i2,i3]) * (Y[i1,0]-Y[i2,0]) * (Y[i3,0]-Y[i4,0]) / 4
    return T / (n*(n-1)*(n-2)*(n-3))

def zc(X, Y):
    # return z score of the test
    return n * zc_t(X, Y) / (sigma_square_hat(Y) * np.sqrt(2 * sigmaSquareHat(X)))

def sketching(X, Y, k):
    n, p = X.shape
    S   = np.random.normal(size=[p, k]) / p
    XS  = X @ S
    H   = XS @ np.linalg.inv(XS.T @ XS) @ XS.T
    num = (Y.T @ H @ Y)[0,0]
    den = np.sum(Y**2) - num
    return (num * (n-k)) / (den * k)

n = 50
p = 500
S = 1
N = 500
k = 3 * int(np.log(p))
ss = [50, 100, 300]
bb = [1, 5]
ground_truth = np.concatenate((np.zeros(N), np.ones(N), np.ones(N)))
z_score_abs = np.zeros(3*N)
zc_score_abs = np.zeros(3*N)
f_score = np.zeros(3*N)
stat_mat_f = np.zeros([3, 3*N])
stat_mat_cgz = np.zeros([3, 3*N])
stat_mat_zc = np.zeros([3, 3*N])

for j in range(3):
    s = ss[j]
    for i in tqdm(range(3*N)):
        # data matrix X
        X = scipy.stats.t.rvs(df=2, size=[n, p])
        Sigma_sqrt_diag = np.array([1/(i**(1/3)*np.log(i+1)) for i in range(1,p+1)])
        U, _, _ = np.linalg.svd(np.random.normal(size=[p, p]))
        Sigma_sqrt = U @ np.diag(Sigma_sqrt_diag) @ U.T
        Sigma_sqrt = Sigma_sqrt / np.linalg.norm(Sigma_sqrt) * s
        X = X @ Sigma_sqrt
        
        # noise vector
        epsilon = np.random.normal(size=[n, 1])
        if i <= N-1:
            beta = np.zeros([p, 1])
        else:
            beta = np.random.binomial(n=3,p=0.3,size=[p, 1]) + 0.3 * np.random.normal(size=[p, 1])
            beta = beta / (np.sqrt(np.sum(beta**2)))
        if i >= 2*N:
            beta = beta * 5
            
        Y = X @ beta + epsilon
    
        # sketching
        f_score[i] = sketching(X, Y, k)
        z_score_abs[i] = abs(cgz(X, Y))
        zc_score_abs[i] = abs(zc(X, Y))
    
        # print(i, end=' ')
    stat_mat_f[j,:] = f_score
    stat_mat_cgz[j, :] = z_score_abs
    stat_mat_zc[j, :] = zc_score_abs

np.save('stat_mat_f.npy', stat_mat_f)
np.save('stat_mat_cgz.npy', stat_mat_cgz)
np.save('stat_mat_zc.npy', stat_mat_zc)

n = 50
p = 500
S = 1
N = 500
k = 25
ss = [50, 100, 300]
bb = [1, 5]
ground_truth = np.concatenate((np.zeros(N), np.ones(N), np.ones(N)))
z_score_abs = np.zeros(3*N)
zc_score_abs = np.zeros(3*N)
f_score = np.zeros(3*N)
stat_mat_f = np.zeros([3, 3*N])
stat_mat_cgz = np.zeros([3, 3*N])
stat_mat_zc = np.zeros([3, 3*N])

for j in range(3):
    s = ss[j]
    for i in tqdm(range(3*N)):
        # data matrix X
        X = np.random.normal(size=[n, p])
        Sigma_sqrt_diag = np.array([1/np.log(i+1)**3 for i in range(1,p+1)])
        U, _, _ = np.linalg.svd(np.random.normal(size=[p, p]))
        Sigma_sqrt = U @ np.diag(Sigma_sqrt_diag) @ U.T
        Sigma_sqrt = Sigma_sqrt / np.linalg.norm(Sigma_sqrt) * s
        X = X @ Sigma_sqrt
        
        # noise vector
        epsilon = np.random.normal(size=[n, 1])
        if i <= N-1:
            beta = np.zeros([p, 1])
        else:
            beta = np.random.binomial(n=3,p=0.3,size=[p, 1]) + 0.3 * np.random.normal(size=[p, 1])
            beta = beta / (np.sqrt(np.sum(beta**2)))
        if i >= 2*N:
            beta = beta * 5
            
        Y = X @ beta + epsilon
    
        # sketching
        f_score[i] = sketching(X, Y, k)
        z_score_abs[i] = abs(cgz(X, Y))
        zc_score_abs[i] = abs(zc(X, Y))
    
        # print(i, end=' ')
    stat_mat_f[j,:] = f_score
    stat_mat_cgz[j, :] = z_score_abs
    stat_mat_zc[j, :] = zc_score_abs
    
np.save('stat_mat_f_slow.npy', stat_mat_f)
np.save('stat_mat_cgz_slow.npy', stat_mat_cgz)
np.save('stat_mat_zc_slow.npy', stat_mat_zc)
