#include "contextsearch.h"

using namespace std;

contextsearch::contextsearch(const int& n_, const int& con, const int& intv, const bool& verb):n(n_), con_vars(con), intv_vars(intv), verbose(verb) {
}

contextsearch::~contextsearch() {
}

void contextsearch::set_target(const int& a, const int& b) {
    int con = (a | b) & ~intv_vars;
    int intv = b & intv_vars;
    int i_set;
    vector<int> elems;
    int e = 0;
    for ( int i = 1; i <= n; i++ ) {
        i_set = unary(i);
        if ( (i_set & con) == i_set ) {
            elems.push_back(i_set);
            e++;
        }
    }
    for ( int i = 0; i <= full_set(e); i++ ) {
        p target;
        target.a = a;
        target.b = b;
        target.c = 0;
        target.d = intv;
        for ( int j = 1; j <= e; j++ ) {
            if ( (unary(j) & i) > 0) target.d += elems[j-1];
            else target.c += elems[j-1];
        }
        if ( verbose ) Rcpp::Rcout << "Adding target: " << to_string(target) << endl;
        found_targets[to_string(target)] = 0;
        targets.push_back(target);
    }
}

void contextsearch::set_options() {
    trivial_id = false;
    format_do = true;
    index = 0;
    lhs = 0;
    rules = {0, 5, -5, -4, 1, -1, 2, -2, 6, 7, -7, 3, -3};
    rule_times = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
}

void contextsearch::set_labels(const Rcpp::StringVector& lab) {
    labels = vector<string>(n);
    for ( int i = 0; i < n; i++ ) {
        labels[i] = lab(i);
    }
}

void contextsearch::set_graph(ldag* g_) {
    g = g_;
}

/*
void contextsearch::set_derivation(derivation* d_) {
    deriv = d_;
}
*/

void contextsearch::add_known(const int& a, const int& b) {
    int con = (a | b) & ~intv_vars;
    int i_set;
    vector<int> elems;
    int e = 0;
    for ( int i = 1; i <= n; i++ ) {
        i_set = unary(i);
        if ( (i_set & con) == i_set ) {
            elems.push_back(i_set);
            e++;
        }
    }
    for ( int i = 0; i <= full_set(e); i++ ) {
        index++;
        p pp;
        pp.a = a; pp.b = b; pp.c = 0; pp.d = 0;
        for ( int j = 1; j <= e; j++ ) {
            if ( (unary(j) & i) > 0) pp.d += elems[j-1];
            else pp.c += elems[j-1];
        }
        distr iquery;
        iquery.rule_num = 0;
        iquery.pp = pp;
        iquery.pa1 = 0;
        iquery.pa2 = 0;
        iquery.index = index;
        iquery.primitive = true;
        iquery.score = 0;
        L[index] = iquery;
        ps[make_key(pp)] = index;
        if ( verbose ) Rcpp::Rcout << "Adding known distribution: " << to_string(pp) << endl;
        any_target(pp, index);
        if ( all_targets() ) {
            trivial_id = true;
        }
        lhs = lhs | a;
    }
}

void contextsearch::any_target(const p& pp, const int& ix) {
    for ( auto target : targets ) {
        if ( equal_p(pp, target) ) {
            found_targets[to_string(target)] = ix;
        }
    }
}

bool contextsearch::all_targets() {
    for ( auto ix : found_targets ) {
        if ( ix.second == 0 ) return false;
    }
    return true;
}

