
import java.io.*;
import java.util.*;

/**************************************************************************************************************************************
 * Class DecisionTree
 *****/

class DecisionTree implements Debuggable{
    public static boolean DISPLAY_INTERNAL_NODES_CLASSIFICATION = false;
    
    int name, depth;
    
    DecisionTreeNode root;
    // root of the tree

    Vector leaves;
    // list of leaves of the tree (potential growth here)

    Boost myBoost;

    int max_size;

    int split_CV;

    int number_nodes;
    //includes leaves

    double tree_p_t, tree_psi_t, tree_p_t_star;

    DecisionTree(int nn, Boost bb, int maxs, int split){
	name = nn;
	root = null;
	myBoost = bb;

	max_size = maxs;

	split_CV = split;

	tree_p_t = tree_psi_t = tree_p_t_star = -1.0;
    }

    
    public String toString(){
	int i;
	String v = "(name = #" + name + " | depth = " + depth + " | #nodes = " + number_nodes + ")\n";
	DecisionTreeNode dumn;
	
	v += root.display(new HashSet <Integer> ());
	
        v += "Leaves:";

	Iterator it = leaves.iterator();
	while(it.hasNext()){
	    v += " ";
	    dumn = (DecisionTreeNode) it.next();
	    v += "#" + dumn.name + dumn.observations_string();
	}
	v += ".\n";
	
	return v;
    }

    public static void INSERT_WEIGHT_ORDER(Vector <DecisionTreeNode> all_nodes, DecisionTreeNode nn){
	int k;
	
	if (all_nodes.size() == 0)
	    all_nodes.addElement(nn);
	else{
	    k = 0;
	    while( (k < all_nodes.size()) && (nn.train_fold_indexes_in_node.length < all_nodes.elementAt(k).train_fold_indexes_in_node.length) )
		k++;
	    all_nodes.insertElementAt(nn, k);
	}
    }

    public void grow_heavy_first(){
	// grows the heaviest grow-able leaf
	
	Vector vin, vvv;
	boolean stop = false;
	DecisionTreeNode nn, leftnn, rightnn;

	int i, j, k, l, ibest, pos_left, neg_left, pos_right, neg_right, nsplits = 0;
	double wpos_left, wneg_left, wpos_right, wneg_right;
	
	Vector <DecisionTreeNode> try_leaves = null; // leaves that will be tried to grow
	Vector candidate_split, dumv;

	try_leaves = new Vector<>();
	for (j = 0;j < leaves.size();j++)
	    DecisionTree.INSERT_WEIGHT_ORDER(try_leaves, (DecisionTreeNode) leaves.elementAt(j));
	
	do{
	    do{
		nn = ((DecisionTreeNode) try_leaves.elementAt(0));
		candidate_split = nn.bestSplit();
		
		if (candidate_split == null)
		    try_leaves.removeElementAt(0);
	    }while ( (try_leaves.size() > 0) && (candidate_split == null) );

	    if (candidate_split == null)
		stop = true;
	    else{
		vin = candidate_split;
		    
		nn.is_leaf = false;
		try_leaves.removeElementAt(0);
		
		nn.feature_node_index = ((Integer) vin.elementAt(2)).intValue();
		nn.feature_node_test_index = ((Integer) vin.elementAt(3)).intValue();
		
		pos_left = ((Integer) vin.elementAt(6)).intValue();
		neg_left = ((Integer) vin.elementAt(7)).intValue();
		pos_right = ((Integer) vin.elementAt(8)).intValue();
		neg_right = ((Integer) vin.elementAt(9)).intValue();
		
		wpos_left = ((Double) vin.elementAt(10)).doubleValue();
		wneg_left = ((Double) vin.elementAt(11)).doubleValue();
		wpos_right = ((Double) vin.elementAt(12)).doubleValue();
		wneg_right = ((Double) vin.elementAt(13)).doubleValue();
		
		number_nodes++;
		leftnn = new DecisionTreeNode(this, number_nodes, nn.depth + 1, split_CV, (Vector) vin.elementAt(4), pos_left, neg_left, wpos_left, wneg_left);

		leftnn.compute_prediction();

		leftnn.continuous_features_indexes_for_split_copy_from(nn);
		leftnn.continuous_features_indexes_for_split_update_child(nn.feature_node_index, nn.feature_node_test_index, DecisionTreeNode.LEFT_CHILD);

		DecisionTree.INSERT_WEIGHT_ORDER(try_leaves, leftnn);
		
		number_nodes++;
		rightnn = new DecisionTreeNode(this, number_nodes, nn.depth + 1, split_CV, (Vector) vin.elementAt(5), pos_right, neg_right, wpos_right, wneg_right);

		rightnn.compute_prediction();
		    
		rightnn.continuous_features_indexes_for_split_copy_from(nn);
		rightnn.continuous_features_indexes_for_split_update_child(nn.feature_node_index, nn.feature_node_test_index, DecisionTreeNode.RIGHT_CHILD);

		DecisionTree.INSERT_WEIGHT_ORDER(try_leaves, rightnn);
		
		if (nn.depth+1 > depth)
		    depth = nn.depth+1;
		    
		nn.left_child = leftnn;
		nn.right_child = rightnn;

		nsplits++;
	    }
	    if (number_nodes >= max_size)
		stop = true;
	    
	}while(!stop);

	// updates leaves in tree
	leaves = new Vector();
	for (j = 0;j < try_leaves.size();j++)
	    leaves.addElement(try_leaves.elementAt(j));
    }

