function [best_thetas,best_w,best_yh,psses,sses,mses,spars] = elastic_basis_pursuit(thetas,y,X,cache,f,regtype,lambda,tau,vfunc,nits,verbose)
% [best_thetas,best_w,best_yh,psses,sses,mses,spars] = elastic_basis_pursuit(thetas,y,X,cache,f,regtype,lambda,tau,vfunc,nits,verbose)
%  Fits a mixture model to y using EBP
%  I.e. solves
%   minimize psse = sum_i (y_i - sum_j w_j f(x_i,theta_j))^2 + lambda * P(w)
%
% Inputs:
%  - thetas, D x K0 matrix: initial estimate of parameters; you can set this to []
%  - y, n x 1 signal
%  - X, p x n design matrix X = (x_1,...,x_n)
%  - cache: [], or a struct with the following fields
%   * thetas, D x K1 matrix: previously cached parameter points
%   * inds, 1 x K1 matrix: the theta_space cell index for cached thetas 
%   * F, n x K1 matrix: previously cached signal matrix
%   * nms, 1 x K1 matrix: norms of cached_F
%  - f, function handle of the form f(x,theta): specifies kernel family
%  - regtype, 'none', 'L1', or 'unweighted': regularization type
%  - lambda, scalar: penalization constant
%  - tau, function handle: oracle which returns theta=tau(r), so that f_theta is most correlated with r
%     see generic_oracle.m
%  - vfunc, function handle: function which returns prediction errror mse = vfunc(yh) 
%     see validation_function/,
%  - nits, number of iterations
%  - verbose, boolean: if true, print [iter, psse, sse, mse, spar, corr, toc] in each
%     iteration where corr is the correlation of the newest added theta
% 
% Outputs:
%  - best_thetas, D x K matrix of estimated parameters which minimize validation error
%  - best_w, K x 1 vector of estimated weights which minimize validation error
%  - best_yh, n x1 predicted signal which minimizes validation error
%  - psses, 1 x nits vector: penalized sum of squares error in each iteration
%  - sses, 1 x nits vector: sum of squares error in each iteration
%  - mses, 1 x nits vector: prediction errors in each iteration
%  - spars, 1 x nits vector: size of active set in each iteration
%
% Example usage:
%  theta_space = .5 .* (fullfact([5 5])-3)'; % grid on [-1,1]^2
%  delta = .5;
%  thetas = sample_parameters(theta_space,delta,10); % will sample points from [-1.25,1.25]^2
%  w = rand(10,1);
%  f = @(x,theta) exp(-norm(x-theta)^2); % gaussian kernel
%  X = randn(2,10);
%  X_val = randn(2,10);
%  noise_type = 'gaussian'; noise=0.1; % specifies gaussian noise N(0, 0.01)
%  [y,y0] = generate_signal(X, thetas, w, f, noise_type, noise);
%  [y_val,y0_val] = generate_signal(X, thetas, w, f, noise_type, noise);
%  regtype = 'L1'; lambda = 0.1;
%
%  % create oracle tau
%  grad_eps = [1e-5 0.9]; % uses step size of 1e-5 for numerical derivative, multiplier of 0.9 for line search
%  Ngradsteps = [3 10]; % does 3 gradient steps, max 10 line search steps
%  Nrestarts = 100; % number of random restarts
%  % create the oracle tau as an anonymous function
%  tau = @(r) generic_oracle(r,X,f,theta_space,delta,regtype,lambda,grad_eps,Ngradsteps,Nrestarts);
%
%  % create function handle for vfunc
%  vfunc = @(thetas, w) validation_function(thetas, w, X_val, f, y_val);
%
%  % fit an NNLS
%  [thetas_n, w_n, yh_n, psse_n, sse_n] = nnls_fit(y,X,theta_space,f,regtype,lambda);
% 
%  % fit EBP, using NNLS fit as initialization
%  nits = 100;
%  [best_thetas,best_w,best_yh,psses,sses,mses,spars] = elastic_basis_pursuit(thetas_n,y,X,f,regtype,lambda,tau,vfunc,nits,true)

% transform y for regularized problem
%%
n = size(y,1);
if strcmp(regtype,'L1')
  y = [y; 0];
end
if strcmp(regtype,'unweighted')
  y = [y; sqrt(lambda)];
end

%% Initialization
r = y;
if size(thetas,2) > 0
  F = predictor_matrix(X, thetas, f, regtype, lambda);
  w = lsqnonneg(F,y);
  yh = F*w;
  psse = norm(y-yh)^2;
  sse = norm(y(1:n) - yh(1:n))^2;
  r = y - yh;
  mse = vfunc(thetas,w);
end
[theta_new,corrr,cache] = tau(r,cache);
D = size(theta_new,1);
if size(thetas,2) ==0
  F = zeros(size(y,1),0);
  thetas = zeros(D,0);
  w = zeros(0);
  yh = 0.*y;
  psse = norm(y)^2;
  sse = norm(y(1:n))^2;
  mse = vfunc(theta_new, 0);
end

psses = zeros(1,nits);
sses = zeros(1,nits);
mses = zeros(1,nits);
spars = zeros(1,nits);

best_mse = mse;
best_yh = yh;
best_thetas = thetas;
best_w = w;

K = size(thetas,2);

if verbose;
  {'iter','psse','sse','mse','K','corr','toc';0,psse,sse,mse,K,NaN,NaN}
end


%% Main loop
for iter = 1:nits;
  tic;
  % add the new theta
  K = K+1;
  thetas = [thetas, theta_new];
  F = [F, predictor_matrix(X, theta_new, f, regtype, lambda)];
  % refit weights
  w = lsqnonneg(F, y);
  yh = F * w;
  r = y-yh;
  % prune params
  thetas = thetas(:, w > 0);
  F = F(:, w > 0);
  w = w(w > 0);
  K = size(thetas,2);
  % validation error
  mse = vfunc(thetas, w);
  if mse < best_mse;
    best_mse = mse;
    best_thetas = thetas;
    best_w = w;
    best_yh = yh;
  end
  % various statistics
  psses(iter) = norm(y-yh)^2;
  sses(iter) = norm(y(1:n) - yh(1:n))^2;
  mses(iter) = mse;
  spars(iter) = K;
  % get new theta
  if iter < nits;
    [theta_new,corrr,cache] = tau(r,cache);
  end
  toc_i = toc;
  
  if verbose;
      {'iter','psse','sse','mse','K','corr','toc';iter,psses(iter),sses(iter),mses(iter),spars(iter),corrr,toc_i}
  end

end

end
