/*---------------------------------------------------
 * File:    sampler.c
 * purpose: Sampling and auxiliary routines
 * author:  ahollowa@uci.edu
 * date:    12/14/09
 *-------------------------------------------------*/

#include "mylib.h"
#include "sampler.h"
#include "alloc.h"
#include "graph.h"
#include "node.h"
#include "init.h"
#include "myio.h"

/*----------------------------------------------------
 * Internal Function Declarations
 *---------------------------------------------------- */
void construct_stick_break_distrib(struct node n, Graph *graph);
int compute_prob_dfs(int W, int L, int *p_estimate, double alpha, double beta, double eta, int parent_id, int child_id, int did, int indx, int depth, int min_depth_of_path, double prob_so_far, int path_so_far[L], double **path_probs, int ***paths, int **path_lengths, int ***dwl, int **dl, Graph *graph, int allow_new_nodes, int must_end_at_this_node);
void decrement_edge_cnts( struct node *from, int to, int did);
void increment_edge_cnts( struct node *from, int to, int did);
int sample_log_space(double *probs, int size);
int set_indx( double *probs, int size );


/*---------------------------------------------------
 *  Auxiliary Functions
 *--------------------------------------------------- */

/* Construct the stick-breaking distribution from scratch */
void construct_stick_break_distrib(struct node n, Graph *graph){
	int node, k = 0;

	if( graph->num_equiv[ n.equiv_class+1 ] < graph->max_depth ){

		n.num_feasible = graph->num_equiv[n.equiv_class+1];
		for(k=0; k < n.num_feasible; k++){
			node = graph->equivalence[k][n.equiv_class+1];
			
			if( node > n.k_capacity ) {
				n.perm          = resize_one_dim_i(n.k_capacity, 2*graph->num_equiv[n.equiv_class+1], n.perm);
				n.feasible      = resize_one_dim_i(n.k_capacity, 2*graph->num_equiv[n.equiv_class+1], n.feasible);
				n.etot_k        = resize_one_dim_i(n.k_capacity, 2*graph->num_equiv[n.equiv_class+1], n.etot_k);
				n.etot_k_agg    = resize_one_dim_i(n.k_capacity, 2*graph->num_equiv[n.equiv_class+1], n.etot_k_agg);
				n.k_capacity    = 2*graph->num_equiv[n.equiv_class+1];
			}
			
			if( node == NEW_NODE ){
				n.feasible[k] = n.id;
				n.perm[n.id] = k;
			}
			else{
				n.feasible[k] = node;
				n.perm[node] = k;
			}
		}
	}
	else{
		n.num_feasible = 1;
		if( n.id > n.k_capacity){
			n.perm          = resize_one_dim_i(n.k_capacity, 2*graph->num_equiv[n.equiv_class+1], n.perm);
			n.feasible      = resize_one_dim_i(n.k_capacity, 2*graph->num_equiv[n.equiv_class+1], n.feasible);
			n.etot_k        = resize_one_dim_i(n.k_capacity, 2*graph->num_equiv[n.equiv_class+1], n.etot_k);
			n.etot_k_agg    = resize_one_dim_i(n.k_capacity, 2*graph->num_equiv[n.equiv_class+1], n.etot_k_agg);
			n.k_capacity    = 2*graph->num_equiv[n.equiv_class+1];
		}
		n.perm[n.id] = 0;
		n.feasible[0] = n.id;
	}
		
	return;
}

/* Adds a feasible node to the stick-breaking distribution */
void update_stick_break_distrib(struct node *n, int node_to_add){
	
	// make sure there is room in the perm and edge_cnts arrays...
	if( node_to_add >= n->k_capacity ){
	  n->perm          = resize_one_dim_i(n->k_capacity, 2*(node_to_add), n->perm);
	  n->feasible      = resize_one_dim_i(n->k_capacity, 2*(node_to_add), n->feasible);
	  n->etot_k        = resize_one_dim_i(n->k_capacity, 2*(node_to_add), n->etot_k);
	  n->etot_k_agg    = resize_one_dim_i(n->k_capacity, 2*(node_to_add), n->etot_k_agg);
	  n->k_capacity    = 2*(node_to_add);

	}
			
	// add new node to node k's stick-breaking distribution
	n->feasible[n->num_feasible] = node_to_add;
	n->perm[node_to_add] = n->num_feasible;
	n->num_feasible++;

	return;
}

