import numpy as np
from numpy.linalg import norm, eigh


# Utility functions
def dinfty(uhat, ustar):
    return min(max(np.abs(uhat - ustar)), max(np.abs(uhat + ustar)))

def pmdiff(uhat, ustar):
  return min(np.abs(uhat-ustar), np.abs(uhat+ustar))


def sample_spherical(npoints, ndim):
    vec = np.random.randn(ndim, npoints)
    vec /= np.linalg.norm(vec, axis=0)
    return vec

def get_ortho(n):
  A = np.random.randn(n * n).reshape(n, n).astype("float32")
  O, _ = np.linalg.qr(A)
  return O.astype("float32")

# parameters and setups
def simulator(n, a, is_haar, noise_type):

  r = 1
  H = get_ortho(n)

  # Eigenvalues
  lambdas = [np.sqrt(n*np.log(n))]

  if is_haar == 0:
      xi = 2*np.random.binomial(1, 0.5, size=n-1)-1
  else:
      xi = sample_spherical(1, n-1)
      xi = xi.reshape(-1)
  b1 = np.sqrt(1-a**2)*xi/norm(xi,2)
  u_new = np.concatenate(([a],b1)).reshape(-1,1)
  loc = np.random.randint(0, n)
  c = u_new[loc, 0]
  u_new[loc, 0], u_new[0,0] = a, c
  U = u_new

  # Create Symmetric Noise
  if noise_type == 1:
     W = np.random.randn(n * n).reshape(n, n).astype("float32")
  elif noise_type == 2:
     W = np.random.laplace(0, 1/np.sqrt(2), n*n).reshape(n,n).astype("float32")
  else:
     W = (2*np.random.binomial(1, 0.5, size=n*n)-1).reshape(n,n).astype("float32")
  W_diag = np.diag(W)
  W_upper = np.triu(W)
  W_symm = W_upper + W_upper.T
  np.fill_diagonal(W_symm, W_diag)
  # Create the Signal Matrix
  M = U @ np.diag(lambdas) @ U.T
  # Create Observation
  Y = (M + W_symm).astype("float32")
  del M, W_symm, W, W_diag, W_upper
  # Eigendecomposition
  Yhat = H @ Y @ H.T
  eigvals, eigvecs = eigh(Yhat)
  k = 0
  if eigvals[0] < np.abs(eigvals[n-1]):
    k = n-1
  uhat_specs = H.T @ eigvecs[:,k]
  sigma_hat = np.sqrt(np.sum(np.triu((Y - (eigvals[k] * np.outer(uhat_specs, uhat_specs)))**2))/(n*(n+1)/2))
  del Y
  
  uspec_dinfty = [0]
  eigvals_c = (eigvals[k] + np.sqrt(eigvals[k]**2 - 4*n*(sigma_hat**2)))/2
  uspec_dinfty[0] = "{:.4f}".format(dinfty(uhat_specs, U[:,0]))
  uhatc_dinfty = [0]

  results2 = {'n':[], 'r':[], 'a':[], 'noise':[], 'haar':[], 
              'largest.loc':[], 'U.largest':[], 'Ours.Est':[], 'Spec.Est':[], 'Ours.Error':[], 'Spec.Error':[]}

  
  qs = np.sign(eigvecs[:,k])
  alpha0 = np.quantile(np.abs(eigvecs[:,k]), 0.5) 


  # Calculate Yhats
  Yhats = qs.reshape(-1, 1) * Yhat * qs
  del Yhat
  I = np.where(np.abs(eigvecs[:,k]) > alpha0)[0]

  # Initialize uhat
  uhat_c = np.zeros(n)
  # Calculate u.sum.root
  u_sum_root_c = np.sqrt(np.sum(Yhats[I[:, np.newaxis], I]) / np.abs(eigvals_c))
  # Calculate each element of uhat
  for i in range(n):
      uhat_c[i] = np.sum(Yhats[i, I]) / (eigvals_c * u_sum_root_c)

  # Apply Qs to uhat
  uhat_c = qs*uhat_c 
  uhat_c = H.T @ uhat_c
  
  for i in range(n):
    if np.abs(uhat_specs[i]) <= sigma_hat*np.abs(np.log(n)/eigvals_c):
      uhat_c[i] = uhat_specs[i]

  uhatc_dinfty[0] = "{:.4f}".format(dinfty(uhat_c, U[:,0]))

  noise_dict = {1:'Gaussian', 2:'Laplacian', 3:'Rademacher'}
  haar_dict = {0:"No", 1:'Yes'}

  results2['n'].append(n)
  results2['r'].append(r)
  results2['a'].append(a)
  results2['noise'].append(noise_dict[noise_type])
  results2['haar'].append(haar_dict[is_haar])
  results2['largest.loc'].append(loc)
  results2['U.largest'].append(U[loc, 0])
  results2['Ours.Est'].append("{:.4f}".format(uhat_c[loc]))
  uhatc_diff = pmdiff(uhat_c[loc], U[loc,0])
  results2['Ours.Error'].append("{:.4f}".format(uhatc_diff))
  results2['Spec.Est'].append("{:.4f}".format(uhat_specs[loc]))
  uspec_diff = pmdiff(uhat_specs[loc],U[loc,0])
  results2['Spec.Error'].append("{:.4f}".format(uspec_diff))
  
  results1 = {
     "n":[n],
     "r":[r],
     "a":[a],
     "noise":[noise_dict[noise_type]],
     "haar":[haar_dict[is_haar]],
     "Ours.Error":uhatc_dinfty,
     "Spec.Error":uspec_dinfty
     }
  
  results = {"results1":results1, "results2":results2}
  return results
