import numpy as np
import matplotlib.pyplot as plt
import pickle
from predict_spectrum import CK, NTK
from simulate_NN import histeig, cross, compute_NTK

c=np.sqrt(0.21747/(2*np.sqrt(2*np.pi)))
b=2*(0.258961)/np.sqrt(2*np.pi)/0.208276 # E[sigma'(xi)], for sigmoid
a=2*0.0561939/(np.sqrt(2*np.pi)*(c**2)) # E[sigma'(xi)^2], for sigmoid

Xs = pickle.load(open('training/X.pkl','rb'))
Ws = pickle.load(open('training/W.pkl','rb'))
X0 = np.transpose(Xs[0])
(d0,n) = X0.shape
L = 3
d = [d0] * (L+1)
gamma = n/np.array(d[1:])

ylim = [0,40]
bins = 50

print('Computing CK predictions')
for l in range(1,L+1):
  KCK = cross(np.transpose(Xs[l][-1]))
  eigs = np.linalg.eigvalsh(KCK)
  xgrid = np.linspace(max(eigs)*(-0.05),max(eigs)*1.2,num=1000)
  spec = np.linalg.eigvalsh(cross(X0))
  dgrid = CK(L,gamma,b,xgrid,spec=spec)
  f, ax = plt.subplots(1,1)
  histeig(eigs,ax,xgrid=xgrid,dgrid=dgrid,ylim=ylim,bins=bins,title='Trained CK spectrum, layer %d' % l)
  width = (max(xgrid)-min(xgrid))*0.01
  ax.arrow(eigs[-1],8,0,-3,width=width,color='b',head_width=5*width,head_length=1)
  ax.arrow(eigs[-2],8,0,-3,width=width,color='b',head_width=5*width,head_length=1)
  plt.savefig('trained_X%d.png' % l)

print('Computing NTK prediction')
Xfin = [np.transpose(Xs[0])]
for l in range(L):
  Xfin.append(np.transpose(Xs[l+1][-1]))
KNTK = compute_NTK(Ws[-1],Xfin,d)
eigs = np.linalg.eigvalsh(KNTK)
xgrid = np.linspace(max(eigs)*(-0.05),max(eigs)*1.2,num=1000)
dgrid = NTK(L,gamma,a,b,xgrid,spec=spec)
f, ax = plt.subplots(1,1)
histeig(eigs,ax,xgrid=xgrid,dgrid=dgrid,ylim=ylim,bins=bins,title='Trained NTK spectrum')
width = (max(xgrid)-min(xgrid))*0.01
ax.arrow(eigs[-1],8,0,-3,width=width,color='b',head_width=5*width,head_length=1)
ax.arrow(eigs[-2],8,0,-3,width=width,color='b',head_width=5*width,head_length=1)
plt.savefig('trained_NTK.png')

eigvecs = np.linalg.eigh(KCK)[1][:,(-1,)]
data = pickle.load(open('training/data.pkl','rb'))
y = data['train_labels'].reshape(-1)
proj = np.dot(np.dot(eigvecs,np.linalg.inv(np.dot(np.transpose(eigvecs),eigvecs))),np.dot(np.transpose(eigvecs),y))
f, ax = plt.subplots(1,1)
ax.plot(proj,y,'k.')
ax.set_xlabel('Projection of y onto top PC of CK')
ax.set_ylabel('Training label y')
ax.set_title('CK eigenvectors vs. training labels')
plt.tight_layout()
plt.savefig('trained_CK_correlation1.png')
SSerr = sum((y-proj)**2)
SStot = sum((y-np.mean(y))**2)
Rsq = 1-SSerr/SStot
print('CK projection R^2: %s' % str(Rsq))

eigvecs = np.linalg.eigh(KCK)[1][:,(-1,-2)]
proj = np.dot(np.dot(eigvecs,np.linalg.inv(np.dot(np.transpose(eigvecs),eigvecs))),np.dot(np.transpose(eigvecs),y))
f, ax = plt.subplots(1,1)
ax.plot(proj,y,'k.')
ax.set_xlabel('Projection of y onto top 2 PCs of CK')
ax.set_ylabel('Training label y')
ax.set_title('CK eigenvectors vs. training labels')
plt.tight_layout()
plt.savefig('trained_CK_correlation.png')
SSerr = sum((y-proj)**2)
SStot = sum((y-np.mean(y))**2)
Rsq = 1-SSerr/SStot
print('CK projection R^2: %s' % str(Rsq))