/* Compute probability of each path using a modified DFS */
int compute_prob_dfs(int W, int L, int *p_estimate, double alpha, double beta, double eta, int parent_id, int child_id, int did, int indx, int depth, int min_depth_of_path, double prob_so_far, int path_so_far[L], double **path_probs, int ***paths, int **path_lengths, int ***dwl, int **dl, Graph *graph, int allow_new_nodes, int must_end_at_this_node){
	int i, j, x, print = FALSE, exit_child;
	int compute_but_dont_recurse = FALSE;
	double a, b, word_prob;
	int my_indx = indx;
	int *order;
	double prod, exit_prob, tmp;
		
	//Ensure enough capacity in paths, path_probs and path_lengths
	if(my_indx >= *p_estimate ){
		(*paths) = resize_two_dim_i( *p_estimate, 2*(*p_estimate), L, (*paths));
		(*path_probs) = resize_one_dim_d(*p_estimate, 2*(*p_estimate), (*path_probs));
		(*path_lengths) = resize_one_dim_i( *p_estimate, 2*(*p_estimate), (*path_lengths));
		(*p_estimate) *=2;
	}

	if( child_id == NEW_NODE ){

		/*** CASE ONE: Going from a NEW_NODE to nodes in next lower equivalence class ***/
		if( depth+1 < L ){
			
			prod = 1.0;
			order = randperm(graph->num_equiv[depth+1]);  //sample a permutation over the nodes at the next level
			for(i=0; i < graph->num_equiv[depth+1]; i++){
				x = graph->equivalence[order[i]][depth+1];

				// The NEW_NODE is standing in for the exit child
				if( x == NEW_NODE ){
					exit_prob = alpha/(alpha+beta) * prod;
					prod *= beta/(alpha+beta);
					continue; 
				}
				
				tmp = alpha/(alpha+beta) * prod;
				prod *= beta/(alpha+beta);
				if( tmp < DBL_MIN || prod < DBL_MIN ){ prod = 0; tmp  = 0; }

				path_so_far[depth+1] = x;
				indx++;
				assert(tmp>=0);
				indx = compute_prob_dfs(W, L, p_estimate, alpha, beta, eta, child_id, x, did, indx, depth+1, min_depth_of_path, prob_so_far + log(tmp), path_so_far, path_probs, paths, path_lengths, dwl, dl, graph, allow_new_nodes, must_end_at_this_node);				
			}			

			//Recurse on a new node if allowed...which, if we've made it here, is obviously allowed
			if( allow_new_nodes == TRUE ){
				assert(graph->num_equiv[depth+1] <= graph->limit_equiv[depth+1]);
				tmp = prod * ( 1 - pow(beta/(alpha+beta), graph->limit_equiv[depth+1]-graph->num_equiv[depth+1]));
				path_so_far[depth+1] = NEW_NODE;
				indx++;
				assert(tmp>=0);
				indx = compute_prob_dfs(W, L, p_estimate, alpha, beta, eta, child_id, NEW_NODE, did, indx, depth+1, min_depth_of_path, prob_so_far + log(tmp), path_so_far, path_probs, paths, path_lengths, dwl, dl, graph, allow_new_nodes, must_end_at_this_node);
			}
				
			//Save this path if we can end at any node and this node is deep enough...
			if( must_end_at_this_node == FREE_PATH && depth >= min_depth_of_path){
				assert(exit_prob>=0);
				(*path_probs)[my_indx] = prob_so_far + log(exit_prob);
				for(i=0; i<=depth;i++){
					(*paths)[my_indx][i] = path_so_far[i];
					(*path_probs)[my_indx] += compute_prob_doc_at_node(graph, dwl, dl, (*paths)[my_indx][i], did, depth-i, eta, W);
				}
				(*path_lengths)[my_indx] = depth;
			}

			free(order);
		}
		
		/*** CASE TWO: Going from a NEW_NODE to its exit node. ***/
		else{
			//There is no specified node and the path is long enough...
			if(must_end_at_this_node == FREE_PATH && depth >= min_depth_of_path){
				(*path_probs)[my_indx] = prob_so_far + log(1);
				for(i=0; i<=depth; i++){
					(*paths)[my_indx][i] = path_so_far[i];
					(*path_probs)[my_indx] += compute_prob_doc_at_node(graph, dwl, dl, (*paths)[my_indx][i], did, depth-i, eta, W);
				}
				(*path_lengths)[my_indx] = depth;
			}
		}
	}
	else{
		/*** CASE THREE: Going from existing node to nodes at next equivalence class ***/
		if( graph->nodes[child_id].num_feasible > 1 ){

			//The path must end at this node so we don't need to waste time recursing on its children.
			compute_but_dont_recurse = FALSE;
			if( path_so_far[depth] == must_end_at_this_node){
				compute_but_dont_recurse = TRUE;
			}

			prod = 1.0;			
			for(i=0; i < graph->nodes[child_id].num_feasible; i++){
				x = graph->nodes[child_id].feasible[i];
				a = alpha + graph->nodes[child_id].etot_k[x];
				b = beta  + graph->nodes[child_id].etot_k_agg[x];
				
				// Probability of transitioning to x
				if( x == child_id ){
					exit_prob = a/(a+b) * prod;
					prod *= b/(a+b);
					if( compute_but_dont_recurse == TRUE ){ break; }
					else{ continue; }
				}
				tmp = a/(a+b) * prod;
				prod *= b/(a+b);
				if( tmp < DBL_MIN || prod < DBL_MIN ){ prod = 0; tmp = 0; }

				//Either recurse on child, or don't recurse and keep going
				if( compute_but_dont_recurse == FALSE){
					assert(tmp>=0);
					path_so_far[depth+1] = x;
					indx++;
					indx = compute_prob_dfs(W, L, p_estimate, alpha, beta, eta, child_id, x, did, indx, depth+1, min_depth_of_path, prob_so_far + log(tmp), path_so_far, path_probs, paths, path_lengths, dwl, dl, graph, allow_new_nodes, must_end_at_this_node);
				}
			}

			//The path must exit at this node...
			if( compute_but_dont_recurse == TRUE ){
				assert(exit_prob >= 0);
				(*path_probs)[my_indx] = prob_so_far + log(exit_prob);
				for(i=0; i<=depth;i++){
					(*paths)[my_indx][i] = path_so_far[i];
					(*path_probs)[my_indx] += compute_prob_doc_at_node(graph, dwl, dl, (*paths)[my_indx][i], did, depth-i, eta, W);
				}
				(*path_lengths)[my_indx] = depth;			
			}
			else{
			
				//Recurse on a new node if allowed..
				if( allow_new_nodes == TRUE ){
					assert(graph->num_equiv[depth+1] <= graph->limit_equiv[depth+1]);
					tmp = prod * ( 1 - pow(beta/(alpha+beta), graph->limit_equiv[depth+1]-graph->num_equiv[depth+1]));
					assert(tmp>=0);
					path_so_far[depth+1] = NEW_NODE;
					indx++;
					indx = compute_prob_dfs(W, L, p_estimate, alpha, beta, eta, child_id, NEW_NODE, did, indx, depth+1, min_depth_of_path, prob_so_far + log(tmp), path_so_far, path_probs, paths, path_lengths, dwl, dl, graph, allow_new_nodes, must_end_at_this_node);
				}
				
				//Save this path if we can end at any node and this node is deep enough...
				if( must_end_at_this_node == FREE_PATH && depth >= min_depth_of_path){
					(*path_probs)[my_indx] = prob_so_far + log(exit_prob);
					for(i=0; i<=depth;i++){
						(*paths)[my_indx][i] = path_so_far[i];
						(*path_probs)[my_indx] += compute_prob_doc_at_node(graph, dwl, dl, (*paths)[my_indx][i], did, depth-i, eta, W);
					}
					(*path_lengths)[my_indx] = depth;
				}
			}
		}

		/*** CASE FOUR: The only feasible nodes are (1) exit child and (2) a new node ***/
		else if( depth+1 < L ){
			exit_child = graph->nodes[child_id].feasible[0];
			a = alpha + graph->nodes[child_id].etot_k[exit_child];
			b = beta  + graph->nodes[child_id].etot_k_agg[exit_child];
			
			//The path must exit at this node...
			if( path_so_far[depth] == must_end_at_this_node){
				(*path_probs)[my_indx] = prob_so_far + log(a/(a+b));
				for(i=0; i<=depth;i++){
					(*paths)[my_indx][i] = path_so_far[i];
					(*path_probs)[my_indx] += compute_prob_doc_at_node(graph, dwl, dl, (*paths)[my_indx][i], did, depth-i, eta, W);
				}
				(*path_lengths)[my_indx] = depth;
			}
			else{
				//Probability of creating a new node
				if( allow_new_nodes == TRUE){
					assert(graph->num_equiv[depth+1] <= graph->limit_equiv[depth+1]);
					tmp = b/(b+a) * ( 1 - pow(beta/(alpha+beta), graph->limit_equiv[depth+1]-graph->num_equiv[depth+1]));
					path_so_far[depth+1] = NEW_NODE;
					indx++;
					assert(tmp>=0);
					indx = compute_prob_dfs(W, L, p_estimate, alpha, beta, eta, child_id, NEW_NODE, did, indx, depth+1, min_depth_of_path, prob_so_far + log(tmp), path_so_far, path_probs, paths, path_lengths, dwl, dl, graph, allow_new_nodes, must_end_at_this_node);
				}
				
				//Save this path if we can end at any node and this node is deep enough...
				if(must_end_at_this_node == FREE_PATH && depth >= min_depth_of_path){
					(*path_probs)[my_indx] = prob_so_far + log(a/(a+b));
					for(i=0; i<=depth;i++){
						(*paths)[my_indx][i] = path_so_far[i];
						(*path_probs)[my_indx] += compute_prob_doc_at_node(graph, dwl, dl, (*paths)[my_indx][i], did, depth-i, eta, W);
					}
					(*path_lengths)[my_indx] = depth;
				}
			}
		}
		
		/*** CASE FIVE: The only feasible nodes are exit_child because we've reached maximum depth ***/
		else{
			// **EITHER** path must end at this node **OR** path can end at any node as long as it is deep enough
			if( must_end_at_this_node == path_so_far[depth] || (must_end_at_this_node == FREE_PATH && depth >= min_depth_of_path)){
				(*path_probs)[my_indx] = prob_so_far + log(1);
				for(i=0; i<=depth; i++){
					(*paths)[my_indx][i] = path_so_far[i];
					(*path_probs)[my_indx] += compute_prob_doc_at_node(graph,dwl,dl,(*paths)[my_indx][i], did, depth-i, eta, W);
				}
				(*path_lengths)[my_indx] = depth;
			}
		}
	}

	return indx;
}

