# tree functions



# Get facts about the input data matrix
# called by init_tree()
# NOTE! get_metadata() is intended to run once on all data. The only limitation for this
#      to expand into further use is the n_l=n argument to the two calls of get_cutpoint_matrix()

get_metadata <- function(data, dep_var, loss_function) {
	# ------- arguments -------
	# data    -- self explanatory
	# -------------------------

	stopifnot(is.matrix(data)) 
	
	p <- NCOL(data)
	n <- NROW(data)
# TO DO: provide options for bounding box, one of which is this:
	bottoms <- apply(data,2,min)
	tops <- apply(data,2,max)

	# The following three matrices are widely used in what follows.
	# This requires 2p sorts (of n data each).
	# Each node in the tree will store a vector of "data" rownums (and also a temporary
	#   "ordered" matrix for the subset of data if it is "available").

	tempord <- sort.int(data[,1],index.return=TRUE)
	sorted_data <- tempord$x
	ordered <- tempord$ix
	orders <- sort.int(ordered,index.return=TRUE)$ix
	for(i in 2:p) {
		tempord <- sort.int(data[,i],index.return=TRUE)

		# "sorted_data" is independently sorted columns of "data", bolted together. 
		#   The matrix is just a convenient storage format.
		sorted_data <- cbind(sorted_data,tempord$x)

		# Elements of the "ordered" matrix are "data" row numbers. The 1st element in column j is
		#   the "data" rownum with the lowest value in data[,j], and so on.
		#   The matrix is just a convenient storage format.
		ordered <- cbind(ordered,tempord$ix)

		# Elements of the "orders" matrix are "ordered" row numbers. orders[i,] contains the
		#   rownums of "ordered" which contain i.
		orders <- cbind(orders,sort.int(ordered[,i],index.return=TRUE)$ix)
	}
	rm(tempord)

	# assemble list of cutpoint matrices (could be lapply() or mlapply() over as.list(1:p))
	cutpoints <- list(get_cutpoint_matrix(sorted_data=sorted_data,
		                                   n=n,
		                                   n_l=n,
		                                   colnum=1, 
		                                   loss_function=loss_function, 
		                                   dep_var=NA))
	if(p>1) {
		for(i in 2:p) {
			cutpoints[[i]] <- get_cutpoint_matrix(sorted_data=sorted_data,
				                                   n=n,
				                                   n_l=n,
		                                         colnum=i, 
		                                         loss_function=loss_function, 
		                                         dep_var=NA)
		}
	}

	return(list(p=p,
		        n=n,
		        bottoms=bottoms,
		        tops=tops,
		        sorted_data=sorted_data,
		        ordered=ordered,
		        orders=orders,
		        dep_var=dep_var,
		        cutpoints=cutpoints))
	# this list gets used as 'metadata' argument in init_tree()
	# and maybe elsewhere
	# p and n are integers,
	# bottoms and tops are vectors of same type as data
	# sorted_data is a vector of same type as data
	# ordered and orders are matrices of integers
	# cutpoints is a list with an element for each dimension,
	#   each element being a matrix from get_cutpoint_matrix(),
	#   with a row for every unique cutpoint and 8 columns
}




