Rcpp::List contextsearch::search_init() {

    info.to.a   = 0; info.to.b   = 0; info.to.c   = 0; info.to.d   = 0;
    info.from.a = 0; info.from.b = 0; info.from.c = 0; info.from.d = 0;
    info.rp.a   = 0; info.rp.b   = 0; info.rp.c   = 0; info.rp.d   = 0;
    info.ri.xset = 0; info.ri.yset = 0; info.ri.zset = 0; info.ri.u = 0; info.ri.v = 0;
    info.valid = false; info.enumerate = false;

//    unsigned ntarget = targets.size();
//    string formula;
//    string derivation;
//    string full_derivation = "";

    bool trivial = true;
    if ( (lhs & targets[0].a) == targets[0].a ) {
        trivial = FALSE;
    }
    if ( trivial ) {
        return Rcpp::List::create(
            Rcpp::Named("identifiable") = false,
//          Rcpp::Named("formula") = formula,
//          Rcpp::Named("derivation") = derivation,
//          Rcpp::Named("full_derivation") = full_derivation,
            Rcpp::Named("time") = 0,
            Rcpp::Named("rule_times") = rule_times
        );
    }

    z_sets = get_subsets(n);
    std::chrono::duration<double, std::milli> total_time;

    auto t1 = std::chrono::high_resolution_clock::now();
    auto t2 = std::chrono::high_resolution_clock::now();

    if ( !trivial_id ) {
        t1 = std::chrono::high_resolution_clock::now();
        search();
        t2 = std::chrono::high_resolution_clock::now();
    }
    total_time = t2 - t1;

/*
    for ( unsigned int i = 0; i < ntarget; i++ ) {
        if ( formula ) {
            formulas[i] = derive_formula(L[found_targets[i]]);
        }
        if ( draw_derivation && !draw_all ) {
            derivation temp_deriv;
            temp_deriv.init();
            draw(target_dist[i], TRUE, temp_deriv);
            temp_deriv.finish();
            derivations[i] = temp_deriv.get();
        }
    }

    if ( draw_derivation && draw_all ) {
        deriv->init();
        for ( const auto &d : L ) {
            draw(d.second, FALSE, *deriv);
        }
        for ( auto &target : target_dist ) {
            draw(target, FALSE, *deriv);
        }
        deriv->finish();
        full_derivation = deriv->get();
    }
*/

    return Rcpp::List::create(
        Rcpp::Named("identifiable") = all_targets(),
//        Rcpp::Named("formula") = formulas,
//        Rcpp::Named("derivation") = derivations,
//        Rcpp::Named("full_derivation") = full_derivation,
        Rcpp::Named("time") = total_time.count(),
        Rcpp::Named("rule_times") = rule_times
    );

}

distr& contextsearch::next_distribution(const int& i) {
    return L[i];
}


void contextsearch::search() {

    distr required;
    string con_key;
    bool found = false;
    bool primi = true;
    unsigned int i = 1;
    unsigned int z_lim;
    unsigned int z_size = z_sets.size();
    int a, b, c, d, z, u, v, iv, cd, req, ruleid, exist;
    int remaining = L.size();
    chrono::high_resolution_clock::time_point t1, t2, t3;
    chrono::high_resolution_clock::time_point start;
    chrono::duration<double, std::milli> ms;
    chrono::duration<double, std::ratio<3600>> total;

    start = chrono::high_resolution_clock::now();

    while ( remaining > 0 && !found ) {

        distr& iquery = next_distribution(i);
        remaining--;

        a = iquery.pp.a;
        b = iquery.pp.b;
        c = iquery.pp.c;
        d = iquery.pp.d;
        primi = iquery.primitive;

        for ( unsigned int r = 0; r < rules.size(); r++ ) {

            t1 = chrono::high_resolution_clock::now();

            ruleid = rules[r];

            if ( !valid_rule(ruleid, a, b, c, d, primi) ) continue;

            if ( ruleid == 0 || (ruleid * ruleid) > 25 ) z_lim = n;
            else z_lim = z_size;

            for ( unsigned int z_ind = 0; z_ind < z_lim; z_ind++ ) {

                t3 = chrono::high_resolution_clock::now();
                total = t3 - start;
                if ( total.count() > 0.5 ) return;

                required.primitive = true;
                z = z_sets[z_ind];

                apply_rule(ruleid, a, b, c, d, z, z_ind);
                if ( (info.to.c & (info.to.a | info.to.b)) != info.to.c )  Rcpp::Rcout << "Invalid distribution derived by rule: " << ruleid << endl;
                if ( (info.to.d & (info.to.a | info.to.b)) != info.to.d )  Rcpp::Rcout << "Invalid distribution derived by rule: " << ruleid << endl;

                if ( !info.valid ) continue;

                if ( info.enumerate ) {
                    enumerate_candidates();
                    cd = candidates.size();
                    while ( cd > 0 && !found ) {
                        cd--;
                        get_candidate(required, candidates.top());
                        candidates.pop();
                        info.to.c = required.pp.c;
                        info.to.d = required.pp.d;
                        exist = ps[make_key(info.to)];
                        if ( exist == 0 ) derive_distribution(iquery, required, ruleid, remaining, found);
                    }
                } else {
                    exist = ps[make_key(info.to)];
                    if ( exist > 0 ) continue;
                    if ( info.ri.xset > 0 ) {
                        u = info.ri.u & con_vars;
                        v = info.ri.v & con_vars;
                        iv = info.ri.v & intv_vars;
                        if ( !g->csi_criterion(info.ri.xset, info.ri.yset, info.ri.zset, u, v, iv, u | v) ) continue;
                    }
                    if ( info.rp.a > 0 ) {
                        req = ps[make_key(info.rp)];
                        if ( req > 0 ) {
                            get_candidate(required, req);
                            derive_distribution(iquery, required, ruleid, remaining, found);
                        }
                    } else {
                        derive_distribution(iquery, required, ruleid, remaining, found);
                    }
                }

                if ( found ) break;

            } // for z

            t2 = chrono::high_resolution_clock::now();
            ms = t2 - t1;
            rule_times[r] += ms.count();

            if ( found ) break;

        } // for ruleid

        i++;

        /*
        if (i > imax) {
            found = false;
            Rcpp::Rcout << "Breaking the infinite loop!" << endl;
            break;
        } */

    } // while

}

