% gibbs_sample_affine_model:

function [a, b, f_obs, kernel_mcmc, llik_mcmc] = gibbs_sample_affine_model...
            (X_exp, Y_exp, X_space, do_params_X_space, prior_info, hyper_a, X_obs, Z_obs, initial_random, M, locked, verbose)

%% Preliminaries

S = length(X_exp);
num_X_space = length(X_space);

A = cell(S, 1);
seen = cell(S, 1);
unseen = cell(S, 1);
num_seen = zeros(S, 1);
n = zeros(S, 1);

for s = 1:S
  [A{s}, seen{s}] = multi_exp_mapping(X_space, X_exp{s}, X_obs{s}, Z_obs{s}, prior_info{s});
  seen{s} = find(seen{s});
  unseen{s} = find(seen{s} == 0);
  num_seen(s) = length(seen{s});
  n(s) = length(Y_exp{s});
end

%% Prepare prior information

mu_a = ones(num_X_space, 1);
mu_b = zeros(num_X_space, 1);
mu_f = cell(S, 1);
f_mean_unseen = cell(S, 1);
f_mean_seen = cell(S, 1);
f_seen_map = cell(S, 1);
inv_prior_f1 = cell(S, 1);
prior_meancov_f1 = cell(S, 1);

for s = 1:S
  mu_f{s} = do_params_X_space{s}.mu_do;
  K_f = do_params_X_space{s}.K;
  if min(eig(K_f)) < 1.e-10
    K_f = K_f + get_noise_matrix(length(K_f));
  end
  f_mean_unseen{s} = mu_f{s}(unseen{s});
  f_mean_seen{s} = mu_f{s}(seen{s});
  f_seen_map{s} = K_f(unseen{s}, seen{s}) / K_f(seen{s}, seen{s});
  inv_prior_f1{s} = inv(K_f(seen{s}, seen{s}));
  prior_meancov_f = K_f \ mu_f{s};
  prior_meancov_f1{s} = prior_meancov_f(seen{s});
end

