1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
|
#' Plot model trees deepness
#'
#' Visualizes distributions related to depth of tree leafs.
#' \code{xgb.plot.deepness} uses base R graphics, while \code{xgb.ggplot.deepness} uses the ggplot backend.
#'
#' @param model either an \code{xgb.Booster} model generated by the \code{xgb.train} function
#' or a data.table result of the \code{xgb.model.dt.tree} function.
#' @param plot (base R barplot) whether a barplot should be produced.
#' If FALSE, only a data.table is returned.
#' @param which which distribution to plot (see details).
#' @param ... other parameters passed to \code{barplot} or \code{plot}.
#'
#' @details
#'
#' When \code{which="2x1"}, two distributions with respect to the leaf depth
#' are plotted on top of each other:
#' \itemize{
#' \item the distribution of the number of leafs in a tree model at a certain depth;
#' \item the distribution of average weighted number of observations ("cover")
#' ending up in leafs at certain depth.
#' }
#' Those could be helpful in determining sensible ranges of the \code{max_depth}
#' and \code{min_child_weight} parameters.
#'
#' When \code{which="max.depth"} or \code{which="med.depth"}, plots of either maximum or median depth
#' per tree with respect to tree number are created. And \code{which="med.weight"} allows to see how
#' a tree's median absolute leaf weight changes through the iterations.
#'
#' This function was inspired by the blog post
#' \url{https://github.com/aysent/random-forest-leaf-visualization}.
#'
#' @return
#'
#' Other than producing plots (when \code{plot=TRUE}), the \code{xgb.plot.deepness} function
#' silently returns a processed data.table where each row corresponds to a terminal leaf in a tree model,
#' and contains information about leaf's depth, cover, and weight (which is used in calculating predictions).
#'
#' The \code{xgb.ggplot.deepness} silently returns either a list of two ggplot graphs when \code{which="2x1"}
#' or a single ggplot graph for the other \code{which} options.
#'
#' @seealso
#'
#' \code{\link{xgb.train}}, \code{\link{xgb.model.dt.tree}}.
#'
#' @examples
#'
#' data(agaricus.train, package='xgboost')
#'
#' # Change max_depth to a higher number to get a more significant result
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 6,
#' eta = 0.1, nthread = 2, nrounds = 50, objective = "binary:logistic",
#' subsample = 0.5, min_child_weight = 2)
#'
#' xgb.plot.deepness(bst)
#' xgb.ggplot.deepness(bst)
#'
#' xgb.plot.deepness(bst, which='max.depth', pch=16, col=rgb(0,0,1,0.3), cex=2)
#'
#' xgb.plot.deepness(bst, which='med.weight', pch=16, col=rgb(0,0,1,0.3), cex=2)
#'
#' @rdname xgb.plot.deepness
#' @export
xgb.plot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med.depth", "med.weight"),
plot = TRUE, ...) {
if (!(inherits(model, "xgb.Booster") || is.data.table(model)))
stop("model: Has to be either an xgb.Booster model generaged by the xgb.train function\n",
"or a data.table result of the xgb.importance function")
if (!requireNamespace("igraph", quietly = TRUE))
stop("igraph package is required for plotting the graph deepness.", call. = FALSE)
which <- match.arg(which)
dt_tree <- model
if (inherits(model, "xgb.Booster"))
dt_tree <- xgb.model.dt.tree(model = model)
if (!all(c("Feature", "Tree", "ID", "Yes", "No", "Cover") %in% colnames(dt_tree)))
stop("Model tree columns are not as expected!\n",
" Note that this function works only for tree models.")
dt_depths <- merge(get.leaf.depth(dt_tree), dt_tree[, .(ID, Cover, Weight = Quality)], by = "ID")
setkeyv(dt_depths, c("Tree", "ID"))
# count by depth levels, and also calculate average cover at a depth
dt_summaries <- dt_depths[, .(.N, Cover = mean(Cover)), Depth]
setkey(dt_summaries, "Depth")
if (plot) {
if (which == "2x1") {
op <- par(no.readonly = TRUE)
par(mfrow = c(2, 1),
oma = c(3, 1, 3, 1) + 0.1,
mar = c(1, 4, 1, 0) + 0.1)
dt_summaries[, barplot(N, border = NA, ylab = 'Number of leafs', ...)]
dt_summaries[, barplot(Cover, border = NA, ylab = "Weighted cover", names.arg = Depth, ...)]
title("Model complexity", xlab = "Leaf depth", outer = TRUE, line = 1)
par(op)
} else if (which == "max.depth") {
dt_depths[, max(Depth), Tree][
, plot(jitter(V1, amount = 0.1) ~ Tree, ylab = 'Max tree leaf depth', xlab = "tree #", ...)]
} else if (which == "med.depth") {
dt_depths[, median(as.numeric(Depth)), Tree][
, plot(jitter(V1, amount = 0.1) ~ Tree, ylab = 'Median tree leaf depth', xlab = "tree #", ...)]
} else if (which == "med.weight") {
dt_depths[, median(abs(Weight)), Tree][
, plot(V1 ~ Tree, ylab = 'Median absolute leaf weight', xlab = "tree #", ...)]
}
}
invisible(dt_depths)
}
# Extract path depths from root to leaf
# from data.table containing the nodes and edges of the trees.
# internal utility function
get.leaf.depth <- function(dt_tree) {
# extract tree graph's edges
dt_edges <- rbindlist(list(
dt_tree[Feature != "Leaf", .(ID, To = Yes, Tree)],
dt_tree[Feature != "Leaf", .(ID, To = No, Tree)]
))
# whether "To" is a leaf:
dt_edges <-
merge(dt_edges,
dt_tree[Feature == "Leaf", .(ID, Leaf = TRUE)],
all.x = TRUE, by.x = "To", by.y = "ID")
dt_edges[is.na(Leaf), Leaf := FALSE]
dt_edges[, {
graph <- igraph::graph_from_data_frame(.SD[, .(ID, To)])
# min(ID) in a tree is a root node
paths_tmp <- igraph::shortest_paths(graph, from = min(ID), to = To[Leaf == TRUE])
# list of paths to each leaf in a tree
paths <- lapply(paths_tmp$vpath, names)
# combine into a resulting path lengths table for a tree
data.table(Depth = sapply(paths, length), ID = To[Leaf == TRUE])
}, by = Tree]
}
# Avoid error messages during CRAN check.
# The reason is that these variables are never declared
# They are mainly column names inferred by Data.table...
globalVariables(
c(
".N", "N", "Depth", "Quality", "Cover", "Tree", "ID", "Yes", "No", "Feature", "Leaf", "Weight"
)
)
|