    public void init(){
	int i, ne = myBoost.myDomain.myDS.train_size(split_CV);
	Vector indexes = new Vector();
	Example e;
	int pos = 0, neg = 0;
	double wpos = 0.0, wneg = 0.0, alpha_leaf;
	
	for (i=0;i<ne;i++){
	    indexes.addElement(new Integer(i));
	    e = myBoost.myDomain.myDS.train_example(split_CV, i);
	    if (e.is_positive_noisy()){
		pos++;
		wpos += e.current_boosting_weight;
	    }else{
		neg++;
		wneg += e.current_boosting_weight;
	    }
	}

	number_nodes = 1;
	
	root = new DecisionTreeNode(this, number_nodes, 0, split_CV, indexes, pos, neg, wpos, wneg);
	root.init_continuous_features_indexes_for_split();
	depth = 0;

	root.compute_prediction();

	leaves = new Vector();
	leaves.addElement(root);
    }

    public DecisionTreeNode get_leaf(Example ee){
	//returns the leaf reached by the example
	DecisionTreeNode nn = root;
	Feature f;
	while(!nn.is_leaf){
	    f = (Feature) myBoost.myDomain.myDS.features.elementAt(myBoost.myDomain.myDS.index_observation_features_to_index_features[nn.feature_node_index]);
	    if (f.example_goes_left(ee, nn.feature_node_index, nn.feature_node_test_index))
		nn = nn.left_child;
	    else
		nn = nn.right_child;
	}
	return nn;
    }

    public DecisionTreeNode get_leaf_MonotonicTreeGraph(Example ee){
	//returns the monotonic node reached by the example
	// (builds a strictly monotonic path of nodes to a leaf, used the last one in the path; this is a prediction node in the corresponding MonotonicTreeGraph)
	
	DecisionTreeNode nn = root;
	double best_prediction = Math.abs(nn.node_prediction_from_boosting_weights);
	DecisionTreeNode ret = root;
	
	Feature f;
	while(!nn.is_leaf){
	    if (Math.abs(nn.node_prediction_from_boosting_weights) > best_prediction){
		best_prediction = Math.abs(nn.node_prediction_from_boosting_weights);
		ret = nn;
	    }
	    
	    f = (Feature) myBoost.myDomain.myDS.features.elementAt(myBoost.myDomain.myDS.index_observation_features_to_index_features[nn.feature_node_index]);
	    if (f.example_goes_left(ee, nn.feature_node_index, nn.feature_node_test_index))
		nn = nn.left_child;
	    else
		nn = nn.right_child;
	}

	return ret;
    }

    
    public double leveraging_mu(){
	if (myBoost.name.equals(Boost.KEY_NAME_LOG_LOSS))
	    return leveraging_mu_log_loss();
	else
	    Dataset.perror("DecisionTree.class :: no loss " + myBoost.name);

	return -1.0;
    }
    