# get matrix of cutpoint facts for one variable
# called by get_metadata()
get_cutpoint_matrix <- function(sorted_data,
	                             n,
	                             n_l,
                                colnum,
                                loss_function,
                                dep_var) {
	# ------- arguments -------
	# sorted_data   -- see get_metadata() and split_matrices()
	# n             -- n of all data i.e. NROW(data)
	# n_l           -- n in the node
	# colnum        -- natural number, which column of data (and ordered) is 
	#                    to be processed
	# loss_function -- loss function, e.g. loss_ise()
	# dep_var       -- dependent variable (numeric vector) (which is all 1s if loss_function=loss_ise) 
	# -------------------------


	# get variable where FALSE means a new value 
	# (just to the right of a potential cutpoint,
	# except for the first row, which we exclude later)

	# get boolean matrix of which elements of sorted_data have not been blanked (i.e. are still in the
	#   subspace of the current node)
	sorted_trimmed_rows <- !is.na(sorted_data)

	# n_l in the current node
	n_trimmed <- sum(as.numeric(sorted_trimmed_rows[,colnum]))

	# reduced version of dep_var for cutpoint calculation
	dep_var_trimmed <- dep_var[!is.na(dep_var)]

	# get variable where FALSE means a potential cutpoint
	dups <- (sorted_data[sorted_trimmed_rows[,colnum],colnum][2:n_trimmed]==
		       sorted_data[sorted_trimmed_rows[,colnum],colnum][1:(n_trimmed-1)])

	# capture situation where there are no possible cutpoints because 
	#   all data in that dimension have identical values
	if(all(dups)) {
		sl_cut_ordered_index <- NA_integer_
		sl_cut_left_count <- NA_integer_
		sl_cut_right_count <- NA_integer_
		sl_cut_left_sum <- NA_integer_
		sl_cut_right_sum <- NA_integer_
		sl_cut_left_sumsq <- NA_integer_
		sl_cut_right_sumsq <- NA_integer_
		sl_cut_values <- NA_integer_
	} else {

		# vector of the FALSE row numbers -- immediately to the left of a cut
		sl_cut_ordered_index <- which(!dups)
		n_cuts <- length(sl_cut_ordered_index)

		### TO DO: select these based on loss_function
		# cumulative 0th to 2nd moments of dep_var
		sl_cut_cumsum_0 <- cumsum(dep_var_trimmed^0)[sl_cut_ordered_index]
		sl_cut_cumsum_1 <- cumsum(dep_var_trimmed^1)[sl_cut_ordered_index]
		sl_cut_cumsum_2 <- cumsum(dep_var_trimmed^2)[sl_cut_ordered_index]

		# number of observations immediately to the left 
		#   the kth element is (count to left of cut k+1) - (count to left of cut k)
		sl_cut_left_count <- sl_cut_ordered_index - 
	        		           c(0, sl_cut_ordered_index[-n_cuts])

		# number of observations immediately to the right
		sl_cut_right_count <- c(sl_cut_left_count[-1], n_l-sum(sl_cut_left_count))

		# vector of cutpoint values (midway between data)
		sl_cut_values <- (sorted_data[sorted_trimmed_rows[,colnum],colnum][sl_cut_ordered_index] +
		    		        sorted_data[sorted_trimmed_rows[,colnum],colnum][c(sl_cut_ordered_index[-1], n_l)]) / 2


		### TO DO: select these based on loss_function
		# obtain sum and sum of squares to left and right
		sl_cut_left_sum <- sl_cut_cumsum_1 - c(0, sl_cut_cumsum_1[-n_cuts])
		sl_cut_left_sumsq <- sl_cut_cumsum_2 - c(0, sl_cut_cumsum_1[-n_cuts])
		sl_cut_right_sum <- n_l - sl_cut_left_sum
		sl_cut_right_sumsq <- n_l - sl_cut_left_sumsq
	}

	# combine in matrix
	sl_cut <- cbind(sl_cut_ordered_index,
		            sl_cut_left_count,
		            sl_cut_right_count,
		            sl_cut_left_sum,
		            sl_cut_right_sum,
		            sl_cut_left_sumsq,
		            sl_cut_right_sumsq,
		            sl_cut_values)
	# this is a matrix relating to one variable/dimension, with:
	# - a row for each cutpoint between unique values
	# - a column for:
	#    - row numbers of first of each unique value
	#    - number of observations immediately to the left
	#    - number of observations immediately to the right
	#    - cutpoint values (midway between data)
	return(sl_cut)
}



















# first of a collection of loss_* functions for one node
# to be supplied as argument to init_tree and growing functions
# call predict_* function
# returns predicted, observed and loss
# Generate the ISE loss statistic, to be used at the first potential split in any dimension
# Requires n_all, the n for the complete dataset (other target stats may require other moments)
#   which is in tree_list$metadata$n
# All loss functions should return a list with elements called 'predicted' and 'loss'
# Those node-specific losses must *add* to the total tree loss (hence we don't work
#   with log-negative-loss)
#
# The arguments will differ between loss functions, so we send a list, allowing
#    generic functional programming in init_tree and best_cut_univar
loss_ise <- function(loss_input) {
	# Arguments:
	#-------------
	# - loss_input: a list, allowing generic functional programming for any loss function, 
	#                 which, in the case of loss_ise(), contains:
	#    - n_all: n for whole tree / all data
	#    - n_l: n in this node
	#    - node_bottoms: bottoms vector for this node
	#    - node_tops: tops vector for this node

	### TO DO: issue warning if loss gets too small or large and threatens digital rounding error.
	###        Alternatively, rescale volume within this function and deal with proportion instead of n_l

	node_volume <- prod(loss_input$node_tops - loss_input$node_bottoms)
	loss <- (-1)*(loss_input$n_l^2)/node_volume
	predicted <- loss_input$n_l / (loss_input$n_all * node_volume)
	return(list(loss=loss,
		         predicted=predicted))
	# note: in theory, if we calculated loss for both sides of the prospective split together, 
	#       we needn't calculate both volumes
}

loss_xentropy <- function() {}
loss_sse <- function() {}
# From Ram & Gray:
# "The ISE gives a notion of overall distance between the estimated and the true density 
# and is a favored choice in nonparametric density estimation for its inherent robustness 
# in comparison to maximum-likelihood- based loss functions [8]. 
# However, other distance functions such as the KL-divergence can be used as the loss function."





















# loss adjusting functions
loss_adjust_ise <- function() {}
loss_adjust_xentropy <- function() {}
loss_adjust_sse <- function() {}
# see endnote 7, Efron & Hastie CASI, p.130, equating the parent loss to the children's losses and another term




















