#include <random>
#include <algorithm>
#include <chrono>
#include "SRHT_algorithm.h"
using namespace std;

// constructor
SRHT_algorithm::SRHT_algorithm(the_Data *input_data, int n){
    this -> data = input_data;
    this -> N  = input_data -> N;
    this -> p  = input_data -> p;
    this -> n  = n;
    constant_term.resize(this -> N, 1.);
    int tmp = 1;
    while(tmp < N ){
        tmp *= 2;
    }
    (data->y).resize(tmp, 0.);
    (data->Z).matrix.resize(tmp * p, 0.);
    (data->Z).nrow = tmp;
    constant_term.resize(tmp, 0.);
    this -> N = tmp;


    position_flag.resize(this -> N);
    for(int i = 0; i < N; ++ i){
        if(i < n){
            position_flag[i] = true;
        }else{
            position_flag[i] = false;
        }
    }
    random_shuffle(position_flag.begin(), position_flag.end());
}

vector<T> SRHT_algorithm::Execute(){
    sign_flipping();
    transform(0, N);
    compute();
    vector<T> estimated_beta;
    estimated_beta = solve();
    return estimated_beta;
}

void SRHT_algorithm::sign_flipping(){
    for(int i = 0; i< N; ++i ){
        if(rand() % 2){
            constant_term[i] *= -1;
            data->y[i] *= -1;
            for(int j = 0; j < p; ++j){
                data->Z(i,j) *= -1;
            }
        }
    }
}

void SRHT_algorithm::transform(int left, int right){
    if(left == right - 1) return;
    int middle = (left + right) / 2;
    transform(left, middle);
    transform(middle, right);

    int tmp_width = middle - left;
    ///////////// transform constant ///////////////
    for(int i = left; i < middle; ++i){
        constant_term[i] = constant_term[i] + constant_term[i + tmp_width];
        constant_term[i + tmp_width] = constant_term[i] - 2 * constant_term[i + tmp_width];
    }
    ///////////// transform y //////////////////////
    for(int i = left; i < middle; ++i){
        data->y[i] = data->y[i] + data->y[i + tmp_width];
        data->y[i + tmp_width] = data->y[i] - 2 * data->y[i + tmp_width];
    }
    ///////////// transform Z //////////////////////
    for(int i = left; i < middle; ++i){
        for(int j = 0; j<p ; ++j){
            data->Z(i, j) = data->Z(i, j) + data->Z(i + tmp_width, j);
            data->Z(i + tmp_width, j) = data->Z(i, j) - 2 * data->Z(i + tmp_width, j);
        }
    }
}

void SRHT_algorithm::compute(){
    XXT.init(p+1, p+1);
    XY.resize(p+1);
    for(int i=0; i < p+1; ++i)
        XY[i] = 0;
    for(int counter = 0; counter < N; ++ counter){
        if(!position_flag[counter]) continue;
        XXT(0, 0) += constant_term[counter] * constant_term[counter];
        for(int i = 1; i< p+1; ++i){
            XXT(i,0) += data->Z(counter, i-1) * constant_term[counter];
        }
        for(int i = 1; i< p+1; ++i){
            for(int j = 1; j<= i; ++j){
                XXT(i, j) += (data->Z(counter, i-1)) * (data->Z(counter, j-1));
            }
            XY[i] += data->Z(counter, i-1) * (data->y[counter]);
        }
        XY[0] += (data->y[counter]) * constant_term[counter];
    }
}

// Gaussian elimination
vector<T> SRHT_algorithm::solve(){
    vector<T> estimated_beta(p+1);
    for(int i = 0; i< p+1; ++i){
        for(int j =i; j< p+1; ++j){
            XXT(i, j) = XXT(j, i);
        }
    }
    for(int counter = 0; counter < p; ++counter){
        for(int i = counter + 1; i < p+1; ++i){
            T the_ratio = XXT(counter, i) / XXT(counter, counter);
            for(int j = counter; j < p+1; ++j){
                XXT(j, i) -= the_ratio * XXT(j, counter);
            }
            XY[i] -= the_ratio * XY[counter];
        }
    }
    for(int i = p; i>=0; --i){
        estimated_beta[i] =  XY[i];
        for(int j = i+1; j<=p; ++j){
            estimated_beta[i] -= XXT(j, i) * estimated_beta[j];
        }
        estimated_beta[i] /= XXT(i, i);
    }
    return estimated_beta;
}
