clc; clear;

%% Data parameters
n         = 1000; % Number of datapoints
d         = 1000; % Dimensionality
s         = 10;   % Number of dimensions turned on
noise_var = 1;

%% Generate data
% [A, data_cov, X] = generate_data(n,d,s,1,noise_var);
d1 = 20; d2 = 10; s = d1;
[data_cov, X]  = generate_data_jing_lei(n,d,d1,d2);
[U,Lambda]     = eig(data_cov);
U              = real(U);
Lambda         = real(Lambda);
[Lambda, ind]  = sort(diag(Lambda)); 
U              = U(:, ind);
sample_cov     = cov(X);
v1             = U(:,d);
vp             = U(:,1:d-1);
lambda1        = Lambda(d);
lambda2        = Lambda(d-1);

%% Run algorithm
k     = 1;
gamma = s;
init  = randn(d,k);
init  = init/norm(init);

fprintf("Inner product of Initialisation and True Eigenvector = %.5f\n", ...
        init'*v1);

% Traditional Oja with Truncation
fprintf("Traditional Oja's Algorithm with truncation\n");
error_tradoja_trunc = zeros(n,1);
Q     = init;
Q     = Q/norm(Q);
[Q,~] = qr(Q,"econ");
error_tradoja_trunc(1) = norm(U(:,1:d-k)'*Q,"fro")^2;
lr_init         = 0.25*log(n)/n;
% lr_init         = 0.2;
max_indices     = zeros(n,1);
iteration_found = zeros(n,s);
Bn_u0           = init;

factor = 1;
if(all(Bn_u0 == init))
    factor = abs(v1'*init);
end

in_support_index = 1;
outside_support_index = s+1;

average_in_support = zeros(n,1);
average_out_support = zeros(n,1);

runs = 10;

all_entry_in_support = zeros(runs, n);
all_entry_outside_support = zeros(runs, n);

for r = 1:runs
    % [A, data_cov, X] = generate_data(n,d,s,1,noise_var);
    [data_cov, X]  = generate_data_jing_lei(n,d,d1,d2);

    entry_in_support = zeros(n,1);
    entry_outside_support = zeros(n,1);
    theoretical_entry_in_support = ones(n,1);
    theoretical_entry_outside_support = ones(n,1);
    theoretical_entry_in_support(1) = factor*v1(in_support_index);
    % theoretical_entry_outside_support(1) = factor*16*lr_init*sqrt(lambda1-data_cov(outside_support_index,outside_support_index));
    ei_outside_support = zeros(d,1);
    ei_outside_support(s+1) = 1;
    init_prime = ei_outside_support'*(vp*vp')*init;
    theoretical_entry_outside_support(1) = abs(init_prime);

    Bn_u0 = init;

    for t=1:n
        x_t = X(t,:)';
        % lr = lr_init/(t + 1);
        lr = lr_init;
        Q = (eye(d) + lr*(x_t*x_t'))*Q;
        Bn_u0 = (eye(d) + lr*(x_t*x_t'))*Bn_u0;
        entry_in_support(t) = Bn_u0(in_support_index);
        entry_outside_support(t) = Bn_u0(s+1);
    
        if(t >= 2)
            theoretical_entry_in_support(t) = theoretical_entry_in_support(t-1)*(1+lr*lambda1);
            theoretical_entry_outside_support(t) = theoretical_entry_outside_support(t-1)*(1+lr*lambda2);
        end
    
        [Q,~] = qr(Q,"econ");
        Q_prime = row_truncation(Q, gamma);
        [Q_prime,~] = qr(Q_prime,"econ");
        error_tradoja_trunc(t+1) = norm(U(:,1:d-k)'*Q_prime,"fro")^2;
        if(mod(t,100) == 0)
            fprintf("Error after %d iterations = %.5f\n", t, error_tradoja_trunc(t+1));
        end
    end
    average_in_support = average_in_support + (entry_in_support/runs);
    average_out_support = average_out_support + (entry_outside_support/runs);
    all_entry_in_support(r, :) = entry_in_support';
    all_entry_outside_support(r, :) = entry_outside_support';
end
Q_traditional_trunc = Q_prime;

fprintf("Final error Traditional Oja's Algorithm with truncation : %.5f\n", error_tradoja_trunc(end));

fprintf("============================\n");

theoretical_threshold_1 = theoretical_entry_in_support - 0.1*sqrt(d)*(theoretical_entry_outside_support);
theoretical_threshold_2 = theoretical_entry_outside_support + 0.1*sqrt(d)*(theoretical_entry_outside_support);


% Plotting
figure;
xlabel('Timesteps', 'FontSize', 20);
ylabel('Error', 'FontSize', 20);
title('sin-squared error with timesteps', 'FontSize', 20);

hold on;
plot(1:n, log(abs(average_in_support)), 'LineWidth', 2);
plot(1:n, log(abs(average_out_support)), 'LineWidth', 2);
plot(1:n, log(theoretical_entry_in_support), 'LineWidth', 2);
plot(1:n, log(theoretical_entry_outside_support), 'LineWidth', 2);

% % Error bars with reduced frequency for visibility
% error_bar_frequency = 10;
% indices = 1:error_bar_frequency:n;
% std_in_support = std(log(all_entry_in_support), 0, 1);
% std_out_support = std(log(all_entry_outside_support), 0, 1);
% 
% % Reduced opacity for error bars
% alpha_value = 0.1; % Adjust this value for desired transparency
% 
% errorbar(indices, log(abs(average_in_support(indices))), std_in_support(indices), 'LineWidth', 2, 'CapSize', 10, 'Color', [0 0 1 alpha_value]);
% errorbar(indices, log(abs(average_out_support(indices))), std_out_support(indices), 'LineWidth', 2, 'CapSize', 10, 'Color', [1 0 0 alpha_value]);

legend('Entry in support (sample)', 'Entry outside support (sample)', 'Entry in support (population)', 'Entry outside support (population)', 'FontSize', 20);

grid on;
grid minor;
set(gca, 'GridLineStyle', '-', 'GridAlpha', 0.5, 'LineWidth', 0.5);

% threshold = 0.1*(1+lr*lambda1)^n;
% fraction_in_support = sum(Bn_u0(1:s) > threshold)/s;
% fraction_outside_support = sum(Bn_u0(s+1:d) > threshold)/(d-s);
% 
% figure;
% xlabel('Timesteps');
% ylabel('Error');
% title("sin-squared error with timesteps");
% plot(1:n,log(abs(average_in_support))); hold on;
% plot(1:n,log(abs(average_out_support))); hold on;
% plot(1:n,log(theoretical_entry_in_support)); hold on;
% plot(1:n,log(theoretical_entry_outside_support));
% % plot(1:n,log(theoretical_entry_outside_support)); hold on;
% % plot(1:n,log(theoretical_threshold_1));
% % plot(1:n,log(theoretical_threshold_2));
% legend("Entry in support (sample)", ...
%        "Entry outside support (sample)", ...
%        "Entry in support (population)", ...
%        "Entry outside support (population)");
% % legend("Entry in support (sample)", ...
% %        "Entry outside support (sample)", ...
% %        "||e_{i}^Tv_{1}v_{1}^Tu_{0}||(1+\eta\lambda_{1})^n", ...
% %        "||e_{i}^TV_{\perp}V_{\perp}^Tu_{0}||(1+\eta\lambda_{2})^n");
% 
% % figure;
% % xlabel('Timesteps');
% % ylabel('Ratio');
% % ratio = entry_in_support./entry_outside_support;
% % plot(1:n,ratio);
% % legend("Ratio vs time");