# Initialise tree list with root (all data in the bounding box
# is represented by one predicted value)
init_tree <- function(data, loss_function, dep_var) {

	# create dep_var of 1s for loss_ise (ignoring whatever was provided.)
	if(identical(loss_function,loss_ise)) {
		dep_var <- rep(1,NROW(data))
	}

	metadata <- get_metadata(data, dep_var, loss_function)


	# in this setting, the node supplied to the loss function is the entire dataset
	loss_tree <- loss_function(list(n_all=metadata$n,
		                             n_l=metadata$n,
		                             node_bottoms=metadata$bottoms,
		                             node_tops=metadata$tops))
	tree_list <- list(nodes=list(list(parent=NA_integer_,
									          children=rep(NA_integer_,2),
									          children_cutpoint=NA_integer_,
									          children_dimension=NA_integer_,
		                               n_l=metadata$n,
		                               p=metadata$p,
		                               sorted_data=metadata$sorted_data,
		                               ordered=metadata$ordered,
		                               orders=metadata$orders,
		                               bottoms=metadata$bottoms,
		                               tops=metadata$tops,
		                               predicted=loss_tree$predicted,
		                               loss=loss_tree$loss,
		                               dep_var=dep_var,
		                               available=TRUE,
		                               leaf=TRUE,
		                               cutpoints=metadata$cutpoints)),
	                  metadata=metadata) 
	# two items: list of nodes and metadata for the whole tree
	# in this root, the nodes$cutpoints == metadata$cutpoints, but later, as we
	#   subdivide the nodes, the nodes$cutpoints will change and be used in best_cut()
	return(tree_list)
}

















get_availability <- function(current_tree, grow_parameters) {
	return(unlist(lapply(current_tree$nodes,function(x){ (x$available) & (x$n_l>(2*grow_parameters$min_n_l)) })))
}
# TO DO: allow for other grow_parameters





