/* Compute the probability of the word at the given node */
double compute_prob_word_at_node( struct node n, int wid, double eta, double weta){
	return( (eta + n.cp[wid])/(weta + n.ztot));
}

/* Compute the probability of the document at the given node */
double compute_prob_doc_at_node(Graph *graph, int ***dwl, int **dl, int node_id, int did, int depth, double eta, int W){
	int j;
	double word_prob = 0.0;

	if( node_id == NEW_NODE){
		for(j=0; j < W; j++){
			if( dwl[did][j][depth] == 0 ){ continue; }
			word_prob += lgamma(eta + dwl[did][j][depth]) - lgamma(eta);
		}
		word_prob += lgamma(W*eta) - lgamma( W*eta + dl[did][depth]);

	}else{
		for(j=0; j < W; j++){
			if( dwl[did][j][depth] == 0 ){ continue; }			
			word_prob += lgamma(eta + graph->nodes[node_id].cp[j] + dwl[did][j][depth]) - lgamma(eta + graph->nodes[node_id].cp[j]);
		}
		word_prob += lgamma( W*eta + graph->nodes[node_id].ztot) - lgamma( W*eta + graph->nodes[node_id].ztot + dl[did][depth]);
	}

	return(word_prob);
}


/*---------------------------------------------------
 *  Incrementing and Decrementing Functions
 *--------------------------------------------------- */