void contextsearch::derive_distribution(const distr& iquery, const distr& required, const int& ruleid, int& remaining, bool& found) {
    index++;
    distr nquery;
    nquery.pp = info.to;
    nquery.primitive = is_primitive(iquery.primitive, required.primitive, ruleid);
    nquery.pa1 = iquery.index;
    nquery.pa2 = 0;
    nquery.rule_num = ruleid;
    nquery.index = index;

    if ( info.rp.a > 0 ) {
        nquery.pa2 = required.index;
    }

    any_target(nquery.pp, index);
    add_distribution(nquery);

    if ( all_targets() ) {
        if ( verbose ) {
            Rcpp::Rcout << "!!!! Managed to hit the target !!!!" << endl;
            Rcpp::Rcout << "index = " << index << endl;
        }
        found = true;
    } else {
        if ( verbose ) {
            if ( info.rp.a > 0 ) Rcpp::Rcout << "Derived: " << to_string(info.to) << " from " << to_string(info.from) << " and " << to_string(required.pp) << " using rule: " << std::to_string(ruleid) << endl;
            else Rcpp::Rcout << "Derived: " << to_string(info.to) << " from " << to_string(info.from) << " using rule: " << std::to_string(ruleid) << endl;
        }
        remaining++;
    }
}

void contextsearch::add_distribution(distr& nquery) {
    L[index] = nquery;
    ps[make_key(nquery.pp)] = index;
}

void contextsearch::enumerate_candidates() {
    int acon = info.rp.a - (info.rp.a & info.from.a);
    int exist = ps[make_key(info.rp)];
    if ( exist > 0 ) {
        candidates.push(exist);
    }
    if ( acon > 0 ) {
        int u_inc, v_inc, i_set;
        p rq;
        rq.a = info.rp.a;
        rq.b = info.rp.b;
        vector<int> elems;
        vector<int> total;
        int e = 0;
        for ( int i = 1; i <= n; i++ ) {
            i_set = unary(i);
            if ( (i_set & acon) == i_set ) {
                elems.push_back(i_set);
                e++;
            }
        }
        for ( int i = 0; i <= full_set(e); i++ ) {
            v_inc = 0;
            u_inc = 0;
            for ( int j = 1; j <= e; j++ ) {
                if ( (unary(j) & i) > 0) v_inc += elems[j-1];
                else u_inc += elems[j-1];
            }
            rq.c = info.rp.c + u_inc;
            rq.d = info.rp.d + v_inc;
            exist = ps[make_key(rq)];
            if ( exist > 0 ) {
                candidates.push(exist);
            }
        }
    }
}