    public double leveraging_mu_log_loss(){
	int i, ne = myBoost.myDomain.myDS.train_size(split_CV);
	double rho_j = 0.0, max_absolute_pred = -1.0, output_e, tot_weight = 0.0, mu_j;
	Example e;

	for (i=0;i<ne;i++){
	    e = myBoost.myDomain.myDS.train_example(split_CV, i);
	    output_e = output_boosting(e);
	    if ( (i==0) || (Math.abs(output_e) >  max_absolute_pred) )
		max_absolute_pred = Math.abs(output_e);
	    
	    rho_j += ( e.current_boosting_weight * output_e * e.noisy_normalized_class );
	    tot_weight += e.current_boosting_weight;
	}
	rho_j /= (tot_weight * max_absolute_pred);

	tree_p_t = (1.0 + rho_j) / 2.0;
	tree_psi_t = Math.log((1.0 + rho_j)/(1.0 - rho_j));

	tree_p_t_star = 1.0 / (1.0 + Math.exp(-max_absolute_pred));

	if (tree_psi_t < 0.0)
	    Dataset.perror("DecisionTree.class :: negative value tree_psi_t = " + tree_psi_t);

	mu_j = (Math.log((1.0 + rho_j)/(1.0 - rho_j))) / max_absolute_pred;
	return mu_j;
    }

    public double leveraging_alpha(double mu, Vector <Double> allzs){
	// log-loss: no scaling
	return mu;
    }

    public double output_boosting(Example ee){
	DecisionTreeNode nn = get_leaf(ee);

	nn.checkForOutput();
	
	return (nn.node_prediction_from_boosting_weights);
    }
    
    public double output_boosting_MonotonicTreeGraph(Example ee){
	DecisionTreeNode nn = get_leaf_MonotonicTreeGraph(ee);

	nn.checkForOutput_MonotonicTreeGraph();
	
	return (nn.node_prediction_from_boosting_weights);
    }
    
    public double unweighted_edge_training(Example ee){
	// return y * this(ee) ; USE NOISY CLASS (if no noise, just the regular class)
	return ( (output_boosting(ee)) * (ee.noisy_normalized_class) );
    }

    public double unweighted_edge_training_MonotonicTreeGraph(Example ee){
	// return y * this(ee) ; USE NOISY CLASS (if no noise, just the regular class)
	return ( (output_boosting_MonotonicTreeGraph(ee)) * (ee.noisy_normalized_class) );
    }

    // safe checks methods
    
    public void check(){
	int i, j, ne = myBoost.myDomain.myDS.train_size(split_CV);
	DecisionTreeNode leaf1, leaf2;
	Example ee;
	
	for (i=0;i<leaves.size();i++){
	    leaf1 = (DecisionTreeNode) leaves.elementAt(i);
	    for (j=0;j<leaf1.train_fold_indexes_in_node.length;j++){
		ee = myBoost.myDomain.myDS.train_example(split_CV, leaf1.train_fold_indexes_in_node[j]);
		
		leaf2 = get_leaf(ee);
		if (leaf1.name != leaf2.name){
		    Dataset.perror("DecisionTree.class :: Example " + ee + " reaches leaf #" + leaf2 + " but is recorded for leaf #" + leaf1);
		}
	    }
	}
    }
}