/* Decrement counts */
void decrement_cnts_sb(Graph *graph, int *path, int did, int length){
	int i, node, from, to;

	if( length == 0 ){
		node = path[length];
		decrement_edge_cnts( &(graph->nodes[node]), node, did);
	}
	else{
		for( i = 0; i < length; i++){
			from = path[i];
			to   = path[i+1];
			decrement_edge_cnts(&(graph->nodes[from]),to, did);
		}
		decrement_edge_cnts( &(graph->nodes[to]), to, did);
	}

	return;
}

/* Decrement the topic word count */
void decrement_cp(struct node *n, int wid, int cnt){
	n->cp[wid] -= cnt;
	n->ztot -= cnt;
	
	assert(n->cp[wid] >= 0);
	assert(n->ztot >= 0);
	
	return;
}

/* Decrement the edge counts for the edge (from, to) */
void decrement_edge_cnts( struct node *from, int to, int did){
	int stick_break_pos, i, x;
	
	from->etot_k[to]--;
	assert(from->etot_k[to] >= 0);

	stick_break_pos = from->perm[to];
	for( i = 0; i < from->num_feasible; i++){
		x = from->feasible[i];
		if( from->perm[x] < stick_break_pos){
			from->etot_k_agg[x]--;
			assert(from->etot_k_agg[x] >= 0);
		}
	}

	return;
}

/* Increment counts */
//COW: The did no longer matters...should remove this
void increment_cnts_sb(Graph *graph, int *path, int did, int length){
	int i, node, from, to;

	if( length == 0 ){
		node = path[length];
		increment_edge_cnts(&(graph->nodes[node]), node, did);
	}
	else{
		for( i = 0; i < length; i++){
			from = path[i];
			to   = path[i+1];
			increment_edge_cnts(&(graph->nodes[from]), to, did);
		}
		increment_edge_cnts( &(graph->nodes[to]), to, did);
	}

	return;
}

/* Increment the topic word count */
void increment_cp(struct node *n, int wid, int cnt){
	n->cp[wid] += cnt;
	n->ztot += cnt;
	return;
}

/* Increment the edge counts for the edge (from, to) */
void increment_edge_cnts( struct node *from, int to, int did){
	int stick_break_pos, i, x;
	
	from->etot_k[to]++;
	stick_break_pos = from->perm[to];
	for( i = 0; i < from->num_feasible; i++){
		x = from->feasible[i];
		if( from->perm[x] < stick_break_pos){
			from->etot_k_agg[x]++;
		}
	}

	return;
}


/*-------------------------------------------------------------
 *  Sampling from multinomials - includes log space sampling
 *------------------------------------------------------------- */

/* Sample a new child node */
int sample_unnormalized_mn(double *probs, double totprob){
    int pid = 0;
	double maxprob = drand()*totprob;
	double currprob = probs[pid];

	while(maxprob > currprob){
		pid++;
		currprob += probs[pid];
	}

	return(pid);
}

/* Sample from probabilities in log space */
int sample_log_space(double *probs, int size){
	int a0_indx, j, pid;
	double log_totprob, tmp, totprob;
	
	a0_indx = set_indx(probs, size);
	log_totprob = probs[a0_indx];		

	tmp = 1.0;
	for(j = 0; j < size; j++){
		if( probs[j] == 0.0 ){ continue; }
		if( j == a0_indx ){ continue; }
			
		assert(exp(probs[j] - probs[a0_indx]) < DBL_MAX);
		if( exp(probs[j] - probs[a0_indx]) < DBL_MIN) {
			tmp += 0;
		}else{
			tmp += exp( probs[j] - probs[a0_indx]);			
		}
	}

	assert(abs(log(tmp)) < DBL_MAX); // Doesn't equal positive or negative infinity
	if( abs(log(tmp)) != 0.0 ){
		assert(abs(log(tmp)) > DBL_MIN);
	}
	log_totprob += log(tmp);

	//Normalize in log space - then transform back
	totprob = 0;
	for(j=0; j < size; j++){
		if( probs[j] == 0.0 ){ continue; }
		probs[j] = probs[j] - log_totprob;
		assert(exp(probs[j]) >= 0.0);
		assert(exp(probs[j]) <= 1.0);
		if(exp(probs[j]) <= DBL_MIN){
			probs[j] = 0;
		}
		else{
			probs[j] = exp(probs[j]);  // the new probabilities are stored in the old array
			totprob += probs[j];
		}
	}
			
	for(j=0; j < size; j++){
		probs[j] /= totprob;
	}

			
	pid = sample_unnormalized_mn(probs, 1.0);
	assert(pid < size);
	return(pid);
}

