Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exponentially faster tree depth #38

Merged
merged 5 commits into from
Mar 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions R/min_depth_distribution.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@ calculate_tree_depth <- function(frame){
stop("The data frame has to contain columns called 'right daughter' and 'left daughter'!
It should be a product of the function getTree(..., labelVar = T).")
}
# Both child values of leaf nodes are 0, i.e., lower than min(node_id)
frame[["depth"]] <- calculate_tree_depth_(
node_id = seq_len(nrow(frame)),
left_child = frame[["left daughter"]],
right_child = frame[["right daughter"]]
frame[, c("left daughter", "right daughter")]
)
return(frame)
}
Expand All @@ -19,22 +16,30 @@ calculate_tree_depth_ranger <- function(frame){
stop("The data frame has to contain columns called 'rightChild' and 'leftChild'!
It should be a product of the function ranger::treeInfo().")
}
# Child nodes are zero based, so we increase them by 1
frame[["depth"]] <- calculate_tree_depth_(
node_id = frame[["nodeID"]],
left_child = frame[["leftChild"]],
right_child = frame[["rightChild"]]
frame[, c("leftChild", "rightChild")] + 1
)
return(frame)
}

# Internal function used to determine the depth of each node
calculate_tree_depth_ <- function(node_id, left_child, right_child) {
n <- length(node_id)
depth <- numeric(n)
for (i in 2:n) {
parent_node <- left_child %in% node_id[i] | right_child %in% node_id[i]
depth[i] <- depth[parent_node] + 1
# Internal function used to determine the depth of each node.
# The input is a data.frame with left and right child nodes in 1:nrow(childs).
calculate_tree_depth_ <- function(childs) {
childs <- as.matrix(childs)
n <- nrow(childs)
depth <- rep(NA, times = n)
j <- depth[1L] <- 0
ix <- 1L # current nodes, initialized with root node index

# j loops over tree depth
while(anyNA(depth) && j < n) { # The second condition is never used
ix <- as.integer(childs[ix, ])
ix <- ix[!is.na(ix) & ix >= 1L] # leaf nodes do not have childs
j <- j + 1
depth[ix] <- j
}

return(depth)
}

Expand Down
Loading