# main tree-growing function
# rather than exhaustive recursion down one branch before moving to the next, 
#   this considers a split in each node, and then passes back though them all 
#   (if still available) repeatedly until none are available
grow_tree <- function(current_tree, loss_function, grow_parameters) {
	avail_nodes_bool <- get_availability(current_tree, grow_parameters)

	while(any(avail_nodes_bool) & (length(current_tree$nodes) < grow_parameters$max_nodes)) {
		avail_nodes_num <- which(avail_nodes_bool)

		### TO DO: change to lapply()
		for(l in avail_nodes_num) {

			# get the best cut in this node:
			node_best_cut <- best_cut(current_tree, l, loss_function, grow_parameters)

			# check whether node_best_cut meets requirements 
			# this should deal with -1 from best_cut()
			can_cut <- is.list(node_best_cut)
			parent_node_id <- l

			if(can_cut) {
				left_child_node_id <- length(current_tree$nodes)+1
				right_child_node_id <- length(current_tree$nodes)+2

				# The output of best_cut is then stored in various parts of current_tree
				# The parent node is marked as unavailable, not leaf, and 
				#  (under consideration) its objects no longer needed are removed 

				current_tree$nodes[[left_child_node_id]] <- list()
				current_tree$nodes[[right_child_node_id]] <- list()

				# split parent matrices at colnum and c
         	child_matrices <- split_matrices(sorted_data=current_tree$nodes[[parent_node_id]]$sorted_data,
         		                              ordered=current_tree$nodes[[parent_node_id]]$ordered,
         		                              orders=current_tree$nodes[[parent_node_id]]$orders,
         		                              dep_var=current_tree$nodes[[parent_node_id]]$dep_var,
         		                              colnum=node_best_cut$j,
         		                              c=node_best_cut$c,
         		                              n_l=current_tree$nodes[[parent_node_id]]$n_l,
         		                              n_all=current_tree$metadata$n,
         		                              p=current_tree$nodes[[parent_node_id]]$p)
         	current_tree$nodes[[left_child_node_id]]$sorted_data <- child_matrices$left_matrices$sorted_data_left
         	current_tree$nodes[[right_child_node_id]]$sorted_data <- child_matrices$right_matrices$sorted_data_right
         	current_tree$nodes[[left_child_node_id]]$ordered <- child_matrices$left_matrices$ordered_left
         	current_tree$nodes[[right_child_node_id]]$ordered <- child_matrices$right_matrices$ordered_right
         	current_tree$nodes[[left_child_node_id]]$orders <- child_matrices$left_matrices$orders_left
         	current_tree$nodes[[right_child_node_id]]$orders <- child_matrices$right_matrices$orders_right
         	current_tree$nodes[[left_child_node_id]]$dep_var <- child_matrices$left_matrices$dep_var_left
         	current_tree$nodes[[right_child_node_id]]$dep_var <- child_matrices$right_matrices$dep_var_right

				# other content of node in tree list
				current_tree$nodes[[left_child_node_id]]$parent <- 
				   current_tree$nodes[[right_child_node_id]]$parent <- parent_node_id
				current_tree$nodes[[left_child_node_id]]$children <- 
				   current_tree$nodes[[right_child_node_id]]$children <- rep(NA_integer_, 2)
				current_tree$nodes[[left_child_node_id]]$children_cutpoint <- 
				   current_tree$nodes[[right_child_node_id]]$children_cutpoint <- NA_real_
				current_tree$nodes[[left_child_node_id]]$children_dimension <- 
				   current_tree$nodes[[right_child_node_id]]$children_dimension <- NA_integer_
				current_tree$nodes[[left_child_node_id]]$n_l <- sum(!is.na(child_matrices$left_matrices$sorted_data[,1]))
				current_tree$nodes[[right_child_node_id]]$n_l <- sum(!is.na(child_matrices$right_matrices$sorted_data[,1]))
				current_tree$nodes[[left_child_node_id]]$p <- 
				   current_tree$nodes[[right_child_node_id]]$p <- current_tree$nodes[[parent_node_id]]$p

				# change one element of tops in the left child, and one element of bottoms in the right
				current_tree$nodes[[left_child_node_id]]$bottoms <- current_tree$nodes[[parent_node_id]]$bottoms
				current_tree$nodes[[left_child_node_id]]$tops <- current_tree$nodes[[parent_node_id]]$tops
				current_tree$nodes[[left_child_node_id]]$tops[node_best_cut$j] <- node_best_cut$cut_value
				current_tree$nodes[[right_child_node_id]]$bottoms <- current_tree$nodes[[parent_node_id]]$bottoms
				current_tree$nodes[[right_child_node_id]]$bottoms[node_best_cut$j] <- node_best_cut$cut_value
				current_tree$nodes[[right_child_node_id]]$tops <- current_tree$nodes[[parent_node_id]]$tops

				current_tree$nodes[[left_child_node_id]]$predicted <- node_best_cut$left_loss$predicted
				current_tree$nodes[[left_child_node_id]]$loss <- node_best_cut$left_loss$loss
				current_tree$nodes[[right_child_node_id]]$predicted <- node_best_cut$right_loss$predicted
				current_tree$nodes[[right_child_node_id]]$loss <- node_best_cut$right_loss$loss

				current_tree$nodes[[left_child_node_id]]$available <- 
				   current_tree$nodes[[right_child_node_id]]$available <- TRUE
				current_tree$nodes[[left_child_node_id]]$leaf <- 
				   current_tree$nodes[[right_child_node_id]]$leaf <- TRUE

# TO DO: skip any leaf-dimension combinations where the parent node's cutpoints matrix is a line of NAs (indicating no possible cutpoint already)
# TO DO: skip child nodes if n_l <= 2*min_n_l, and mark them !available
         	# make child cutpoints matrix lists (could be lapply() or mlapply() over as.list(1:p))
				current_tree$nodes[[left_child_node_id]]$cutpoints <- 
				       list(get_cutpoint_matrix(sorted_data=current_tree$nodes[[left_child_node_id]]$sorted_data,
		   	                                n=current_tree$metadata$n,
		   	                                n_l=current_tree$nodes[[left_child_node_id]]$n_l,
		   	                                colnum=1, 
		   	                                loss_function=loss_function, 
		   	                                dep_var=NA))
				current_tree$nodes[[right_child_node_id]]$cutpoints <- 
				       list(get_cutpoint_matrix(sorted_data=current_tree$nodes[[right_child_node_id]]$sorted_data,
		   	                                n=current_tree$metadata$n,
		   	                                n_l=current_tree$nodes[[right_child_node_id]]$n_l,
		   	                                colnum=1, 
		   	                                loss_function=loss_function, 
		   	                                dep_var=NA))
				if(current_tree$nodes[[parent_node_id]]$p > 1) {
					for(i in 2:(current_tree$nodes[[parent_node_id]]$p)) {
						current_tree$nodes[[left_child_node_id]]$cutpoints[[i]] <- 
								get_cutpoint_matrix(sorted_data=current_tree$nodes[[left_child_node_id]]$sorted_data,
					                             n=current_tree$metadata$n,
		   	                                n_l=current_tree$nodes[[left_child_node_id]]$n_l,
		   	                                colnum=i, 
		   	                                loss_function=loss_function, 
		   	                                dep_var=NA)
						current_tree$nodes[[right_child_node_id]]$cutpoints[[i]] <- 
								get_cutpoint_matrix(sorted_data=current_tree$nodes[[right_child_node_id]]$sorted_data,
					                             n=current_tree$metadata$n,
		   	                                n_l=current_tree$nodes[[right_child_node_id]]$n_l,
		   	                                colnum=i, 
		   	                                loss_function=loss_function, 
		   	                                dep_var=NA)
					}
				}

				# amend parent node in tree list
				current_tree$nodes[[parent_node_id]]$available <- FALSE
				current_tree$nodes[[parent_node_id]]$leaf <- FALSE
# CONSIDER: deleting data once a node has been split is questionable. it might interfere with pruning.
				current_tree$nodes[[parent_node_id]]$sorted_data <- 
				  current_tree$nodes[[parent_node_id]]$ordered <- 
				  current_tree$nodes[[parent_node_id]]$orders <- 
				  current_tree$nodes[[parent_node_id]]$cutpoints <- NA_integer_
				current_tree$nodes[[parent_node_id]]$children <- c(left_child_node_id, right_child_node_id)
				current_tree$nodes[[parent_node_id]]$children_cutpoint <- node_best_cut$cut_value
				current_tree$nodes[[parent_node_id]]$children_dimension <- node_best_cut$j

				# if the child node is unsplittable, drop the big matrices:
				if(current_tree$nodes[[left_child_node_id]]$n_l<(2*grow_parameters$min_n_l)) {
				  current_tree$nodes[[left_child_node_id]]$sorted_data <- 
				  current_tree$nodes[[left_child_node_id]]$ordered <- 
				  current_tree$nodes[[left_child_node_id]]$orders <- 
				  current_tree$nodes[[left_child_node_id]]$cutpoints <- NA_integer_

				}
				if(current_tree$nodes[[right_child_node_id]]$n_l<(2*grow_parameters$min_n_l)) {
				  current_tree$nodes[[right_child_node_id]]$sorted_data <- 
				  current_tree$nodes[[right_child_node_id]]$ordered <- 
				  current_tree$nodes[[right_child_node_id]]$orders <- 
				  current_tree$nodes[[right_child_node_id]]$cutpoints <- NA_integer_

				}
			}
			else{
				# actions if no cut is possible/satisfactory	
				current_tree$nodes[[parent_node_id]]$available <- FALSE
				current_tree$nodes[[parent_node_id]]$leaf <- FALSE	
# CONSIDER: deleting data if a node can't be split is questionable. it might interfere with pruning.
				current_tree$nodes[[parent_node_id]]$sorted_data <- 
				  current_tree$nodes[[parent_node_id]]$ordered <- 
				  current_tree$nodes[[parent_node_id]]$orders <- 
				  current_tree$nodes[[parent_node_id]]$cutpoints <- NA_integer_

			}
			# blank out the objects craeted and used inside the loop over available nodes,
			#   so that they do not go on and cause false outputs if something fails
			node_best_cut <- can_cut <- parent_node_id <- left_child_node_id <- right_child_node_id <- 
			   child_matrices <- NA
		}
		avail_nodes_bool <- get_availability(current_tree, grow_parameters)
	}
	return(current_tree)
}



