/* Auxiliary function for sample_log_space that selects argmax of largest prob to be the base */
int set_indx( double *probs, int size ){
	int indx = 0,i, j;
	double val;

//	while(probs[indx] == 0.0){
//		indx = indx+1;
//	}

	val = -DBL_MAX;
	for(i=0; i < size; i++){
		if( probs[i] == 0.0 ){ continue; }
		if( probs[i] > val ){
			indx = i;
			val = probs[i];
		}
	}

	return indx;
}



/*-------------------------------------------------------------
 *  Main Gibb's sampling and Metropolis Hastings Functions
 *------------------------------------------------------------- */

/* Sample path for all word tokens in corpus */
void sample_nonparam_z(int L, int W, int D, int p_estimate, double alpha, double beta, double eta, int ntot, int *d, int *w, int **p, int *lengths, int ***dwl, int **dl, int *min_depth, Graph *graph, int allow_new_nodes, int *docconcept){
	int ii, i, j, k, l, node, v, pid, prev, curr;
	int print = FALSE;
	int *start_path=(int *)calloc(L,sizeof(int)); assert(start_path);
	int *path_lengths;
	double *path_probs;
	int **paths;
	int from_three = 0;
	int from_four = 0;
	
	
	for(i = 0; i < D; i++){		
		if( docconcept[i] == FIXED_PATH){ continue; }				

		//Decrement Counts
		decrement_cnts_sb(graph, p[i], i, lengths[i]);
		for(l=0; l <= lengths[i]; l++){
			node = p[i][lengths[i]-l];
			if( graph->nodes[node].keep_cp_fixed == TRUE){ continue; }
			for(v=0; v < W; v++){
				if( dwl[i][v][l] == 0 ){ continue; }
				decrement_cp(&(graph->nodes[node]), v, dwl[i][v][l]);
			}
		}
			
		
		// Compute probability of each path
		if(print){printf("Document %d\n", i);}
		paths = imat(p_estimate, L);
		path_lengths = ivec(p_estimate);
		path_probs = dvec(p_estimate);
		compute_prob_dfs(W, L, &p_estimate, alpha, beta, eta, -1, 0, i, 0, 0, min_depth[i], 0, start_path, &path_probs, &paths, &path_lengths, dwl, dl, graph, allow_new_nodes, docconcept[i]);

		pid = sample_log_space(path_probs, p_estimate);
		
		for(j = 1; j <= path_lengths[pid]; j++){
			prev = paths[pid][j-1];
			curr = paths[pid][j];

			if( curr == NEW_NODE ){
				if(print){printf("Adding new node %d\n", graph->next_avail_id);}
				paths[pid][j] = graph->next_avail_id;
				add_new_node(graph, graph->next_avail_id, j, L, W, D);
				update_next_avail_id(graph);
			}
			p[i][j] = paths[pid][j];
		}
		lengths[i] = path_lengths[pid];
		free(paths[0]);
		free(paths);
		free(path_probs);
		free(path_lengths);

		if(print){
			printf("min_depth=%d\n", min_depth[i]);
			printf("pid = %d\n", pid);
			printf("Sampled path: ");
			for(j=0; j <= lengths[i]; j++){
				printf("%d ", p[i][j]);
			}
			printf("\n");
			printf("Equivalence Classes:\n");
			for(j=0; j < graph->max_depth; j++){
				printf("\tClass %d: ", j);
				for(k=0; k < graph->num_equiv[j]; k++){
					if( graph->equivalence[k][j] == NEW_NODE)
						printf(" %d", graph->equivalence[k][j]);
					else
						printf(" %d", graph->equivalence[k][j]);
				} printf("\n");
			} printf( "\n");
		}
	
	
		//Increment Counts
		increment_cnts_sb(graph, p[i], i, lengths[i]);
		for(l=0; l <= lengths[i]; l++){
			node = p[i][lengths[i] - l];
			if( graph->nodes[node].keep_cp_fixed == TRUE){ continue; }
			for(v=0; v < W; v++){
				if( dwl[i][v][l] == 0 ){continue;}
				increment_cp(&(graph->nodes[node]), v, dwl[i][v][l]);
			}
		}			
	
		for(j=0; j < 10; j++){
			swap_cluster_labels(graph, alpha, beta);
		}
	}
	
	free(start_path);
	return;
}

