#include "lmo.impl.h"
#include <scip/scip.h>
#include <scip/scipdefplugins.h>

#include <iostream>
#include <vector>
#include <random>
#include <chrono>
#include <omp.h>

using namespace std;
using namespace std::chrono;

template<typename real>
inline bool epseq(const real x, const real y, const real eps=1e-6){
  return fabs(x-y) <= eps;
}

/** define macro to print error message and exit */
#define SCIP_CALL_ERROR(x)   do                                                                               \
                       {                                                                                      \
                          SCIP_RETCODE _restat_;                                                              \
                          if( (_restat_ = (x)) != SCIP_OKAY )                                                 \
                          {                                                                                   \
                             SCIPprintError(_restat_);                                                        \
                             return -1;                                                                       \
                           }                                                                                  \
                       }                                                                                      \
                       while( FALSE )

using std::vector;
using ix_t = uint64_t;
using val_t = double;

template<typename val_t>
using vec_t = Eigen::VectorX<val_t>;

template<typename Numeric, typename Generator = std::mt19937>
Numeric random(Numeric from, Numeric to)
{
    thread_local static Generator gen(std::random_device{}());

    using dist_type = typename std::conditional
    <
        std::is_integral<Numeric>::value
        , std::uniform_int_distribution<Numeric>
        , std::uniform_real_distribution<Numeric>
    >::type;

    thread_local static dist_type dist;

    return dist(gen, typename dist_type::param_type{from, to});
}

int main(int argc, char * argv[])
{
    SCIP* scip = NULL;
    SCIP_CALL_ERROR( SCIPcreate(&scip) );
    
    int n = 8;
    SCIP_Longint capacity;
    vec_t<SCIP_Longint> weights(n);
    vec_t<val_t> direction(n);

    for (int i=0; i<n; i++){
        weights[i] = random<int>(0, 1000);
        direction[i] = random<double>(-1, 1);
        capacity = random<int>(10000, 100000);
    }

    // weights << 988, 891, 818, 489, 435, 289, 233, 144, 120, 94, 77;
    // direction << 0.12, 0.17, -0.04, 0.02, -0.07, -0.01, 0.01, -0.00, 0.01, 0.01, -0.00;
    // capacity = 4000;

    // std::cout << "weights: ";
    // for (auto c: weights) std::cout << c << " ";
    // std::cout << std::endl;

    // std::cout << "direction: ";
    // for (auto c: direction) std::cout << c << " ";
    // std::cout << std::endl;

    // std::cout << "capacity: " << capacity << std::endl;


    std::vector<std::tuple<std::string, oracle_t>> oracles{
                                             {"CPU_ENUMERATION", oracle_t::ENUMERATION_CPU}, 
                                             {"DYNAMIC_PROGRAMMING", oracle_t::REF_ORACLE},
                                             {"ISPC_ENUMERATION", oracle_t::ISPC_ORACLE},
                                             {"HYBRID_DYNPROG_ISPC", oracle_t::CPU_HYBRID}
                                            };

    std::vector<double> results;

    std::cout << "Number of variables: " << n << std::endl;

    for(auto oracle: oracles){
        const auto oracle_name = std::get<0>(oracle);
        const auto oracle_type = std::get<1>(oracle);
        
        val_t obj;
        vec_t<val_t> vertex(n);

        auto start = high_resolution_clock::now();
        call_knapsack_oracle(scip, n, weights, capacity, direction, obj, vertex, 8, oracle_type);
        auto stop = high_resolution_clock::now();
        auto duration = duration_cast<nanoseconds>(stop - start);
        std::cout << oracle_name << ", optimum: "<< obj << ", runtime (nanoseconds): " << duration.count() << std::endl;
        results.push_back(obj);
    }

    const auto first = results[0];
    std::string status = "SUCCESS";
    for (const auto& val: results){
        if(!epseq(first, val)) status = "FAILED";
    }
    std::cout << "TEST " << status << std::endl;



}