void contextsearch::get_candidate(distr& required, const int& req) {
    distr& reqd = L[req];
    required.pp = reqd.pp;
    required.primitive = reqd.primitive;
    required.pa1 = reqd.pa1;
    required.pa2 = reqd.pa2;
    required.index = req;
    required.rule_num = reqd.rule_num;
    required.score = reqd.score;
}

bool contextsearch::is_primitive(const bool& pa1_primitive, const bool& pa2_primitive, const int& ruleid) {
    // if ( pa1_primitive && pa2_primitive ) {
    //     int r = ruleid * ruleid;
    //     if ( r == 1 || r == 16 ) return false;
    //     return true;
    // }
    return false;
}

/*
string contextsearch::derive_formula(distr& dist) {
    string formula = "";
    if ( dist.pa1 > 0 ) {
        int r = dist.rule_num * dist.rule_num;
        distr& pa1 = L[dist.pa1];
        string paf1 = derive_formula(pa1);
        if ( dist.pa2 > 0 ) {
            distr& pa2 = L[dist.pa2];
            string paf2 = derive_formula(pa2);
            if ( dist.primitive ) formula = to_string(dist.pp);
            else {
                if ( r == 4 ) {
                    if ( paf1.length() < paf2.length() ) formula = paf1 + "*" + paf2;
                    else formula = paf2 + "*" + paf1;
                } else if ( r == 25 ) {
                    formula = paf1 + " /\\ " + paf2;
                } else if ( r == 36 ) {
                    formula = "[" + paf2 + " - " + paf1 + "]";
                } else if ( r == 49 ) {
                    formula = "[" + paf1 + " - " + paf2 + "]";
                }
            }
        } else {
            if ( r == 9 || r == 16 ) {
                formula = paf1;
            } else {
                if ( dist.primitive ) formula = to_string(dist.pp);
                else {
                    if ( r == 0 ) {
                        formula =  "[sum_{" + dec_to_text(pa1.pp.a - dist.pp.a, 0, 0) + "}" + paf1 + "]";
                    } else if ( r == 1 ) {
                        formula = "[[" + paf1 + "]/[sum_{" + dec_to_text(dist.pp.a, 0, 0) + "} " + paf1 + "]]";
                    } else if ( r == 64 ) {
                        formula = paf1;
                    }
                }
            }
        }
    } else {
        formula = to_string(dist.pp);
    }
    return formula;
}



string contextsearch::rule_name(const int& rule_num) const {

    switch ( rule_num ) {

        case 0  : return "M";
        case 1  : return "C";
        case -1  : return "C";
        case 2  : return "C";
        case -2  : return "C";
        case 3  : return "P";
        case -3 : return "P";
        case -4 : return "I-";
        case 5  : return "I+0";
        case -5 : return "I+1";

    }

    return "";
}

void contextsearch::draw(const distr& dist, const bool& recursive, derivation& d) {
    if ( dist.pa1 > 0 ) {
        distr& pa1 = L[dist.pa1];
        d.add_edge(to_string(pa1.pp), to_string(dist.pp), rule_name(dist.rule_num));
        if ( recursive ) draw(pa1, recursive, d);
        if ( dist.pa2 > 0 ) {
            distr& pa2 = L[dist.pa2];
            d.add_edge(to_string(pa2.pp), to_string(dist.pp), rule_name(dist.rule_num));
            if ( recursive ) draw(pa2, recursive, d);
        }
    }
}
*/

string contextsearch::dec_to_text(const int& dec, const int& zero, const int& one) const {
    if ( dec == 0 ) return("");
    string s = "";
    int first = 0;
    for ( int i = 1; i <= n; i++ ) {
        if ( in_set(i, dec) ) {
            first = i;
            if ( in_set(i, zero) ) s += labels[i-1] + " = 0";
            else if ( in_set(i, one) ) s += labels[i-1] + " = 1";
            else s += labels[i-1];
            break;
        }
    }
    if ( first > 0 ) {
        for ( int i = first + 1; i <= n; i++) {
            if ( in_set(i, dec) ) {
                s += ",";
                if ( in_set(i, zero) ) s += labels[i-1] + " = 0";
                else if ( in_set(i, one) ) s += labels[i-1] + " = 1";
                else s += labels[i-1];
            }
        }
    }
    return s;
}