/* Sample truncated poisson hyperparameter */
void sample_nonparam_lambda(int D, int L, int ntot, double gamma_k, double gamma_theta, int *d, double *pi_d, int *path_lengths, int *word_levels, int *Nd, int **dl, int *docconcept, Graph *graph, gsl_rng *rgen, int *factorial){
	int i, j, k, nsamples=1000, length, chosen_index;
	double tmp, accept = 1, *log_probs, *vals, *level_cnts, max_prob;
	double totprob;

	vals  = dvec(nsamples);
	log_probs = dvec(nsamples);
	for(i=0; i < D; i++){
		if( docconcept[i] == FIXED_PATH ){ continue; }
		length = path_lengths[i];
		
		//Approximate continuous distribution with samples
		for(j=0; j < nsamples; j++){
			
			//Generate proposal value from a Gamma(k,theta)
			vals[j] =  gsl_ran_gamma(rgen, gamma_k, gamma_theta);
			log_probs[j] = 0.0;

			//Compute acceptance probability of proposal
			for(k=0; k < L;  k++){
				log_probs[j] += dl[i][k]*(k*log(vals[j]) - log(factorial[k]) - vals[j]);
				//probs[j] *= pow((pow(vals[j],k)/factorial[k] * exp(-vals[j])), dl[i][k]);
			}
//			log_probs[j] *= pow((factorial[length]/gsl_sf_gamma_inc(length+1,vals[j])), Nd[i]);
			log_probs[j] -= Nd[i]*(log(factorial[length]) - log(gsl_sf_gamma_inc(length+1,vals[j])));
		}
		chosen_index =  sample_log_space(log_probs, nsamples);
		pi_d[i] = vals[chosen_index];
	}

	free(vals);
	free(log_probs);
	return;

}

/* Sample truncated geometric hyperparameters */
void sample_nonparam_pi(int D, int L, int ntot, double a, double b, int *d, double *pi_d, int *path_lengths, int *word_levels, int *sum_levels, int *Nd, int **dl, int *docconcept, Graph *graph, gsl_rng *rgen){
	int i, j,k, start_index,end_index, nsamples=1000, chosen_index;
	double tmp, accept = 1, *log_probs, *vals, *level_cnts, max_prob;
	
	vals = dvec(nsamples);
	log_probs = dvec(nsamples);	
	for(i=0; i < D; i++){
		if( docconcept[i] == FIXED_PATH ){ continue; }
		
		//Approximate continuous distribution with samples
		for(j=0; j < nsamples; j++){
			vals[j] = gsl_ran_beta(rgen, a, b);
			log_probs[j] = 0.0;

			for(k=0; k < L;  k++){
				log_probs[j] += dl[i][k]*(log(vals[j]) + k*log(1-vals[j]));
		
			}
			log_probs[j] -= Nd[i] * log(1-pow(1-vals[j],path_lengths[i]+1));
		}
		chosen_index = sample_log_space(log_probs, nsamples);
		pi_d[i] = vals[chosen_index];
	}

	free(vals);
	free(log_probs);
	return;
}

/* Sample level assignments for each word token */
void sample_level_assignments_poisson(int L, int W, int D, int ntot, double eta, double *pi_d, int **p, int *d, int *w, int **dl, int ***dwl, int *path_lengths, int *word_levels, int *Nd, int *min_depth, int *docconcept, int *factorial, Graph *graph){
	int i, k, j, did, wid, level, length, node;
	double lambda, totprob;
	double *probs;
	
	for(i=0; i < ntot; i++){
		if( docconcept[d[i]] == FIXED_PATH){continue; }
		
		did = d[i];
		wid = w[i];
		lambda = pi_d[did];
		level = word_levels[i];
		length = path_lengths[did];
		
		//Decrement counts
		node = p[did][length-level];
		if(graph->nodes[node].keep_cp_fixed == FALSE){decrement_cp(&(graph->nodes[node]), wid, 1);}
		dwl[did][wid][level]--;   
		dl[did][level]--;

		//Sample new level
		probs = dvec(length+1);
		totprob = 0;
		for(k=0; k<=length; k++){
			node = p[did][length-k];
			probs[k] = pow(lambda,k)/factorial[k] * exp(-lambda) * factorial[length]/gsl_sf_gamma_inc(length+1,lambda);
			probs[k] *= compute_prob_word_at_node(graph->nodes[node], wid, eta, W*eta);
			totprob += probs[k];
		}
		level = sample_unnormalized_mn(probs, totprob);
		assert(level <= length);
		free(probs);
		
		//Increment counts
		node = p[did][length - level];
		word_levels[i] = level;
		if(graph->nodes[node].keep_cp_fixed == FALSE){increment_cp(&(graph->nodes[node]), wid, 1);}
		dwl[did][wid][level]++;   
		dl[did][level]++;
	}

	//Update min_depth array
	for(i=0; i < D; i++){
		for(j=L-1; j >= 0; j--){
			if( dl[i][j]  == 0 ){
				continue;
			}
			else{
				min_depth[i] = j;
				break;
			}
		}
	}
	
	return;

}

