File: xgb.plot.deepness.R

package info (click to toggle)
xgboost 1.2.1-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 8,472 kB
  • sloc: cpp: 32,873; python: 12,641; java: 2,926; xml: 1,024; sh: 662; ansic: 448; makefile: 306; javascript: 19
file content (150 lines) | stat: -rw-r--r-- 6,227 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#' Plot model trees deepness
#'
#' Visualizes distributions related to depth of tree leafs.
#' \code{xgb.plot.deepness} uses base R graphics, while \code{xgb.ggplot.deepness} uses the ggplot backend.
#'
#' @param model either an \code{xgb.Booster} model generated by the \code{xgb.train} function
#'        or a data.table result of the \code{xgb.model.dt.tree} function.
#' @param plot (base R barplot) whether a barplot should be produced.
#'        If FALSE, only a data.table is returned.
#' @param which which distribution to plot (see details).
#' @param ... other parameters passed to \code{barplot} or \code{plot}.
#'
#' @details
#'
#' When \code{which="2x1"}, two distributions with respect to the leaf depth
#' are plotted on top of each other:
#' \itemize{
#'  \item the distribution of the number of leafs in a tree model at a certain depth;
#'  \item the distribution of average weighted number of observations ("cover")
#'        ending up in leafs at certain depth.
#' }
#' Those could be helpful in determining sensible ranges of the \code{max_depth}
#' and \code{min_child_weight} parameters.
#'
#' When \code{which="max.depth"} or \code{which="med.depth"}, plots of either maximum or median depth
#' per tree with respect to tree number are created. And \code{which="med.weight"} allows to see how
#' a tree's median absolute leaf weight changes through the iterations.
#'
#' This function was inspired by the blog post
#' \url{https://github.com/aysent/random-forest-leaf-visualization}.
#'
#' @return
#'
#' Other than producing plots (when \code{plot=TRUE}), the \code{xgb.plot.deepness} function
#' silently returns a processed data.table where each row corresponds to a terminal leaf in a tree model,
#' and contains information about leaf's depth, cover, and weight (which is used in calculating predictions).
#'
#' The \code{xgb.ggplot.deepness} silently returns either a list of two ggplot graphs when \code{which="2x1"}
#' or a single ggplot graph for the other \code{which} options.
#'
#' @seealso
#'
#' \code{\link{xgb.train}}, \code{\link{xgb.model.dt.tree}}.
#'
#' @examples
#'
#' data(agaricus.train, package='xgboost')
#'
#' # Change max_depth to a higher number to get a more significant result
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 6,
#'                eta = 0.1, nthread = 2, nrounds = 50, objective = "binary:logistic",
#'                subsample = 0.5, min_child_weight = 2)
#'
#' xgb.plot.deepness(bst)
#' xgb.ggplot.deepness(bst)
#'
#' xgb.plot.deepness(bst, which='max.depth', pch=16, col=rgb(0,0,1,0.3), cex=2)
#'
#' xgb.plot.deepness(bst, which='med.weight', pch=16, col=rgb(0,0,1,0.3), cex=2)
#'
#' @rdname xgb.plot.deepness
#' @export
xgb.plot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med.depth", "med.weight"),
                              plot = TRUE, ...) {

  if (!(inherits(model, "xgb.Booster") || is.data.table(model)))
    stop("model: Has to be either an xgb.Booster model generaged by the xgb.train function\n",
         "or a data.table result of the xgb.importance function")

  if (!requireNamespace("igraph", quietly = TRUE))
    stop("igraph package is required for plotting the graph deepness.", call. = FALSE)

  which <- match.arg(which)

  dt_tree <- model
  if (inherits(model, "xgb.Booster"))
    dt_tree <- xgb.model.dt.tree(model = model)

  if (!all(c("Feature", "Tree", "ID", "Yes", "No", "Cover") %in% colnames(dt_tree)))
    stop("Model tree columns are not as expected!\n",
         "  Note that this function works only for tree models.")

  dt_depths <- merge(get.leaf.depth(dt_tree), dt_tree[, .(ID, Cover, Weight = Quality)], by = "ID")
  setkeyv(dt_depths, c("Tree", "ID"))
  # count by depth levels, and also calculate average cover at a depth
  dt_summaries <- dt_depths[, .(.N, Cover = mean(Cover)), Depth]
  setkey(dt_summaries, "Depth")

  if (plot) {
    if (which == "2x1") {
      op <- par(no.readonly = TRUE)
      par(mfrow = c(2, 1),
          oma = c(3, 1, 3, 1) + 0.1,
          mar = c(1, 4, 1, 0) + 0.1)

      dt_summaries[, barplot(N, border = NA, ylab = 'Number of leafs', ...)]

      dt_summaries[, barplot(Cover, border = NA, ylab = "Weighted cover", names.arg = Depth, ...)]

      title("Model complexity", xlab = "Leaf depth", outer = TRUE, line = 1)
      par(op)
    } else if (which == "max.depth") {
      dt_depths[, max(Depth), Tree][
                , plot(jitter(V1, amount = 0.1) ~ Tree, ylab = 'Max tree leaf depth', xlab = "tree #", ...)]
    } else if (which == "med.depth") {
      dt_depths[, median(as.numeric(Depth)), Tree][
                , plot(jitter(V1, amount = 0.1) ~ Tree, ylab = 'Median tree leaf depth', xlab = "tree #", ...)]
    } else if (which == "med.weight") {
      dt_depths[, median(abs(Weight)), Tree][
                , plot(V1 ~ Tree, ylab = 'Median absolute leaf weight', xlab = "tree #", ...)]
    }
  }
  invisible(dt_depths)
}

# Extract path depths from root to leaf
# from data.table containing the nodes and edges of the trees.
# internal utility function
get.leaf.depth <- function(dt_tree) {
  # extract tree graph's edges
  dt_edges <- rbindlist(list(
      dt_tree[Feature != "Leaf", .(ID, To = Yes, Tree)],
      dt_tree[Feature != "Leaf", .(ID, To = No, Tree)]
    ))
  # whether "To" is a leaf:
  dt_edges <-
    merge(dt_edges,
          dt_tree[Feature == "Leaf", .(ID, Leaf = TRUE)],
          all.x = TRUE, by.x = "To", by.y = "ID")
  dt_edges[is.na(Leaf), Leaf := FALSE]

  dt_edges[, {
    graph <- igraph::graph_from_data_frame(.SD[, .(ID, To)])
    # min(ID) in a tree is a root node
    paths_tmp <- igraph::shortest_paths(graph, from = min(ID), to = To[Leaf == TRUE])
    # list of paths to each leaf in a tree
    paths <- lapply(paths_tmp$vpath, names)
    # combine into a resulting path lengths table for a tree
    data.table(Depth = sapply(paths, length), ID = To[Leaf == TRUE])
  }, by = Tree]
}

# Avoid error messages during CRAN check.
# The reason is that these variables are never declared
# They are mainly column names inferred by Data.table...
globalVariables(
  c(
    ".N", "N", "Depth", "Quality", "Cover", "Tree", "ID", "Yes", "No", "Feature", "Leaf", "Weight"
  )
)