string contextsearch::make_key(const p& pp) const {
    return std::to_string(pp.a) + "," + std::to_string(pp.b) + "," + std::to_string(pp.c) + "," + std::to_string(pp.d);
}

string contextsearch::to_string(const p& pp) const {
    int a = pp.a;
    int b = pp.b;
    int c = pp.c;
    int d = pp.d;
    string s = "";

    s += "p(" + dec_to_text(a, a & c, a & d);
    if ( b != 0 ) {
      s += "|" + dec_to_text(b, b & c, b & d);
    }
    s += ")";

    return s;
}

bool contextsearch::equal_p(const p& pp1, const p& pp2) const {
    return (pp1.a == pp2.a) && (pp1.b == pp2.b) && (pp1.c == pp2.c) && (pp1.d == pp2.d);
}

bool contextsearch::valid_rule(const int& ruleid, const int& a, const int& b, const int& c, const int& d, const bool& primi) const {

    switch ( ruleid ) {

        // Marginalization
        case 0 : {
            // there must be other variables
            if ( set_size(a) == 1 ) return false;
            else return true;
        }

        // Conditioning
        case 1 : {
            // there must be other variables
            if ( set_size(a) == 1 ) return false;
            else return true;
        }

        // Conditioning
        case -1 : {
            // there must be other variables
            if ( set_size(a) == 1 ) return false;
            else return true;
        }

        // Conditioning
        case -2 : {
            // there must be conditioning variables
            if ( set_size(b) == 1 ) return false;
            else return true;
        }


        // Product rule
        case 3 : {
            // there must be conditioning variables
            if ( b == 0 ) return false;
            else return true;
        }

        // Deletion of observations
        case -4 : {
            // there must be observations to delete
            if ( b == 0 ) return false;
            else return true;
        }

        default : {
            return true;
        }
    }

    return true;
}

void contextsearch::apply_rule(const int &ruleid, const int &a, const int &b, const int &c, const int &d, const int &z, const int &z_ind) {

    info.valid = false;

    switch ( ruleid ) {

        // Marginalisation
        case 0 : {

            if ( (z & a) != z ) return;
            if ( z == a ) return;

            break;

        }

        // Conditioning (Numerator)
        case 1 : {

            if ( (z & a) != z ) return;
            if ( z == a ) return;
            if ( (z & lhs) != z ) return;

            break;

        }

        // Conditioning (Numerator)
        case -1 : {

            if ( (z & a) != z ) return;
            if ( z == a ) return;
            if ( (z & lhs) != z ) return;

            break;

        }

        // Conditioning (Denominator)
        case 2 : {

            if ( (z & a) != 0 ) return;
            if ( (z & b) != 0 ) return;
            if ( (z & lhs) != z ) return;

            break;

        }

        // Conditioning (Denominator)
        case -2 : {

            if ( (z & b) != z ) return;
            if ( (z & lhs) != z ) return;

            break;

        }

        // Product rule
        case 3 : {

            if ( (z & b) != z ) return;
            if ( (z & lhs) != z ) return;

            break;

        }

        // Product rule
        case -3 : {

            if ( (z & a) != 0 ) return;
            if ( (z & b) != 0 ) return;
            if ( (z & lhs) != z ) return;

            break;

        }

        // Deletion of observations
        case -4 : {

            if ( (z & a) != 0 ) return;
            if ( (z & b) != z ) return;

            break;

        }

        // Insertion of observations (Z = 0)
        case 5 : {

            if ( (z & a) != 0 ) return;
            if ( (z & b) != 0 ) return;
            if ( (z & intv_vars) != 0 ) return;

            break;

        }

        // Insertion of observations (Z = 1)
        case -5 : {

            if ( (z & a) != 0 ) return;
            if ( (z & b) != 0 ) return;

            break;

        }

        // General-by-case reasoning (RHS)
        case 6 : {

            if ( (z & a) != z ) return;
            if ( (z & c) != z && (z & d) != z ) return;
            if ( (z & lhs) != z ) return;

            break;

        }

        // General-by-case reasoning (LHS, Z = 0)
        case 7 : {

            if ( (z & a) != 0 ) return;
            if ( (z & b) != 0 ) return;
            if ( (z & lhs) != z ) return;

            break;

        }

        // General-by-case reasoning (LHS, Z = 1)
        case -7 : {

            if ( (z & a) != 0 ) return;
            if ( (z & b) != 0 ) return;
            if ( (z & lhs) != z ) return;
            break;

        }

    }

    info.valid = true;
    get_ruleinfo(ruleid, a, b, c, d, z);

}


