#include "lmo_type.h"
#include "lmo.h"
#include "gpu/enumeration.cuh"
#include <limits>

#include <scip/cons_knapsack.h>

#ifdef HYBRID_ORACLE
#include <ispc.h>
#endif
#include <iostream>

double inner_call_oracle_enumeration_cpu(
   const int                  n,
   const vec_t<SCIP_Longint>& weights,
   const SCIP_Longint         capacity,
   const vec_t<double>&       direction,
   vec_t<double>&             vertex
   )
{
   const long long num_sols = 1ul << n;
   double best_obj = - -std::numeric_limits<double>::max();
   long long best_sol_idx = -1;

   for (long long sol_idx = 0; sol_idx < num_sols; sol_idx++)
   {
      double obj = 0.0;
      SCIP_Longint weight = 0;

      for (int item = 0; item<n; item++)
      {
         const int sol_val = (sol_idx & (1ul << item)) > 0;
         obj += sol_val * direction[item];
         weight += sol_val * weights[item];
      }

      if ( weight <= capacity && obj > best_obj )
      {
         best_obj = obj;
         best_sol_idx = sol_idx;
      }
   }
   assert( best_sol_idx >= 0 );

   // reconstruct the best vertex
   for (int item = 0; item < n; item++)
   {
      const int sol_val = (best_sol_idx & (1ul << item)) > 0;
      vertex[item] = sol_val;
   }

   return best_obj;
}

// uses 64bit
// double inner_call_oracle_ispc(
//     const int n,
//     const vec_t<SCIP_Longint> weights,
//     const SCIP_Longint capacity,
//     const vec_t<double> direction,
//     vec_t<double>& vertex
// ){
//    // AVX2 -> 4 x 64 bit
//    // Unforntunately, AVX512 does not support 64 bit integers
//    constexpr uint64_t GANG_SIZE = 8;
//    const long long num_vertices = 1ul << n;

//    // call the kernel
//    double max_obj[GANG_SIZE];
//    uint64_t max_ix[GANG_SIZE];

//    #pragma unroll
//    for(int i = 0; i < GANG_SIZE; ++i)
//    {
//       max_obj[i] = std::numeric_limits<double>::min();
//       max_ix[i] = std::numeric_limits<uint64_t>::max();
//    }

//    ispc::enumeration_cpu_kernel(
//       static_cast<uint64_t>(n),
//       static_cast<uint64_t>(num_vertices),
//       reinterpret_cast<const int64_t*>(weights.data()),
//       static_cast<int64_t>(capacity),
//       direction.data(),
//       max_obj,
//       max_ix);

//    // find the best solution and return its vertex
//    double best_obj = -std::numeric_limits<double>::min();

//    #pragma unroll
//    for(int i = 0; i < GANG_SIZE; ++i)
//    {
//       if(max_ix[i] != std::numeric_limits<uint64_t>::max() && max_obj[i] > best_obj)
//       {
//          best_obj = max_obj[i];
//          for(int j = 0; j < n; ++j)
//             vertex[j] = (max_ix[i] & (1ul << j)) > 0;
//       }
//    }

//    return best_obj;
// }

// uses 32bit
double inner_call_oracle_ispc(
   const int                  n,
   const vec_t<SCIP_Longint>& weights,
   const SCIP_Longint         capacity,
   const vec_t<double>&       direction,
   vec_t<double>&             vertex
   )
{
#ifdef HYBRID_ORACLE
   // AVX512 -> 16 x 32 bit
   // Unforntunately, AVX512 does not support 64 bit integers
   constexpr uint64_t GANG_SIZE = static_cast<uint64_t>(GANGSIZE);
   const long long num_vertices = 1ul << n;

   // call the kernel
   float max_obj[GANG_SIZE];
   uint32_t max_ix[GANG_SIZE];

   #pragma unroll
   for (int i = 0; i < (int) GANG_SIZE; ++i)
   {
      max_obj[i] = -std::numeric_limits<float>::max();
      max_ix[i] = std::numeric_limits<uint32_t>::max();
   }

   // type conversions
   constexpr int BUF_SIZE = 32;
   if ( n > BUF_SIZE )
      throw std::runtime_error("n must be less or equal to 32.");

   int32_t _weights[BUF_SIZE];
   float _direction[BUF_SIZE];

   #pragma unroll
   for (int i = 0; i < n; ++i)
   {
      _weights[i] = static_cast<int32_t>(weights[i]);
      _direction[i] = static_cast<float>(direction[i]);
   }

   ispc::enumeration_cpu_kernel(
      static_cast<uint32_t>(n),
      static_cast<uint32_t>(num_vertices),
      _weights,
      static_cast<int32_t>(capacity),
      _direction,
      max_obj,
      max_ix);

   // find the best solution and return its vertex
   float best_obj = -std::numeric_limits<float>::max();

   #pragma unroll
   for (int i = 0; i < (int) GANG_SIZE; ++i)
   {
      if ( max_ix[i] != std::numeric_limits<uint32_t>::max() && max_obj[i] > best_obj )
      {
         best_obj = max_obj[i];
         for (int j = 0; j < n; ++j)
            vertex[j] = (max_ix[i] & (1ul << j)) > 0;
      }
   }

   return static_cast<double>(best_obj);
#else
   return 0.0;
#endif
}