best_cut <- function(current_tree, l, loss_function, grow_parameters) {
#print(paste0("----- Evaluating node ",l))
	### TO DO: mlapply here if parallel option is chosen:
	univar_bests <- lapply(as.list(1:current_tree$metadata$p),
		                    best_cut_univar,
		                    n_l=current_tree$nodes[[l]]$n_l,
		                    node_tops=current_tree$nodes[[l]]$tops,
			                node_bottoms=current_tree$nodes[[l]]$bottoms,
			                cutpoints=current_tree$nodes[[l]]$cutpoints,
			                loss_function=loss_function, 
		                    grow_parameters=grow_parameters,
		                    tree_metadata=current_tree$metadata,
		                    l=l)

	feasible_dimensions <- sapply(univar_bests,is.list)

	# if all univar_bests are -1 (insufficient n_l), return -1
	if(!any(feasible_dimensions)) { 
		output <- (-1)
	}
	else {
		feasible_univar_bests <- univar_bests[feasible_dimensions]
		feasible_univar_total_losses <- sapply(feasible_univar_bests, function(x){x$total_loss})
		best_cut_found <- feasible_univar_bests[[which(feasible_univar_total_losses==min(feasible_univar_total_losses))[1]]]

		# if no univar_bests beat parent loss, return -1
		if(best_cut_found$total_loss>=current_tree$nodes[[l]]$loss) {
			output <- (-1)
		}
		else {
			output <- best_cut_found
		}
	}

	return(output) # output is either (-1) or a best_cut_found list
}


















	### TO DO: what happens to cutpoints matrix if no cutpoint is possible in dimension j?
	###        This should lead to best_cut_univar skipping straight to return -1