void contextsearch::get_ruleinfo(const int& ruleid, const int& y, const int& x, const int& u, const int& v, const int& z) {

    info.from.a = y;     info.from.b = x; info.from.c = u;     info.from.d = v;

    switch ( ruleid ) {

        // Marginalisation
        case 0 : {

            info.to.a = y - z; info.to.b = x; info.to.c = u - (u & z);           info.to.d = v - (v & z);
            info.rp.a = y;     info.rp.b = x; info.rp.c = u - (u & z) + (v & z); info.rp.d = v - (v & z) + (u & z);

            info.ri.xset = 0;
            info.enumerate = false;

            return;

        }

        // Conditioning (Numerator)
        case 1 : {

            info.to.a = y - z; info.to.b = x + z; info.to.c = u;             info.to.d = v;
            info.rp.a = z;     info.rp.b = x;     info.rp.c = ((z | x) & u); info.rp.d = ((z | x) & v);

            info.ri.xset = 0;
            info.enumerate = false;

            return;

        }

        // Conditioning (Numerator)
        case -1 : {

            info.to.a = y - z; info.to.b = x;           info.to.c = u - (u & z); info.to.d = v - (v & z);
            info.rp.a = z;     info.rp.b = x + (y - z); info.rp.c = u;           info.rp.d = v;

            info.ri.xset = 0;
            info.enumerate = false;

            return;

        }

        // Conditioning (Denominator)
        case 2 : {

            info.to.a = z;     info.to.b = x + y; info.to.c = u; info.to.d = v;
            info.rp.a = y + z; info.rp.b = x;     info.rp.c = u; info.rp.d = v;

            info.ri.xset = 0;
            info.enumerate = true;

            return;

        }

        // Conditioning (Denominator)
        case -2 : {

            info.to.a = z;     info.to.b = x - z; info.to.c = u - (u & y); info.to.d = v - (v & y);
            info.rp.a = y + z; info.rp.b = x - z; info.rp.c = u;           info.rp.d = v;

            info.ri.xset = 0;
            info.enumerate = false;

            return;

        }

        // Product rule
        case 3 : {

            info.to.a = y + z; info.to.b = x - z; info.to.c = u;     info.to.d = v;
            info.rp.a = z;     info.rp.b = x - z; info.rp.c = u & x; info.rp.d = v & x;

            info.ri.xset = 0;
            info.enumerate = false;

            return;

        }

        // Product rule
        case -3 : {

            info.to.a = y + z; info.to.b = x;     info.to.c = u; info.to.d = v;
            info.rp.a = z;     info.rp.b = x + y; info.rp.c = u; info.rp.d = v;

            info.ri.xset = 0;
            info.enumerate = true;

            return;

        }

        // Deletion of observations
        case -4 : {

            info.to.a = y; info.to.b = x - z; info.to.c = u - (u & z); info.to.d = v - (v & z);
            info.rp.a = 0;

            info.ri.yset = y; info.ri.xset = z; info.ri.zset = x - z;
            info.ri.u = (u - (u & z)) & x; info.ri.v = (v - (v & z)) & x;
            info.enumerate = false;

            return;

        }

        // Insertion of observations (Z = 0)
        case 5 : {

            info.to.a = y; info.to.b = x + z; info.to.c = u + z; info.to.d = v;
            info.rp.a = 0;

            info.ri.yset = y; info.ri.xset = z; info.ri.zset = x;
            info.ri.u = u & x; info.ri.v = v & x;
            info.enumerate = false;

            return;

        }

        // Insertion of observations (Z = 1)
        case -5 : {

            info.to.a = y; info.to.b = x + z; info.to.c = u; info.to.d = v + z;
            info.rp.a = 0;

            info.ri.yset = y; info.ri.xset = z; info.ri.zset = x;
            info.ri.u = u & x; info.ri.v = v & x;
            info.enumerate = false;

            return;

        }

        // General-by-case reasoning (RHS)
        case 6 : {

            info.to.a = y;     info.to.b = x; info.to.c = u - (u & z) + (v & z); info.to.d = v - (v & z) + (u & z);
            info.rp.a = y - z; info.rp.b = x; info.rp.c = u - (u & z);           info.rp.d = v - (v & z);

            info.ri.xset = 0;
            info.enumerate = false;

            return;
        }

        // General-by-case reasoning (LHS, Z = 0)
        case 7 : {

            info.to.a = y + z; info.to.b = x; info.to.c = u;     info.to.d = v + z;
            info.rp.a = y + z; info.rp.b = x; info.rp.c = u + z; info.rp.d = v;

            info.ri.xset = 0;
            info.enumerate = false;

            return;
        }

        // General-by-case reasoning (LHS, Z = 1)
        case -7 : {

            info.to.a = y + z; info.to.b = x; info.to.c = u + z; info.to.d = v;
            info.rp.a = y + z; info.rp.b = x; info.rp.c = u;     info.rp.d = v + z;

            info.ri.xset = 0;
            info.enumerate = false;

            return;
        }

    }
}

