#ifndef CSISEARCH_H
#define CSISEARCH_H

#include <unordered_map>
#include <vector>
#include <string>
#include <queue>
#include <chrono>
#include <stack>
#include <Rcpp.h>
#include "ldag.h"
#include "derivation.h"

using namespace std;

struct p {
    int a, b, c, d;
};

struct rindep {
    int xset, yset, zset, u, v;
};

struct output {
    p to, from, rp;
    rindep ri;
    bool valid, enumerate;
};

struct distr {
    int rule_num, index, score, pa1, pa2;
    bool primitive;
    p pp;
};

struct comp_distr {
    bool operator()(distr const * d1, distr const * d2) {
        return d1->score < d2->score;
    }
};

class csisearch {
public:
    csisearch(const int& n_, const int& con, const int& intv, const bool& dd, const bool& da, const bool& dea, const bool& fa, const bool& verb);
    Rcpp::List search_init();
    int n, index, lhs, con_vars, intv_vars;
    bool trivial_id, format_do;
    const bool draw_derivation;
    const bool draw_all;
    const bool derive_all;
    const bool formula;
    const bool verbose;
    p target;
    ldag *g;
    derivation *deriv;
    vector<distr> target_dist;
    vector<string> labels;
    vector<int> z_sets;
    vector<int> rules;
    unordered_map<int, distr> L;
    unordered_map<string, int> ps;
    stack<int> candidates;
    output info;
    vector<double> rule_times;
    virtual void add_distribution(distr& nquery);
    virtual void add_known(const int& a, const int& b, const int& c, const int& d);
    virtual distr& next_distribution(const int& i);
    void search();
    void set_target(const int& a, const int& b, const int& c, const int& d);
    void set_options(const vector<int>& r);
    void set_labels(const Rcpp::StringVector& lab);
    void set_graph(ldag* g_);
    void set_derivation(derivation* d_);
    bool is_primitive(const bool& pa1_primitive, const bool& pa2_primitive, const int& ruleid);
    void draw(const distr& dist, const bool& recursive, derivation& d);
    string derive_formula(distr& dist);
    string dec_to_text(const int& dec, const int& zero, const int& one) const;
    string to_string(const p& pp) const;
    string rule_name(const int& rule_num) const;
    string make_key(const p& pp) const;
    bool equal_p(const p& p1, const p& p2) const;
    bool valid_rule(const int& ruleid, const int& a, const int& b, const int& c, const int& d, const bool& primi) const;
    void apply_rule(const int& ruleid, const int& a, const int& b, const int& c, const int& d, const int& z, const int& z_ind);
    void derive_distribution(const distr& iquery, const distr& required, const int& ruleid, int& remaining, bool& found);
    void get_ruleinfo(const int& ruleid, const int& y, const int& x, const int& u, const int& v, const int& z);
    void get_candidate(distr& required, const int& req);
    void enumerate_candidates();
    virtual ~csisearch();
};

class csisearch_heuristic: public csisearch {
public:
    csisearch_heuristic(const int& n_, const int& con, const int& intv, const bool& dd, const bool& da, const bool& dea, const bool& fa, const bool& verb);
    void add_distribution(distr& nquery);
    void add_known(const int& a, const int& b, const int& c, const int& d);
    distr& next_distribution(const int& i);
    ~csisearch_heuristic();
private:
    int compute_score(const p& pp) const;
    priority_queue<distr*, std::vector<distr*>, comp_distr> Q;
};

#endif	/* FORMULASEARCH_H */

