File: xgb.plot.shap.R

package info (click to toggle)
xgboost 1.2.1-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 8,472 kB
  • sloc: cpp: 32,873; python: 12,641; java: 2,926; xml: 1,024; sh: 662; ansic: 448; makefile: 306; javascript: 19
file content (218 lines) | stat: -rw-r--r-- 10,615 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
#' SHAP contribution dependency plots
#'
#' Visualizing the SHAP feature contribution to prediction dependencies on feature value.
#'
#' @param data data as a \code{matrix} or \code{dgCMatrix}.
#' @param shap_contrib a matrix of SHAP contributions that was computed earlier for the above
#'          \code{data}. When it is NULL, it is computed internally using \code{model} and \code{data}.
#' @param features a vector of either column indices or of feature names to plot. When it is NULL,
#'          feature importance is calculated, and \code{top_n} high ranked features are taken.
#' @param top_n when \code{features} is NULL, top_n [1, 100] most important features in a model are taken.
#' @param model an \code{xgb.Booster} model. It has to be provided when either \code{shap_contrib}
#'          or \code{features} is missing.
#' @param trees passed to \code{\link{xgb.importance}} when \code{features = NULL}.
#' @param target_class is only relevant for multiclass models. When it is set to a 0-based class index,
#'          only SHAP contributions for that specific class are used.
#'          If it is not set, SHAP importances are averaged over all classes.
#' @param approxcontrib passed to \code{\link{predict.xgb.Booster}} when \code{shap_contrib = NULL}.
#' @param subsample a random fraction of data points to use for plotting. When it is NULL,
#'          it is set so that up to 100K data points are used.
#' @param n_col a number of columns in a grid of plots.
#' @param col color of the scatterplot markers.
#' @param pch scatterplot marker.
#' @param discrete_n_uniq a maximal number of unique values in a feature to consider it as discrete.
#' @param discrete_jitter an \code{amount} parameter of jitter added to discrete features' positions.
#' @param ylab a y-axis label in 1D plots.
#' @param plot_NA whether the contributions of cases with missing values should also be plotted.
#' @param col_NA a color of marker for missing value contributions.
#' @param pch_NA a marker type for NA values.
#' @param pos_NA a relative position of the x-location where NA values are shown:
#'          \code{min(x) + (max(x) - min(x)) * pos_NA}.
#' @param plot_loess whether to plot loess-smoothed curves. The smoothing is only done for features with
#'          more than 5 distinct values.
#' @param col_loess a color to use for the loess curves.
#' @param span_loess the \code{span} parameter in \code{\link[stats]{loess}}'s call.
#' @param which whether to do univariate or bivariate plotting. NOTE: only 1D is implemented so far.
#' @param plot whether a plot should be drawn. If FALSE, only a lits of matrices is returned.
#' @param ... other parameters passed to \code{plot}.
#'
#' @details
#'
#' These scatterplots represent how SHAP feature contributions depend of feature values.
#' The similarity to partial dependency plots is that they also give an idea for how feature values
#' affect predictions. However, in partial dependency plots, we usually see marginal dependencies
#' of model prediction on feature value, while SHAP contribution dependency plots display the estimated
#' contributions of a feature to model prediction for each individual case.
#'
#' When \code{plot_loess = TRUE} is set, feature values are rounded to 3 significant digits and
#' weighted LOESS is computed and plotted, where weights are the numbers of data points
#' at each rounded value.
#'
#' Note: SHAP contributions are shown on the scale of model margin. E.g., for a logistic binomial objective,
#' the margin is prediction before a sigmoidal transform into probability-like values.
#' Also, since SHAP stands for "SHapley Additive exPlanation" (model prediction = sum of SHAP
#' contributions for all features + bias), depending on the objective used, transforming SHAP
#' contributions for a feature from the marginal to the prediction space is not necessarily
#' a meaningful thing to do.
#'
#' @return
#'
#' In addition to producing plots (when \code{plot=TRUE}), it silently returns a list of two matrices:
#' \itemize{
#'  \item \code{data} the values of selected features;
#'  \item \code{shap_contrib} the contributions of selected features.
#' }
#'
#' @references
#'
#' Scott M. Lundberg, Su-In Lee, "A Unified Approach to Interpreting Model Predictions", NIPS Proceedings 2017, \url{https://arxiv.org/abs/1705.07874}
#'
#' Scott M. Lundberg, Su-In Lee, "Consistent feature attribution for tree ensembles", \url{https://arxiv.org/abs/1706.06060}
#'
#' @examples
#'
#' data(agaricus.train, package='xgboost')
#' data(agaricus.test, package='xgboost')
#'
#' bst <- xgboost(agaricus.train$data, agaricus.train$label, nrounds = 50,
#'                eta = 0.1, max_depth = 3, subsample = .5,
#'                method = "hist", objective = "binary:logistic", nthread = 2, verbose = 0)
#'
#' xgb.plot.shap(agaricus.test$data, model = bst, features = "odor=none")
#' contr <- predict(bst, agaricus.test$data, predcontrib = TRUE)
#' xgb.plot.shap(agaricus.test$data, contr, model = bst, top_n = 12, n_col = 3)
#'
#' # multiclass example - plots for each class separately:
#' nclass <- 3
#' nrounds <- 20
#' x <- as.matrix(iris[, -5])
#' set.seed(123)
#' is.na(x[sample(nrow(x) * 4, 30)]) <- TRUE # introduce some missing values
#' mbst <- xgboost(data = x, label = as.numeric(iris$Species) - 1, nrounds = nrounds,
#'                 max_depth = 2, eta = 0.3, subsample = .5, nthread = 2,
#'                 objective = "multi:softprob", num_class = nclass, verbose = 0)
#' trees0 <- seq(from=0, by=nclass, length.out=nrounds)
#' col <- rgb(0, 0, 1, 0.5)
#' xgb.plot.shap(x, model = mbst, trees = trees0, target_class = 0, top_n = 4,
#'               n_col = 2, col = col, pch = 16, pch_NA = 17)
#' xgb.plot.shap(x, model = mbst, trees = trees0 + 1, target_class = 1, top_n = 4,
#'               n_col = 2, col = col, pch = 16, pch_NA = 17)
#' xgb.plot.shap(x, model = mbst, trees = trees0 + 2, target_class = 2, top_n = 4,
#'               n_col = 2, col = col, pch = 16, pch_NA = 17)
#'
#' @rdname xgb.plot.shap
#' @export
xgb.plot.shap <- function(data, shap_contrib = NULL, features = NULL, top_n = 1, model = NULL,
                          trees = NULL, target_class = NULL, approxcontrib = FALSE,
                          subsample = NULL, n_col = 1, col = rgb(0, 0, 1, 0.2), pch = '.',
                          discrete_n_uniq = 5, discrete_jitter = 0.01, ylab = "SHAP",
                          plot_NA = TRUE, col_NA = rgb(0.7, 0, 1, 0.6), pch_NA = '.', pos_NA = 1.07,
                          plot_loess = TRUE, col_loess = 2, span_loess = 0.5,
                          which = c("1d", "2d"), plot = TRUE, ...) {

  if (!is.matrix(data) && !inherits(data, "dgCMatrix"))
    stop("data: must be either matrix or dgCMatrix")

  if (is.null(shap_contrib) && (is.null(model) || !inherits(model, "xgb.Booster")))
    stop("when shap_contrib is not provided, one must provide an xgb.Booster model")

  if (is.null(features) && (is.null(model) || !inherits(model, "xgb.Booster")))
    stop("when features are not provided, one must provide an xgb.Booster model to rank the features")

  if (!is.null(shap_contrib) &&
      (!is.matrix(shap_contrib) || nrow(shap_contrib) != nrow(data) || ncol(shap_contrib) != ncol(data) + 1))
    stop("shap_contrib is not compatible with the provided data")

  nsample <- if (is.null(subsample)) min(100000, nrow(data)) else as.integer(subsample * nrow(data))
  idx <- sample(1:nrow(data), nsample)
  data <- data[idx, ]

  if (is.null(shap_contrib)) {
    shap_contrib <- predict(model, data, predcontrib = TRUE, approxcontrib = approxcontrib)
  } else {
    shap_contrib <- shap_contrib[idx, ]
  }

  which <- match.arg(which)
  if (which == "2d")
    stop("2D plots are not implemented yet")

  if (is.null(features)) {
    imp <- xgb.importance(model = model, trees = trees)
    top_n <- as.integer(top_n[1])
    if (top_n < 1 && top_n > 100)
      stop("top_n: must be an integer within [1, 100]")
    features <- imp$Feature[1:min(top_n, NROW(imp))]
  }

  if (is.character(features)) {
    if (is.null(colnames(data)))
      stop("Either provide `data` with column names or provide `features` as column indices")
    features <- match(features, colnames(data))
  }

  if (n_col > length(features)) n_col <- length(features)

  if (is.list(shap_contrib)) { # multiclass: either choose a class or merge
    shap_contrib <- if (!is.null(target_class)) shap_contrib[[target_class + 1]]
                    else Reduce("+", lapply(shap_contrib, abs))
  }

  shap_contrib <- shap_contrib[, features, drop = FALSE]
  data <- data[, features, drop = FALSE]
  cols <- colnames(data)
  if (is.null(cols)) cols <- colnames(shap_contrib)
  if (is.null(cols)) cols <- paste0('X', 1:ncol(data))
  colnames(data) <- cols
  colnames(shap_contrib) <- cols

  if (plot && which == "1d") {
    op <- par(mfrow = c(ceiling(length(features) / n_col), n_col),
              oma = c(0, 0, 0, 0) + 0.2,
              mar = c(3.5, 3.5, 0, 0) + 0.1,
              mgp = c(1.7, 0.6, 0))
    for (f in cols) {
      ord <- order(data[, f])
      x <- data[, f][ord]
      y <- shap_contrib[, f][ord]
      x_lim <- range(x, na.rm = TRUE)
      y_lim <- range(y, na.rm = TRUE)
      do_na <- plot_NA && any(is.na(x))
      if (do_na) {
        x_range <- diff(x_lim)
        loc_na <- min(x, na.rm = TRUE) + x_range * pos_NA
        x_lim <- range(c(x_lim, loc_na))
      }
      x_uniq <- unique(x)
      x2plot <- x
      # add small jitter for discrete features with <= 5 distinct values
      if (length(x_uniq) <= discrete_n_uniq)
        x2plot <- jitter(x, amount = discrete_jitter * min(diff(x_uniq), na.rm = TRUE))
      plot(x2plot, y, pch = pch, xlab = f, col = col, xlim = x_lim, ylim = y_lim, ylab = ylab, ...)
      grid()
      if (plot_loess) {
        # compress x to 3 digits, and mean-aggredate y
        zz <- data.table(x = signif(x, 3), y)[, .(.N, y = mean(y)), x]
        if (nrow(zz) <= 5) {
          lines(zz$x, zz$y, col = col_loess)
        } else {
          lo <- stats::loess(y ~ x, data = zz, weights = zz$N, span = span_loess)
          zz$y_lo <- predict(lo, zz, type = "link")
          lines(zz$x, zz$y_lo, col = col_loess)
        }
      }
      if (do_na) {
        i_na <- which(is.na(x))
        x_na <- rep(loc_na, length(i_na))
        x_na <- jitter(x_na, amount = x_range * 0.01)
        points(x_na, y[i_na], pch = pch_NA, col = col_NA)
      }
    }
    par(op)
  }
  if (plot && which == "2d") {
    # TODO
    warning("Bivariate plotting is currently not available.")
  }
  invisible(list(data = data, shap_contrib = shap_contrib))
}