// contextsearch_heuristic

contextsearch_heuristic::contextsearch_heuristic(const int& n_,  const int& con, const int& intv, const bool& verb):contextsearch(n_, con, intv,  verb) {
}

contextsearch_heuristic::~contextsearch_heuristic() {

}

distr& contextsearch_heuristic::next_distribution(const int& i) {
    distr& top = *Q.top();
    Q.pop();
    return top;
}

void contextsearch_heuristic::add_distribution(distr& nquery) {
    nquery.score = compute_score(nquery.pp);
    L[index] = nquery;
    ps[make_key(nquery.pp)] = index;
    Q.push(&L[index]);
}

void contextsearch_heuristic::add_known(const int& a, const int& b) {
    int con = (a | b) & ~intv_vars;
    int i_set;
    vector<int> elems;
    int e = 0;
    for ( int i = 1; i <= n; i++ ) {
        i_set = unary(i);
        if ( (i_set & con) == i_set ) {
            elems.push_back(i_set);
            e++;
        }
    }
    for ( int i = 0; i <= full_set(e); i++ ) {
        index++;
        p pp;
        pp.a = a; pp.b = b; pp.c = 0; pp.d = 0;
        for ( int j = 1; j <= e; j++ ) {
            if ( (unary(j) & i) > 0) pp.d += elems[j-1];
            else pp.c += elems[j-1];
        }
        distr iquery;
        iquery.rule_num = 0;
        iquery.pp = pp;
        iquery.pa1 = 0;
        iquery.pa2 = 0;
        iquery.index = index;
        iquery.primitive = true;
        iquery.score = compute_score(pp);
        L[index] = iquery;
        ps[make_key(pp)] = index;
        Q.push(&L[index]);
        if ( verbose ) Rcpp::Rcout << "Adding known distribution: " << to_string(pp) << endl;
        any_target(pp, index);
        if ( all_targets() ) {
            trivial_id = true;
        }
        lhs = lhs | a;
    }
}

// Heuristic for search order
int contextsearch_heuristic::compute_score(const p& pp) const {
    int score = 0;
    int common_a = pp.a & targets[0].a;
    int common_b = pp.b & targets[0].b;

    score += 10 * set_size(common_a);
    score += 5 * set_size(common_b);

    score -= 2 * set_size(targets[0].a - common_a);
    score -= 2 * set_size(pp.b - common_b);
    score -= 2 * set_size(targets[0].b - common_b);

    return score;
}