// reference implementation for maximizing a linear function given by direction over 0/1 knapsack defined by weights and capacity
SCIP_RETCODE inner_call_oracle_ref(
   SCIP*                      scip,
   const int                  dim,
   const vec_t<SCIP_Longint>& weights,
   const SCIP_Longint         capacity,
   const vec_t<double>&       direction,
   double&                    objval,
   vec_t<double>&             vertex
   )
{
   SCIP_Bool success;
   int* solitems;
   int* nonsolitems;
   int* items;
   int nsolitems;
   int nnonsolitems;
   int i;

   assert( scip != NULL );

   SCIP_CALL( SCIPallocBufferArray(scip, &items, dim) );
   for (i = 0; i < dim; ++i)
      items[i] = i;
   SCIP_CALL( SCIPallocBufferArray(scip, &solitems, dim) );
   SCIP_CALL( SCIPallocBufferArray(scip, &nonsolitems, dim) );

   SCIP_CALL( SCIPsolveKnapsackExactly(scip, dim, const_cast<SCIP_Longint*>(weights.data()), const_cast<double*>(direction.data()), capacity,
         items, solitems, nonsolitems, &nsolitems, &nnonsolitems, &objval, &success) );

   if ( success )
   {
      for (i = 0; i < dim; ++i)
         vertex[i] = 0.0;
      for (i = 0; i < nsolitems; ++i)
         vertex[solitems[i]] = 1.0;
   }

   SCIPfreeBufferArray(scip, &solitems);
   SCIPfreeBufferArray(scip, &nonsolitems);
   SCIPfreeBufferArray(scip, &items);

   return SCIP_OKAY;
}

// maximize linear function given by direction over 0/1 knapsack defined by weights and capacity
void call_knapsack_oracle(
   SCIP*                      scip,
   const int                  dim,
   const vec_t<SCIP_Longint>& weights,
   const SCIP_Longint         capacity,
   const vec_t<double>&       direction,
   double&                    objval,
   vec_t<double>&             vertex,
   const int                  threshold,
   oracle_t                   use_oracle
   )
{
#ifndef NDEBUG
   SCIP_RETCODE retcode;
#endif

   switch ( use_oracle )
   {
   case REF_ORACLE: // dynamic programming from SCIP
#ifndef NDEBUG
      retcode = inner_call_oracle_ref(scip, dim, weights, capacity, direction, objval, vertex);
      assert( retcode == SCIP_OKAY );
#else
      (void) inner_call_oracle_ref(scip, dim, weights, capacity, direction, objval, vertex);
#endif
      break;

   case ISPC_ORACLE: // ISPC CPU
      objval = inner_call_oracle_ispc(dim, weights, capacity, direction, vertex);
      break;

   case CUDA_ORACLE: // GPU
#ifdef GPU_ORACLE
      inner_call_oracle_gpu(dim, weights.data(), int(capacity), direction.data(), &objval, vertex.data());
#else
      throw std::runtime_error("CUDA not compiled.");
#endif
      break;

   case ENUMERATION_CPU: // CPU enumeration
      objval = inner_call_oracle_enumeration_cpu(dim, weights, capacity, direction, vertex);
      break;

   case CPU_HYBRID: // Hybrid ISPC-DYNAMIC PROGRAMMING with a threshold
      if ( dim <= threshold )
         objval = inner_call_oracle_ispc(dim, weights, capacity, direction, vertex);
      else
      {
#ifndef NDEBUG
         retcode = inner_call_oracle_ref(scip, dim, weights, capacity, direction, objval, vertex);
         assert( retcode == SCIP_OKAY );
#else
         (void) inner_call_oracle_ref(scip, dim, weights, capacity, direction, objval, vertex);
#endif
      }
      break;

   default:
      break;
   }
}
