% permuted mnist

function [] =mnist_gd_sigmadependence(i,seed)
load('mnist.mat')
iteration = 1e6;
start = 1e6;
M = 100;
beta = 1e4;
eta1 = 0.002;
N = 200;
sigma_all = [0.1:0.05:0.7,0.8,0.9,1];

n = 1;

N_0 = 784;
sigma = sigma_all(i);
P = 300;
N_s = 1000;
rng(2,'twister');


trainX = double(trainX);
testX = double(testX);
trainX = (trainX'-mean(trainX'))./std(trainX');
testX = (testX'-mean(testX'))./std(testX');

indtrain = find(mod(trainY,2)==0);
indtest = find(mod(testY,2)==0);
indtrain1 = find(mod(trainY,2)==1);
indtest1 = find(mod(testY,2)==1);
index = [indtrain(1:P/2),indtrain1(1:P/2)];
index_t = [indtest(1:N_s/2),indtest1(1:N_s/2)];

x_00 = trainX(:,index);
x_t0 = testX(:,index_t);

y = (mod(trainY(index),2)==0);
y_t = (mod(testY(index_t),2)==0);
rng(3,'twister');
v = normrnd(0,1,M,N_0);
th =0;
for k = 1:n
    ind(k,:) = randperm(N_0);
    x_0(:,((k-1)*P+1):(k*P)) = double(x_00(ind(k,:),:));
    x_t(:,((k-1)*N_s+1):(k*N_s)) = double(x_t0(ind(k,:),:));
end
    g = (v*x_0>th);
    gp = (v*x_t>th);
    ind = find(sum(g)==0);
for i = 1:length(find(sum(g)==0))
    field = v*x_0(:,ind(i));
    maxind = find(field==max(field));
    g(maxind,ind(i)) = 1;
end
ind = find(sum(gp)==0);
for i = 1:length(find(sum(gp)==0))
    field = v*x_t(:,ind(i));
    maxind = find(field==max(field));
    gp(maxind,ind(i)) = 1;
end


y = kron(ones(1,n),double(y));
y_t = kron(ones(1,n),double(y_t));

y = 2*y-1;
y_t = 2*y_t-1;


[~,H] = numeric_saddlepoint_nonlinear(N_0,P*n,M,g,y,x_0,N,sigma,beta);



K_m = sigma^2/N_0*(g'*H*g).*(x_0'*x_0);

K_m = (K_m+K_m')/2;




    K_1 = sigma^2/N_0*(x_0'*x_0);
    K_1p = sigma^2/N_0*(x_t'*x_0);
    k_1p = sigma^2/N_0*(x_t'*x_t);
    
    K2 = (g'*H*g).*K_1 + 1/beta*eye(P*n,P*n);

    K2p = (gp'*H*g).*K_1p;
   
    k2p = (gp'*H*gp).*k_1p;
    %k = 1/N_1*(y*pinv(K2)*y');
    fp =  K2p*pinv(K2)*y';
    ft = ((g'*H*g).*K_1)*pinv(K2)*y';
    fvar_th = diag(k2p - K2p*pinv(K2)*K2p');
    % fvar_th(fp==0) = 1;
    ge = sum((fp-y_t').^2)/N_s/n + mean(fvar_th);
    
    gt = sum((ft-y').^2);

    gb = mean((y_t'+1)/2+(-y_t').*erfc(-fp./sqrt(fvar_th*2))/2);
    
    f_average = zeros(1,N_s*n);
    f_var = zeros(1,N_s*n);
% gradient descent
rng(seed,'twister');
w = normrnd(0, sigma, N_0+M, N);

for j0 = 1:iteration
    eta = eta1;
    w_1 = reshape(w(1:N_0, :), N_0, N);
    w_2 = reshape(w(N_0+1:end, :), M, N)';
    
%     for m0 = 1:N_1
%         x_1(m0,:) = 1/sqrt(N_0*M)*diag(x_0'*w_1(:,:,m0)*g);
%     end
    x_1 = 1/sqrt(N_0)*w_1'*x_0;
    x_1t = 1/sqrt(N_0)*w_1'*x_t;
    f = (1/sqrt(N*M)*diag(g'*w_2'*x_1))';
    f_t = (1/sqrt(N*M)*diag(gp'*w_2'*x_1t))';
    E(j0) = sum((y - f).^2);
   

    
    w_2 = w_2 + (eta*1/sqrt(N*M)*g*diag((y-f))*x_1')';
    w_1 = w_1 + (eta/sqrt(N*M*N_0)*(w_2*g)*diag(y-f)*x_0')';
    
    bias = mean((f_t - y_t).^2);
    bias_th = mean((fp' - y_t).^2);
    fvar_th = mean(fvar_th);
        if mod(j0,1e1) == 1
    disp(['train',num2str(E(end))]);
%    disp(['trainr',num2str(Er(end))]);
    end
    
    
    w(1:N_0,:) = reshape(w_1, N_0, N);
    w(N_0+1:end,:) = reshape(w_2', M, N);
       if (E(end)<gt)
           break
   
       end
end


 save(['../global_dependence_sigma_gd/mnist_generalization_context_n',num2str(n),'N',num2str(N),'M',num2str(M),'P',num2str(P),'sig',num2str(sigma),'seed',num2str(seed),'.mat'],'ge','gb','fp','y_t','fvar_th','H','g','gb','x_0','x_t','f_t','f_var')
end