/* Sample level assignments for each word token */
void sample_level_assignments_geometric(int L, int W, int D, int ntot, double eta, double *pi_d, int **p, int *d, int *w, int **dl, int ***dwl, int *path_lengths, int *word_levels, int *sum_levels, int *Nd, int *min_depth, int *docconcept, Graph *graph){
	int i, k, wid, did, level, length, node, j;
	double *probs;
	double totprob, pi;
	
	for(i=0; i < ntot; i++){
		if( docconcept[d[i]] == FIXED_PATH){ continue; }

		did = d[i];
		wid = w[i];
		pi  = pi_d[did];
		level  = word_levels[i];
		length = path_lengths[did];
		
		//Decrement counts
		node = p[did][length - level];
		if(graph->nodes[node].keep_cp_fixed == FALSE){decrement_cp(&(graph->nodes[node]), wid, 1);}
		sum_levels[did] -= level;
		dwl[did][wid][level]--;   
		dl[did][level]--;

		//Sample new level
		probs = dvec(length+1);
		totprob = 0;
		for(k=0; k<=length; k++){
			node = p[did][length-k];
			probs[k] =  pow(1-pi, k) * pi / (1-pow(1-pi, length+1));
			probs[k] *= compute_prob_word_at_node(graph->nodes[node], wid, eta, W*eta);
			totprob += probs[k];
		}
		level = sample_unnormalized_mn(probs, totprob);
		assert(level <= length);
		free(probs);
		
		//Increment counts
		node = p[did][length - level];
		word_levels[i] = level;
		if(graph->nodes[node].keep_cp_fixed == FALSE){increment_cp(&(graph->nodes[node]), wid, 1);}
		sum_levels[did] += level;
		dwl[did][wid][level]++;   
		dl[did][level]++;
	}
	
	
	//Update min_depth array
	for(i=0; i < D; i++){
		for(j=L-1; j >= 0; j--){
			if( dl[i][j]  == 0 ){
				continue;
			}
			else{
				min_depth[i] = j;
				break;
			}
		}
	}
	
	return;
}

/* Implements a Metropolis Hastings sampler for mixing over cluster labels */
void swap_cluster_labels(Graph *graph, double alpha, double beta){
	int i, j, k, diff, node, cluster1, cluster2, node_at_cluster1, node_at_cluster2, N_1, N_2, N_grt_1, N_grt_2, tmp, x, print=FALSE;
	double *cluster_probs, totprob, prob,prob_old,prob_new, prob2;
	
	for(i=0; i < graph->capacity; i++){
		if( graph->nodes[i].id == NOT_IN_USE){ continue; }
		if( graph->nodes[i].num_feasible  <= 1){ continue; }
		
		cluster1 = (graph->nodes[i].num_feasible*drand());
		cluster2 = (graph->nodes[i].num_feasible*drand());
		
		if( cluster1 == cluster2 ){
			//accept with probability 1
			continue;
		}
		
		// Compute log probability of accepting swap
		node_at_cluster1 = graph->nodes[i].feasible[cluster1];
		node_at_cluster2 = graph->nodes[i].feasible[cluster2];
		N_1              = graph->nodes[i].etot_k[node_at_cluster1];
		N_2              = graph->nodes[i].etot_k[node_at_cluster2];


		//We fill in cluster1 counts first ALWAYS and then cluster2 counts
		if( cluster1 < cluster2 ){
			N_grt_1 = graph->nodes[i].etot_k_agg[node_at_cluster1] - N_2;
			N_grt_2 = graph->nodes[i].etot_k_agg[node_at_cluster2];
		}
		else{
			N_grt_1 = graph->nodes[i].etot_k_agg[node_at_cluster1];
			N_grt_2 = graph->nodes[i].etot_k_agg[node_at_cluster2] - N_1 + N_2;
		}


		//Compute acceptance probability
		prob2 = 1.0;
		for(k=0; k < N_1; k++){
			prob2 *= (alpha+beta+N_grt_1+k)/(alpha+beta+N_grt_2+k);
		}
		for(k=0; k < N_2; k++){
			prob2 *= (alpha+beta+N_grt_2+k)/(alpha+beta+N_grt_1+k);
		}


		if( prob2 > DBL_MAX){
			prob2 = 1;
		}
		if( prob2 < DBL_MIN){
			prob2 = 0;
		}
		
		// Accept the swap
		if( N_1 == N_2 || prob2 >= 1 || drand() < prob2){

			if( cluster1 < cluster2 ){
				for( k = cluster1+1; k < cluster2; k++){
					x = graph->nodes[i].feasible[k];
					graph->nodes[i].etot_k_agg[x] -= N_2;
					graph->nodes[i].etot_k_agg[x] += N_1;
				}
				tmp = graph->nodes[i].etot_k_agg[node_at_cluster1];
				graph->nodes[i].etot_k_agg[node_at_cluster1] = graph->nodes[i].etot_k_agg[node_at_cluster2];
				graph->nodes[i].etot_k_agg[node_at_cluster2] = tmp - N_2 + N_1;			
			}
			else{
				for(k= cluster2+1; k < cluster1; k++){
					x = graph->nodes[i].feasible[k];
					graph->nodes[i].etot_k_agg[x] -= N_1;
					graph->nodes[i].etot_k_agg[x] += N_2;
				}
				tmp = graph->nodes[i].etot_k_agg[node_at_cluster2];
				graph->nodes[i].etot_k_agg[node_at_cluster2] = graph->nodes[i].etot_k_agg[node_at_cluster1];
				graph->nodes[i].etot_k_agg[node_at_cluster1] = tmp - N_1 + N_2;			
			}
		
			
			// PERM ARRAY
			tmp = cluster1;
			graph->nodes[i].perm[node_at_cluster1] = cluster2;
			graph->nodes[i].perm[node_at_cluster2] = tmp;
			
			
			// FEASIBLE ARRAY
			tmp = node_at_cluster1;
			graph->nodes[i].feasible[cluster1] = node_at_cluster2;
			graph->nodes[i].feasible[cluster2] = tmp;
						
		}
		// else we reject the swap
	}

	return;
}