core = zeros(num_X_space);
SD2 = zeros(num_X_space);
for s = 1:S
  sd = sqrt(diag(do_params_X_space{s}.K)); sd = sd / max(sd);
  SD2 = SD2 + sd * sd';
  X1 = X_space - min(X_space); X1 = X1 / max(X1);
  Y1 = do_params_X_space{s}.mu_do - min(do_params_X_space{s}.mu_do); Y1 = Y1 / max(Y1);
  core = core + ...
         (X1(:, ones(num_X_space, 1)) - X1(:, ones(num_X_space, 1))').^2 + ...
         (Y1(:, ones(num_X_space, 1)) - Y1(:, ones(num_X_space, 1))').^2 ;  
end
SD2 = SD2 / S; 
core = core / S;

noise_matrix = get_noise_matrix(num_X_space);

%% Main sampling

a = cell(S, 1);
b = cell(S, 1);
kernel_mcmc = cell(S, 1);
f_obs = cell(S, 1);
llik_mcmc = cell(S, 1);
for s = 1:S
  a{s} = ones(num_X_space, M); 
  b{s} = zeros(num_X_space, M); 
  kernel_mcmc{s} = zeros(4, M);  
  f_obs{s} = zeros(num_X_space, M);
  llik_mcmc{s} = zeros(1, M);
end

if ~isempty(initial_random)
  for s = 1:S
    a{s}(:, 1) = initial_random.a{s};
    b{s}(:, 1) = initial_random.b{s};
    kernel_mcmc{s}(1, 1) = initial_random.log_sf2_a(s);
    kernel_mcmc{s}(2, 1) = initial_random.log_ell_a(s);
    kernel_mcmc{s}(3, 1) = initial_random.log_sf2_b(s);
    kernel_mcmc{s}(4, 1) = initial_random.log_ell_b(s);
    f_obs{s}(:, 1) = initial_random.f{s};
    llik_mcmc{s}(1) = initial_random.llikv(s);
  end
  if isfield(initial_random, 'locked_hyper')
    locked_hyper = initial_random.locked_hyper;
  else
    locked_hyper = false;
  end
else
  locked_hyper = false;
end

likv = zeros(S, 1);
inv_prior_a = cell(S, 1);
inv_prior_b = cell(S, 1);
prior_meancov_a = cell(S, 1);
prior_meancov_b = cell(S, 1);

for m = 2:M
   
   if verbose, fprintf('Sampling iteration %d\n', m); end
   
   % Preliminaries
   
   for s = 1:S
       
     a{s}(:, m) = a{s}(:, m - 1);   
     b{s}(:, m) = b{s}(:, m - 1);   
     kernel_mcmc{s}(:, m) = kernel_mcmc{s}(:, m - 1);
     f_obs{s}(:, m) = f_obs{s}(:, m - 1);
     llik_mcmc{s}(m) = llik_mcmc{s}(m - 1);
     likv(s) = exp(llik_mcmc{s}(m));
    
     K_a = exp(kernel_mcmc{s}(1, m)) * SD2 .* exp(-0.5 * core / exp(kernel_mcmc{s}(2, m))) + noise_matrix;
     inv_prior_a{s} = inv(K_a);
     prior_meancov_a{s} = K_a \ mu_a;
     K_b = exp(kernel_mcmc{s}(3, m)) * SD2 .* exp(-0.5 * core / exp(kernel_mcmc{s}(4, m))) + noise_matrix;
     inv_prior_b{s} = inv(K_b);
     prior_meancov_b{s} = K_b \ mu_b;

   end
   
   % Sample f given a
   
   for s = 1:S
     Aa = A{s} .* repmat(a{s}(seen{s}, m)', n(s), 1);
     y1 = Y_exp{s} - A{s} * b{s}(seen{s}, m);
     inv_cov_f = (Aa' * Aa) / likv(s) + inv_prior_f1{s};
     mean_f = inv_cov_f \ (Aa' * y1 / likv(s) + prior_meancov_f1{s});
     f_obs{s}(seen{s}, m) = mean_f + chol(inv_cov_f) \ randn(num_seen(s), 1);
     f_obs{s}(unseen{s}, m) = f_mean_unseen{s} + f_seen_map{s} * (f_obs{s}(seen{s}, m) - f_mean_seen{s});
   end
   
   if ~locked
       
     % Sample a given f and b
 
     for s = 1:S
       inv_cov_a = inv_prior_a{s};
       pre_mean_a = prior_meancov_a{s};

       y1 = Y_exp{s} - A{s} * b{s}(seen{s}, m);
       Aa = A{s} .* repmat(f_obs{s}(seen{s}, m)', n(s), 1);
       inv_cov_a(seen{s}, seen{s}) = inv_cov_a(seen{s}, seen{s}) + (Aa' * Aa) / likv(s);
       pre_mean_a(seen{s}) = pre_mean_a(seen{s}) + Aa' * y1 / likv(s);
         
       mean_a = inv_cov_a \ pre_mean_a;     
       a{s}(:, m) = mean_a + chol(inv_cov_a) \ randn(num_X_space, 1);
     end
     
     % Sample b given f and a

     for s = 1:S
       inv_cov_b = inv_prior_b{s};
       pre_mean_b = prior_meancov_b{s};
     
       y1 = Y_exp{s} - A{s} * (a{s}(seen{s}, m) .* f_obs{s}(seen{s}, m));
       inv_cov_b(seen{s}, seen{s}) = inv_cov_b(seen{s}, seen{s}) + (A{s}' * A{s}) / likv(s);
       pre_mean_b(seen{s}) = pre_mean_b(seen{s}) + A{s}' * y1 / likv(s);
     
       mean_b = inv_cov_b \ pre_mean_b;
       b{s}(:, m) = mean_b + chol(inv_cov_b) \ randn(num_X_space, 1);
     end
   
     if ~locked_hyper
         
       for s = 1:S
           
         % Sample hyperparameters for a

         kernel_mcmc{s}(1, m) = slicesample(kernel_mcmc{s}(1, m), 1, 'logpdf', ...
                            @(x)hyper_seiso_logpdf(x, 1, exp(kernel_mcmc{s}(1:2, m)), a{s}(:, m), mu_a, ...
                                                         hyper_a.sf2.mu, hyper_a.sf2.var, SD2, core, noise_matrix));
         kernel_mcmc{s}(2, m) = slicesample(kernel_mcmc{s}(2, m), 1, 'logpdf', ...
                            @(x)hyper_seiso_logpdf(x, 2, exp(kernel_mcmc{s}(1:2, m)), a{s}(:, m), mu_a, ...
                                                       hyper_a.ell.mu, hyper_a.ell.var, SD2, core, noise_matrix));
     
         % Sample hyperparameters for b

         kernel_mcmc{s}(3, m) = slicesample(kernel_mcmc{s}(3, m), 1, 'logpdf', ...
                            @(x)hyper_seiso_logpdf(x, 1, exp(kernel_mcmc{s}(3:4, m)), b{s}(:, m), mu_b, ...
                                                       hyper_a.sf2.mu, hyper_a.sf2.var, SD2, core, noise_matrix));
         kernel_mcmc{s}(4, m) = slicesample(kernel_mcmc{s}(4, m), 1, 'logpdf', ...
                            @(x)hyper_seiso_logpdf(x, 2, exp(kernel_mcmc{s}(3:4, m)), b{s}(:, m), mu_b, ...
                                                       hyper_a.ell.mu, hyper_a.ell.var, SD2, core, noise_matrix));
           
       end
       
     end
     
   end
   
   % Sample likv
     
   if ~locked_hyper
     for s = 1:S
       exp2 = sum((Y_exp{s} - A{s} * (a{s}(seen{s}, m) .* f_obs{s}(seen{s}, m) + b{s}(seen{s}, m))).^2);   
       llik_mcmc{s}(m) = slicesample(llik_mcmc{s}(m), 1, 'logpdf', @(x)likv_logpdf(x, n(s), exp2));
     end
   end
   
end