best_cut_univar <- function(j,
	                         n_l,
	                         node_tops, 
	                         node_bottoms,  
	                         cutpoints, 
	                         loss_function, 
	                         grow_parameters,
	                         tree_metadata,
	                         l) {
	# arguments:
	#------------
	# n_l: count of observations in parent
	# node_tops: vector of tops for parent
	# node_bottoms: vector of bottoms for parent
	# j: number of variable (data column & element of tops/bottoms) under consideration 
	# cutpoints: list of cutpoints matrices (soon trimmed to just the variable under consideration)
	# loss_function: loss function passed down from tree()
	### TO DO:	# adjust_loss_function: adjust loss function passed down from tree()
	# grow_parameters: parameters passed down from tree()
	# tree_metadata: metadata list on whole tree
	# l : node ID, useful for debugging

	### TO DO: - use adjust_loss function 

	# if no cut was possible (not enough n_l or no cutpoints available), return -1
	if((n_l < 2*grow_parameters$min_n_l) | (all(is.na(cutpoints[[j]])))) {
		best_cut_univar_loss <- (-1)
	} else {
		# use only the cutpoints matrix for the variable under consideration
		cutpoints <- cutpoints[[j]]

		n_cuts <- NROW(cutpoints)

		# vectors of counts anywhere on left/right of cutpoint
		cumsum_left_count <- cumsum(cutpoints[,"sl_cut_left_count"])
		cumsum_right_count <- n_l-cumsum_left_count

		### TO DO: include other grow_parameters
		# Find start/end cutpoints (row numbers in cutpoints matrix) using grow_parameters$min_n_l
		# This must be within the first/last grow_parameters$min_n_l rows of cutpoints
		
# 		# evaluate whole vector if NROW(cutpoints) < (2*grow_parameters$min_n_l)
# 		if(n_cuts < (2*grow_parameters$min_n_l)) {
# 			start_cutpoint <- 1
# 			end_cutpoint <- n_cuts
# 		# this does not exclude those outside min_n_l; perhaps it should be dropped entirely
# 		}
# 		if(!(n_cuts < (2*grow_parameters$min_n_l))) {
			# otherwise, evaluate only the promising bits
#			start_cumsum_search <- min(grow_parameters$min_n_l, n_cuts)
#			end_cumsum_search <- max(1,n_cuts-grow_parameters$min_n_l)
			start_cutpoint <- which(cumsum_left_count >= grow_parameters$min_n_l)[1]
			end_cutpoint <- rev(which(cumsum_right_count >= grow_parameters$min_n_l))[1] #+ (end_cumsum_search-1)
#			end_cutpoint <- end_cutpoint[length(end_cutpoint)]
### TO DO: this could be more efficient, not search all cutpoints

		# another check for hopeless dimensions where there are no acceptable cuts:
		if(is.na(start_cutpoint) | is.na(end_cutpoint)){
			best_cut_univar_loss <- (-1)
		} else {
			if(start_cutpoint > end_cutpoint) {
				best_cut_univar_loss <- (-1)
			}
			else {
				# evaluate start cutpoint
				best_cut_univar_loss <- evaluate_potential_cut(j=j,
					                                            cutnum=start_cutpoint,
					                                            parent_tops=node_tops,
					                                            parent_bottoms=node_bottoms,
					                                            n_all=tree_metadata$n,
					                                            cumsum_left_count=cumsum_left_count,
					                                            cumsum_right_count=cumsum_right_count,
					                                            cutpoints=cutpoints,
					                                            loss_function=loss_function)

				# skip this if there is only one cutpoint:
				if(end_cutpoint > start_cutpoint) {
					for(i in (start_cutpoint+1):(end_cutpoint)) {
						# evaluate each cutpoint

						next_cut_univar_loss <- evaluate_potential_cut(j=j,
						                                               cutnum=i,
						   	                                         parent_tops=node_tops,
						   	                                         parent_bottoms=node_bottoms,
						   	                                         n_all=tree_metadata$n,
						   	                                         cumsum_left_count=cumsum_left_count,
						   	                                         cumsum_right_count=cumsum_right_count,
						   	                                         cutpoints=cutpoints,
				    	                                        	 loss_function=loss_function)
						if(next_cut_univar_loss$total_loss < best_cut_univar_loss$total_loss) {
							best_cut_univar_loss <- next_cut_univar_loss
						}
					}
				}
			}
		}
	}
	# return best univariate cut
	return(best_cut_univar_loss)
}














evaluate_potential_cut <- function(j,
	                                cutnum,
	                                parent_tops,
	                                parent_bottoms,
	                                n_all,
	                                cumsum_left_count,
	                                cumsum_right_count,
	                                cutpoints,
	                                loss_function) {
	temp_tops <- parent_tops
	temp_bottoms <- parent_bottoms

	# left child:
	# replace top in dimension j
	temp_tops[j] <- cutpoints[cutnum,8]
	# get loss
	left_loss <- loss_function(list(n_all=n_all,
		                             n_l=cumsum_left_count[cutnum],
		                				  node_tops=temp_tops,
		                				  node_bottoms=temp_bottoms))
	# restore tops:
	temp_tops <- parent_tops

	# right child:
	# replace bottom in dimension j
	temp_bottoms[j] <- cutpoints[cutnum,8]
	# get loss
	right_loss <- loss_function(list(n_all=n_all,
		                              n_l=cumsum_right_count[cutnum],
		                  				node_tops=temp_tops,
		                 				   node_bottoms=temp_bottoms))
	# restore bottoms:
	temp_bottoms <- parent_bottoms # not really necessary but just for safety if anything changes in future

	return(list(j=j,
		         cutnum=cutnum,
		         c=cumsum_left_count[cutnum],
		         cut_value=as.numeric(cutpoints[cutnum,8]),
		         left_loss=left_loss, 
		         right_loss=right_loss, 
		         total_loss=left_loss$loss+right_loss$loss))
}