/*-------------------------------------------------------------
 *  Metropolis Hastings Samplers for the Hyperparameters
 *------------------------------------------------------------- */

double sample_eta(double eta_prior, double eta, int W, Graph *graph, gsl_rng *rgen){
	double proposal, accept_prob = 0.0;
	int i, j, T = 0;

	proposal = gsl_ran_exponential(rgen, eta_prior);
	//printf("proposed %f\n",proposal);
	//Compute acceptance probability
	for(i=0; i < graph->capacity; i++){
		if(graph->nodes[i].id == NOT_IN_USE){ continue; }		
//		if(graph->nodes[i].keep_cp_fixed == TRUE){ continue; }
		if(graph->nodes[i].ztot == 0){ continue; }
		for(j=0; j < W; j++){
			accept_prob += lgamma(proposal + graph->nodes[i].cp[j]) - lgamma(eta + graph->nodes[i].cp[j]);
		}
		accept_prob += lgamma(W*eta + graph->nodes[i].ztot) - lgamma(W*proposal + graph->nodes[i].ztot);
		T++;
	}
	accept_prob += T*W*(lgamma(eta) - lgamma(proposal));
	accept_prob += T*(lgamma(W*proposal) - lgamma(W*eta));
	accept_prob += eta_prior*(proposal-eta);
	
	//Automatically accept
	if( accept_prob > 0 ){
//			printf("accept_prob grt than 1: returning %f\n",proposal);
			return(proposal);
	}

	//At this point, exp(accept_prob) should be between 0 and 1
	accept_prob = exp(accept_prob);
	if( accept_prob < DBL_MIN ){
//		printf("after exponentiating, accept_prob smaller than DBL_MIN\n");
		return(eta);
	}
	
	if( drand() < accept_prob ){
//		printf("drand smaller than accept_prob %f: returning %f\n", accept_prob, proposal);
		return(proposal);
	}
	
//	printf("drand larger than accept_prob\n");
	return(eta);
}

double sample_beta(double beta_prior, double beta, double alpha, Graph *graph, gsl_rng *rgen){
	double proposal, accept_prob = 0;
	int i, j, k, N_i, N_grt_i, node, pos;


	proposal = gsl_ran_exponential(rgen, beta_prior);

	for(i=0; i < graph->capacity; i++){
		if( graph->nodes[i].id == NOT_IN_USE){ continue; }
		
		//Iterate through feasible nodes backwards...
		for(j=0; j < graph->nodes[i].num_feasible; j++){

			//The current feasible node and its position
			node = graph->nodes[i].feasible[ (graph->nodes[i].num_feasible - 1) - j];
			pos  = graph->nodes[i].perm[node];

			for(k=0; k < graph->nodes[i].etot_k[node]; k++){
				N_i = k;
				N_grt_i = graph->nodes[i].etot_k_agg[node];


				//The probability of NOT choosing an earlier cluster:
				// A product over all previous clusters. For each cluster the probability of NOT selecting it is beta + the data
				// counts of the number of times we went beyond the cluster. This is just etot_k_agg and the current count k
				accept_prob += (pos*(log(proposal + N_i + N_grt_i) - log(alpha + proposal + N_i + N_grt_i))) - (pos*(log(beta + N_i + N_grt_i) - log(alpha + beta + N_i + N_grt_i)));


				//The probability of choosing the current cluster NODE:
				// The probability of choosing node is alpha + the data counts of the number of times we chose NODE before.
				// That is just k.
				accept_prob += (log(alpha + N_i) - log(alpha + proposal + N_i + N_grt_i)) - (log(alpha + N_i) - log(alpha + beta + N_i + N_grt_i));

				//Both quantities have the same normalizing constant which is log(alpha + beta + N_i + N_grt_i)
			}
		}
	}
	accept_prob += beta_prior*(proposal-beta);

	//Automatically accept
	if( accept_prob > 0 ){
		//printf("accept_prob grt than 1: returning %f\n",proposal);
		return(proposal);
	}

	//At this point, exp(accept_prob) should be between 0 and 1
	accept_prob = exp(accept_prob);
	if( accept_prob < DBL_MIN ){
		//printf("after exponentiating, accept_prob smaller than DBL_MIN\n");
		return(beta);
	}
	
	if( drand() < accept_prob ){
		//printf("drand smaller than accept_prob: returning %f\n", proposal);
		return(proposal);
	}
	
	//printf("drand larger than accept_prob\n");
	return(beta);
}


