File: xgb.plot.importance.R

package info (click to toggle)
xgboost 1.2.1-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 8,472 kB
  • sloc: cpp: 32,873; python: 12,641; java: 2,926; xml: 1,024; sh: 662; ansic: 448; makefile: 306; javascript: 19
file content (125 lines) | stat: -rw-r--r-- 5,589 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#' Plot feature importance as a bar graph
#'
#' Represents previously calculated feature importance as a bar graph.
#' \code{xgb.plot.importance} uses base R graphics, while \code{xgb.ggplot.importance} uses the ggplot backend.
#'
#' @param importance_matrix a \code{data.table} returned by \code{\link{xgb.importance}}.
#' @param top_n maximal number of top features to include into the plot.
#' @param measure the name of importance measure to plot.
#'        When \code{NULL}, 'Gain' would be used for trees and 'Weight' would be used for gblinear.
#' @param rel_to_first whether importance values should be represented as relative to the highest ranked feature.
#'        See Details.
#' @param left_margin (base R barplot) allows to adjust the left margin size to fit feature names.
#'        When it is NULL, the existing \code{par('mar')} is used.
#' @param cex (base R barplot) passed as \code{cex.names} parameter to \code{barplot}.
#' @param plot (base R barplot) whether a barplot should be produced.
#'        If FALSE, only a data.table is returned.
#' @param n_clusters (ggplot only) a \code{numeric} vector containing the min and the max range
#'        of the possible number of clusters of bars.
#' @param ... other parameters passed to \code{barplot} (except horiz, border, cex.names, names.arg, and las).
#'
#' @details
#' The graph represents each feature as a horizontal bar of length proportional to the importance of a feature.
#' Features are shown ranked in a decreasing importance order.
#' It works for importances from both \code{gblinear} and \code{gbtree} models.
#'
#' When \code{rel_to_first = FALSE}, the values would be plotted as they were in \code{importance_matrix}.
#' For gbtree model, that would mean being normalized to the total of 1
#' ("what is feature's importance contribution relative to the whole model?").
#' For linear models, \code{rel_to_first = FALSE} would show actual values of the coefficients.
#' Setting \code{rel_to_first = TRUE} allows to see the picture from the perspective of
#' "what is feature's importance contribution relative to the most important feature?"
#'
#' The ggplot-backend method also performs 1-D clustering of the importance values,
#' with bar colors corresponding to different clusters that have somewhat similar importance values.
#'
#' @return
#' The \code{xgb.plot.importance} function creates a \code{barplot} (when \code{plot=TRUE})
#' and silently returns a processed data.table with \code{n_top} features sorted by importance.
#'
#' The \code{xgb.ggplot.importance} function returns a ggplot graph which could be customized afterwards.
#' E.g., to change the title of the graph, add \code{+ ggtitle("A GRAPH NAME")} to the result.
#'
#' @seealso
#' \code{\link[graphics]{barplot}}.
#'
#' @examples
#' data(agaricus.train)
#'
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 3,
#'                eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
#'
#' importance_matrix <- xgb.importance(colnames(agaricus.train$data), model = bst)
#'
#' xgb.plot.importance(importance_matrix, rel_to_first = TRUE, xlab = "Relative importance")
#'
#' (gg <- xgb.ggplot.importance(importance_matrix, measure = "Frequency", rel_to_first = TRUE))
#' gg + ggplot2::ylab("Frequency")
#'
#' @rdname xgb.plot.importance
#' @export
xgb.plot.importance <- function(importance_matrix = NULL, top_n = NULL, measure = NULL,
                                rel_to_first = FALSE, left_margin = 10, cex = NULL, plot = TRUE, ...) {
  check.deprecation(...)
  if (!is.data.table(importance_matrix))  {
    stop("importance_matrix: must be a data.table")
  }

  imp_names <- colnames(importance_matrix)
  if (is.null(measure)) {
    if (all(c("Feature", "Gain") %in% imp_names)) {
      measure <- "Gain"
    } else if (all(c("Feature", "Weight") %in% imp_names)) {
      measure <- "Weight"
    } else {
      stop("Importance matrix column names are not as expected!")
    }
  } else {
    if (!measure %in% imp_names)
      stop("Invalid `measure`")
    if (!"Feature" %in% imp_names)
      stop("Importance matrix column names are not as expected!")
  }

  # also aggregate, just in case when the values were not yet summed up by feature
  importance_matrix <- importance_matrix[, Importance := sum(get(measure)), by = Feature]

  # make sure it's ordered
  importance_matrix <- importance_matrix[order(-abs(Importance))]

  if (!is.null(top_n)) {
    top_n <- min(top_n, nrow(importance_matrix))
    importance_matrix <- head(importance_matrix, top_n)
  }
  if (rel_to_first) {
    importance_matrix[, Importance := Importance / max(abs(Importance))]
  }
  if (is.null(cex)) {
    cex <- 2.5 / log2(1 + nrow(importance_matrix))
  }

  if (plot) {
    op <- par(no.readonly = TRUE)
    mar <- op$mar
    if (!is.null(left_margin))
      mar[2] <- left_margin
    par(mar = mar)

    # reverse the order of rows to have the highest ranked at the top
    importance_matrix[nrow(importance_matrix):1,
                      barplot(Importance, horiz = TRUE, border = NA, cex.names = cex,
                              names.arg = Feature, las = 1, ...)]
    grid(NULL, NA)
    # redraw over the grid
    importance_matrix[nrow(importance_matrix):1,
                      barplot(Importance, horiz = TRUE, border = NA, add = TRUE)]
    par(op)
  }

  invisible(importance_matrix)
}

# Avoid error messages during CRAN check.
# The reason is that these variables are never declared
# They are mainly column names inferred by Data.table...
globalVariables(c("Feature", "Importance"))