# this function takes the parent node (or root) matrices (sorted_data,
#   ordered and orders) and generates them for the children, without sorting
split_matrices <- function(sorted_data,
	                        ordered,
	                        orders,
	                        dep_var,
	                        colnum,
	                        c,
	                        n_l,
	                        n_all,
	                        p) {
   # inputs:
   # -------
   # sorted_data: matrix from the parent node, see get_metadata()
   # ordered: matrix from the parent node, data rownums, see get_metadata()
   # orders: matrix from the parent node, see get_metadata()
   # dep_var: the dependent variable from the parent node
   # colnum: the column where the cut is made
   # c: rownum in the parent node's ordered[,colnum] where the cut is made,
   #          specifically, the left child is 1:c and the right is (c+1:n_l)
   # n_l: number of observations in the parent node
   # p: number of variables/columns


   # more information on matrices, from get_metadata():
   # --------------------------------------------------
   # "sorted_data" is independently sorted columns of "data", bolted together. 
	#   The matrix is just a convenient storage format.
	
	# Elements of the "ordered" matrix are "data" row numbers. The 1st element in column j is
	#   the "data" rownum with the lowest value in data[,j], and so on.
	#   The matrix is just a convenient storage format.
	
	# Elements of the "orders" matrix are "ordered" row numbers. orders[i,] contains the
	#   rownums of "ordered" which contain i.


	# more information on how they are used:
	# --------------------------------------	
	# Elements of the "ordered" matrix are "data" row numbers. The 1st element in column j is
	#   the "data" rownum with the lowest value in data[,j], and so on. When we split a node
	#   into two child nodes, the split will be made on a particular column j. 
	#   All the data rownums from ordered[1,j] to ordered[c,j] will go into the left child
	#   node, and all those from ordered[c+1,j] to ordered[n,j] will go into the right.
	#   (n_l is the number of observations in node (or leaf) l but n is the number of rows;
	#     this is because observations not in the node are blanked with NA)
	# The next task is to separate the entries in the other columns of ordered --- because they
	#   do not have a shared meaningful order. (The columns are to be used independently, and 
	#   the matrix is merely a convenient storage format.) For this, we use "orders". Suppose that
	#   ordered[1,j]=42, so we must eliminate any cell containing 42 from the non-j columns
	#   of "ordered", to make a new "ordered" for the right child node. We can find it easily
	#   in column k!=j : it is ordered[orders[ordered[1,j],k],k]
	# Elements of the "orders" matrix are ranks of "data" values, within each column.
	#   We might also say that they are rownumbers in "ordered". 
	# "sorted_data" is independently sorted columns, bolted together. Here, also, the
	#   matrix is just a convenient storage format.


   # find data row numbers in the two child nodes
   data_rownums_left <- ordered[!is.na(ordered[,colnum]),colnum][1:c]
   data_rownums_right <- ordered[!is.na(ordered[,colnum]),colnum][(c+1):n_l]

   # use these to blank "orders"
   #orders_left <- orders[data_rownums_left,]
   #orders_right <- orders[data_rownums_right,]
   orders_left <- orders_right <- orders
   orders_left[data_rownums_right,] <- NA_integer_
   orders_right[data_rownums_left,] <- NA_integer_

   # likewise dep_var
   dep_var_left <- dep_var_right <- dep_var
   dep_var_left[data_rownums_right] <- NA
   dep_var_right[data_rownums_left] <- NA # don't assume type of dep_var


   # find left row numbers in "orders" and put them in a 2-column matrix for subsetting
   orders_subset_rownums_left <- as.vector(t(orders[data_rownums_left,]))
   orders_subset_matrix_left <- matrix(c(orders_subset_rownums_left,
   	                                    rep((1:p),length(data_rownums_left))),
                                       ncol=2)
   # orders_subset_rownums_right <- as.vector(t(orders[data_rownums_right,]))
   # orders_subset_matrix_right <- matrix(c(orders_subset_rownums_right,
   # 	                                    rep((1:p),length(data_rownums_right))),
   #                                    ncol=2)

   # use this two-column matrix to create a Boolean matrix that will select
   #   (in the right row order within each column) from "ordered" and "sorted_data"
   ordered_subset_boolean_left <- matrix(FALSE, nrow=n_all, ncol=p)
   ordered_subset_boolean_left[orders_subset_matrix_left] <- TRUE

	# right subset is exclusive and exhaustive with left
   ordered_subset_boolean_right <- !(ordered_subset_boolean_left)

   # use these to blank "ordered" and "sorted_data"
   ordered_left <- ordered_right <- ordered
   ordered_left[ordered_subset_boolean_right] <- NA_integer_
   ordered_right[ordered_subset_boolean_left] <- NA_integer_

   sorted_data_left <- sorted_data_right <- sorted_data
   sorted_data_left[ordered_subset_boolean_right] <- NA_integer_
   sorted_data_right[ordered_subset_boolean_left] <- NA_integer_


   # output:
   # list of 2 lists of matrices, for left and right child nodes
   return(list(left_matrices=list(orders_left=orders_left,
   	                            ordered_left=ordered_left,
   	                            sorted_data_left=sorted_data_left,
   	                            dep_var_left=dep_var_left),
               right_matrices=list(orders_right=orders_right,
   	                            ordered_right=ordered_right,
   	                            sorted_data_right=sorted_data_right,
   	                            dep_var_right=dep_var_right)))
}



