File: xgb.plot.multi.trees.R

package info (click to toggle)
xgboost 1.2.1-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 8,472 kB
  • sloc: cpp: 32,873; python: 12,641; java: 2,926; xml: 1,024; sh: 662; ansic: 448; makefile: 306; javascript: 19
file content (148 lines) | stat: -rw-r--r-- 6,214 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#' Project all trees on one tree and plot it
#'
#' Visualization of the ensemble of trees as a single collective unit.
#'
#' @param model produced by the \code{xgb.train} function.
#' @param feature_names names of each feature as a \code{character} vector.
#' @param features_keep number of features to keep in each position of the multi trees.
#' @param plot_width width in pixels of the graph to produce
#' @param plot_height height in pixels of the graph to produce
#' @param render a logical flag for whether the graph should be rendered (see Value).
#' @param ... currently not used
#'
#' @details
#'
#' This function tries to capture the complexity of a gradient boosted tree model
#' in a cohesive way by compressing an ensemble of trees into a single tree-graph representation.
#' The goal is to improve the interpretability of a model generally seen as black box.
#'
#' Note: this function is applicable to tree booster-based models only.
#'
#' It takes advantage of the fact that the shape of a binary tree is only defined by
#' its depth (therefore, in a boosting model, all trees have similar shape).
#'
#' Moreover, the trees tend to reuse the same features.
#'
#' The function projects each tree onto one, and keeps for each position the
#' \code{features_keep} first features (based on the Gain per feature measure).
#'
#' This function is inspired by this blog post:
#' \url{https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/}
#'
#' @return
#'
#' When \code{render = TRUE}:
#' returns a rendered graph object which is an \code{htmlwidget} of class \code{grViz}.
#' Similar to ggplot objects, it needs to be printed to see it when not running from command line.
#'
#' When \code{render = FALSE}:
#' silently returns a graph object which is of DiagrammeR's class \code{dgr_graph}.
#' This could be useful if one wants to modify some of the graph attributes
#' before rendering the graph with \code{\link[DiagrammeR]{render_graph}}.
#'
#' @examples
#'
#' data(agaricus.train, package='xgboost')
#'
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 15,
#'                eta = 1, nthread = 2, nrounds = 30, objective = "binary:logistic",
#'                min_child_weight = 50, verbose = 0)
#'
#' p <- xgb.plot.multi.trees(model = bst, features_keep = 3)
#' print(p)
#'
#' \dontrun{
#' # Below is an example of how to save this plot to a file.
#' # Note that for `export_graph` to work, the DiagrammeRsvg and rsvg packages must also be installed.
#' library(DiagrammeR)
#' gr <- xgb.plot.multi.trees(model=bst, features_keep = 3, render=FALSE)
#' export_graph(gr, 'tree.pdf', width=1500, height=600)
#' }
#'
#' @export
xgb.plot.multi.trees <- function(model, feature_names = NULL, features_keep = 5, plot_width = NULL, plot_height = NULL,
                                 render = TRUE, ...){
  check.deprecation(...)
  tree.matrix <- xgb.model.dt.tree(feature_names = feature_names, model = model)

  # first number of the path represents the tree, then the following numbers are related to the path to follow
  # root init
  root.nodes <- tree.matrix[stri_detect_regex(ID, "\\d+-0"), ID]
  tree.matrix[ID %in% root.nodes, abs.node.position := root.nodes]

  precedent.nodes <- root.nodes

  while (tree.matrix[, sum(is.na(abs.node.position))] > 0) {
    yes.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(Yes)]
    no.row.nodes <- tree.matrix[abs.node.position %in% precedent.nodes & !is.na(No)]
    yes.nodes.abs.pos <- yes.row.nodes[, abs.node.position] %>% paste0("_0")
    no.nodes.abs.pos <- no.row.nodes[, abs.node.position] %>% paste0("_1")

    tree.matrix[ID %in% yes.row.nodes[, Yes], abs.node.position := yes.nodes.abs.pos]
    tree.matrix[ID %in% no.row.nodes[, No], abs.node.position := no.nodes.abs.pos]
    precedent.nodes <- c(yes.nodes.abs.pos, no.nodes.abs.pos)
  }

  tree.matrix[!is.na(Yes), Yes := paste0(abs.node.position, "_0")]
  tree.matrix[!is.na(No), No := paste0(abs.node.position, "_1")]

  remove.tree <- . %>% stri_replace_first_regex(pattern = "^\\d+-", replacement = "")

  tree.matrix[, `:=`(abs.node.position = remove.tree(abs.node.position),
                     Yes = remove.tree(Yes),
                     No = remove.tree(No))]

  nodes.dt <- tree.matrix[
        , .(Quality = sum(Quality))
        , by = .(abs.node.position, Feature)
      ][, .(Text = paste0(Feature[1:min(length(Feature), features_keep)],
                          " (",
                          format(Quality[1:min(length(Quality), features_keep)], digits = 5),
                          ")") %>%
                   paste0(collapse = "\n"))
        , by = abs.node.position]

  edges.dt <- tree.matrix[Feature != "Leaf", .(abs.node.position, Yes)] %>%
    list(tree.matrix[Feature != "Leaf", .(abs.node.position, No)]) %>%
    rbindlist() %>%
    setnames(c("From", "To")) %>%
    .[, .N, .(From, To)] %>%
    .[, N := NULL]

  nodes <- DiagrammeR::create_node_df(
    n = nrow(nodes.dt),
    label = nodes.dt[, Text]
  )

  edges <- DiagrammeR::create_edge_df(
    from = match(edges.dt[, From], nodes.dt[, abs.node.position]),
    to = match(edges.dt[, To], nodes.dt[, abs.node.position]),
    rel = "leading_to")

  graph <- DiagrammeR::create_graph(
      nodes_df = nodes,
      edges_df = edges,
      attr_theme = NULL
      ) %>%
    DiagrammeR::add_global_graph_attrs(
      attr_type = "graph",
      attr  = c("layout", "rankdir"),
      value = c("dot", "LR")
      ) %>%
    DiagrammeR::add_global_graph_attrs(
      attr_type = "node",
      attr  = c("color", "fillcolor", "style", "shape", "fontname"),
      value = c("DimGray", "beige", "filled", "rectangle", "Helvetica")
      ) %>%
    DiagrammeR::add_global_graph_attrs(
      attr_type = "edge",
      attr  = c("color", "arrowsize", "arrowhead", "fontname"),
      value = c("DimGray", "1.5", "vee", "Helvetica"))

  if (!render) return(invisible(graph))

  DiagrammeR::render_graph(graph, width = plot_width, height = plot_height)
}

globalVariables(c(".N", "N", "From", "To", "Text", "Feature", "no.nodes.abs.pos",
                  "ID", "Yes", "No", "Tree", "yes.nodes.abs.pos", "abs.node.position"))