# get numeric vector of the node IDs in current_tree which are leaves
leaves <- function(current_tree) {
	which(unlist(lapply(current_tree$nodes,
		                function(z){ z$leaf })))
}


# get list of leaves' tops matrix, bottoms matrix, density vector
store_tree <- function(current_tree) {
	leaves_ids <- leaves(current_tree)
	n_leaves <- length(leaves_ids)

	# collect first leaf's objects
	tops <- matrix(current_tree$nodes[[leaves_ids[1]]]$tops,nrow=1)
	bottoms <- matrix(current_tree$nodes[[leaves_ids[1]]]$bottoms,nrow=1)
	densities <- current_tree$nodes[[leaves_ids[1]]]$predicted

	# if more leaves exist, collect them and grow the objects
	if(n_leaves>1){
		for(i in 2:n_leaves){
			tops <- rbind(tops,matrix(current_tree$nodes[[leaves_ids[i]]]$tops,nrow=1))
			bottoms <- rbind(bottoms,matrix(current_tree$nodes[[leaves_ids[i]]]$bottoms,nrow=1))
			densities <- c(densities,current_tree$nodes[[leaves_ids[i]]]$predicted)
		}
	}
	return(list(tops=tops,
		        bottoms=bottoms,
		        densities=densities))
}















#' Create a density estimation tree
#' 
#' @param data matrix with observations in rows and variables in columns (float or integer; defined on a metric space)
#' @param goal "density" is the only valid option at present; later, "classify" and "regress" will be added to make a general-purpose CART++, and it might accept a user-specified loss function
#' @param grow_parameters a list of parameters controlling growth: min_n_l (natural number, minimum number of observations per node) and max_nodes (natural number, maximum number of nodes -- not necessarily leaves) are the only ones in use at present; set it to anything not a list (canonically, NA_integer_) to prevent growth (and pruning)
#' @param prune_parameters not in use yet; set it to anything not a list (canonically, NA_integer_) to prevent pruning
#' @param dep_var not in use yet, so set it to something small like NA_integer_; later, it will be a numeric vector of length NROW(data) containing a dependent variable
#' @param full_tree Boolean: whether to return a large list of metadata and all nodes; defalut is FALSE
#' 
#' @return a list containing a matrix of leaf tops, a matrix of leaf bottoms, a vector of leaf densities
#' @export
#' 
#' @examples
#' irisdet <- tree(data=as.matrix(iris[,-5]),
#' 	             goal="density",
#' 	             grow_parameters=list(min_n_l=5, max_nodes=100),
#' 	             prune_parameters=NA_integer_,
#' 	             dep_var=NA_integer_)
tree <- function(data, goal, grow_parameters, prune_parameters, dep_var, full_tree=FALSE) {
   # grow_parameters: add minimum loss gain per split (not top priority)
   #                       - maximum depth (not priority)
   #                       - maximum number of leaves (not nodes) (not top priority)

   # store data attributes from kudzu::normalise() -- pass through harmlessly
   center_vector <- attr(data, "center_vector")
   scale_vector_1  <- attr(data, "scale_vector_1")
   rotate_matrix <- attr(data, "rotate_matrix")
   scale_vector_2  <- attr(data, "scale_vector_2")

	# choose loss_function
	loss_function <- switch(goal,
		                    density = loss_ise,
		                    classify = loss_xentropy,
		                    regress = loss_ssq)

	current_tree <- init_tree(data, loss_function, dep_var)
	if(is.list(grow_parameters)) {
		current_tree <- grow_tree(current_tree, loss_function, grow_parameters)
		if(is.list(prune_parameters)) {
			current_tree <- prune_tree(current_tree, loss_function, prune_parameters)
		}
	}

	# add attributes back in
	center_vector -> attr(current_tree, "center_vector")
    scale_vector_1  -> attr(current_tree, "scale_vector_1")
    rotate_matrix -> attr(current_tree, "rotate_matrix")
    scale_vector_2  -> attr(current_tree, "scale_vector_2")

    # return current_tree or stored_tree list
    if(full_tree) {
    	return(current_tree)
    } else {
    	stored_tree <- store_tree(current_tree)
    	center_vector -> attr(stored_tree, "center_vector")
    	scale_vector_1  -> attr(stored_tree, "scale_vector_1")
    	rotate_matrix -> attr(stored_tree, "rotate_matrix")
    	scale_vector_2  -> attr(stored_tree, "scale_vector_2")

    	return(stored_tree)
    }
}




