1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
|
#' Marginal plots of fitted gbm objects
#'
#' Plots the marginal effect of the selected variables by "integrating" out the
#' other variables.
#'
#' \code{plot.gbm} produces low dimensional projections of the
#' \code{\link{gbm.object}} by integrating out the variables not included in
#' the \code{i.var} argument. The function selects a grid of points and uses
#' the weighted tree traversal method described in Friedman (2001) to do the
#' integration. Based on the variable types included in the projection,
#' \code{plot.gbm} selects an appropriate display choosing amongst line plots,
#' contour plots, and \code{\link[lattice:Lattice]{lattice}} plots. If the default
#' graphics are not sufficient the user may set \code{return.grid = TRUE}, store
#' the result of the function, and develop another graphic display more
#' appropriate to the particular example.
#'
#' @param x A \code{\link{gbm.object}} that was fit using a call to
#' \code{\link{gbm}}.
#'
#' @param i.var Vector of indices or the names of the variables to plot. If
#' using indices, the variables are indexed in the same order that they appear
#' in the initial \code{gbm} formula. If \code{length(i.var)} is between 1 and
#' 3 then \code{plot.gbm} produces the plots. Otherwise, \code{plot.gbm}
#' returns only the grid of evaluation points and their average predictions
#'
#' @param n.trees Integer specifying the number of trees to use to generate the
#' plot. Default is to use \code{x$n.trees} (i.e., the entire ensemble).
#'
#' @param continuous.resolution Integer specifying the number of equally space
#' points at which to evaluate continuous predictors.
#'
#' @param return.grid Logical indicating whether or not to produce graphics
#' \code{FALSE} or only return the grid of evaluation points and their average
#' predictions \code{TRUE}. This is useful for customizing the graphics for
#' special variable types, or for higher dimensional graphs.
#'
#' @param type Character string specifying the type of prediction to plot on the
#' vertical axis. See \code{\link{predict.gbm}} for details.
#'
#' @param level.plot Logical indicating whether or not to use a false color
#' level plot (\code{TRUE}) or a 3-D surface (\code{FALSE}). Default is
#' \code{TRUE}.
#'
#' @param contour Logical indicating whether or not to add contour lines to the
#' level plot. Only used when \code{level.plot = TRUE}. Default is \code{FALSE}.
#'
#' @param number Integer specifying the number of conditional intervals to use
#' for the continuous panel variables. See \code{\link[graphics:coplot]{co.intervals}}
#' and \code{\link[lattice:shingles]{equal.count}} for further details.
#'
#' @param overlap The fraction of overlap of the conditioning variables. See
#' \code{\link[graphics:coplot]{co.intervals}} and \code{\link[lattice:shingles]{equal.count}}
#' for further details.
#'
#' @param col.regions Color vector to be used if \code{level.plot} is
#' \code{TRUE}. Defaults to the wonderful Matplotlib 'viridis' color map
#' provided by the \code{viridis} package. See \code{\link[viridis:reexports]{viridis}}
#' for details.
#'
#' @param ... Additional optional arguments to be passed onto
#' \code{\link[graphics:plot.default]{plot}}.
#'
#' @return If \code{return.grid = TRUE}, a grid of evaluation points and their
#' average predictions. Otherwise, a plot is returned.
#'
#' @note More flexible plotting is available using the
#' \code{\link[pdp]{partial}} and \code{\link[pdp]{plotPartial}} functions.
#'
#' @seealso \code{\link[pdp]{partial}}, \code{\link[pdp]{plotPartial}},
#' \code{\link{gbm}}, and \code{\link{gbm.object}}.
#'
#' @references J. H. Friedman (2001). "Greedy Function Approximation: A Gradient
#' Boosting Machine," Annals of Statistics 29(4).
#'
#' @references B. M. Greenwell (2017). "pdp: An R Package for Constructing
#' Partial Dependence Plots," The R Journal 9(1), 421--436.
#' \url{https://journal.r-project.org/articles/RJ-2017-016/index.html}.
#'
#' @export plot.gbm
#' @export
plot.gbm <- function(x, i.var = 1, n.trees = x$n.trees,
continuous.resolution = 100, return.grid = FALSE,
type = c("link", "response"), level.plot = TRUE,
contour = FALSE, number = 4, overlap = 0.1,
col.regions = viridis::viridis, ...) {
# Match type argument
type <- match.arg(type)
# Sanity checks
if(all(is.character(i.var))) {
i <- match(i.var, x$var.names)
if(any(is.na(i))) {
stop("Requested variables not found in ", deparse(substitute(x)), ": ",
i.var[is.na(i)])
} else {
i.var <- i
}
}
if((min(i.var) < 1) || (max(i.var) > length(x$var.names))) {
warning("i.var must be between 1 and ", length(x$var.names))
}
if(n.trees > x$n.trees) {
warning(paste("n.trees exceeds the number of tree(s) in the model: ",
x$n.trees, ". Using ", x$n.trees,
" tree(s) instead.", sep = ""))
n.trees <- x$n.trees
}
if(length(i.var) > 3) {
warning("plot.gbm() will only create up to (and including) 3-way ",
"interaction plots.\nBeyond that, plot.gbm() will only return ",
"the plotting data structure.")
return.grid <- TRUE
}
# Generate grid of predictor values on which to compute the partial
# dependence values
grid.levels <- vector("list", length(i.var))
for(i in 1:length(i.var)) {
if(is.numeric(x$var.levels[[i.var[i]]])) { # continuous
grid.levels[[i]] <- seq(from = min(x$var.levels[[i.var[i]]]),
to = max(x$var.levels[[i.var[i]]]),
length = continuous.resolution)
} else { # categorical
grid.levels[[i]] <-
as.numeric(factor(x$var.levels[[i.var[i]]],
levels = x$var.levels[[i.var[i]]])) - 1
}
}
X <- expand.grid(grid.levels)
names(X) <- paste("X", 1:length(i.var), sep = "")
# For compatibility with gbm version 1.6
if (is.null(x$num.classes)) {
x$num.classes <- 1
}
# Compute partial dependence values
y <- .Call("gbm_plot", X = as.double(data.matrix(X)),
cRows = as.integer(nrow(X)), cCols = as.integer(ncol(X)),
n.class = as.integer(x$num.classes),
i.var = as.integer(i.var - 1), n.trees = as.integer(n.trees),
initF = as.double(x$initF), trees = x$trees,
c.splits = x$c.splits, var.type = as.integer(x$var.type),
PACKAGE = "gbm")
if (x$distribution$name == "multinomial") { # reshape into matrix
X$y <- matrix(y, ncol = x$num.classes)
colnames(X$y) <- x$classes
# Convert to class probabilities (if requested)
if (type == "response") {
X$y <- exp(X$y)
X$y <- X$y / matrix(rowSums(X$y), ncol = ncol(X$y), nrow = nrow(X$y))
}
} else if(is.element(x$distribution$name, c("bernoulli", "pairwise")) &&
type == "response") {
X$y <- 1 / (1 + exp(-y))
} else if ((x$distribution$name == "poisson") && (type == "response")) {
X$y <- exp(y)
} else if (type == "response"){
warning("`type = \"response\"` only implemented for \"bernoulli\", ",
"\"poisson\", \"multinomial\", and \"pairwise\" distributions. ",
"Ignoring." )
} else {
X$y <- y
}
# Transform categorical variables back to factors
f.factor <- rep(FALSE, length(i.var))
for(i in 1:length(i.var)) {
if(!is.numeric(x$var.levels[[i.var[i]]])) {
X[,i] <- factor(x$var.levels[[i.var[i]]][X[, i] + 1],
levels = x$var.levels[[i.var[i]]])
f.factor[i] <- TRUE
}
}
# Return original variable names
names(X)[1:length(i.var)] <- x$var.names[i.var]
# Return grid only (if requested)
if(return.grid) {
return(X)
}
# Determine number of predictors
nx <- length(i.var)
# Determine which type of plot to draw based on the number of predictors
if (nx == 1L) {
# Single predictor
plotOnePredictorPDP(X, ...)
} else if (nx == 2) {
# Two predictors
plotTwoPredictorPDP(X, level.plot = level.plot, contour = contour,
col.regions = col.regions, ...)
} else {
# Three predictors (paneled version of plotTwoPredictorPDP)
plotThreePredictorPDP(X, nx = nx, level.plot = level.plot,
contour = contour, col.regions = col.regions,
number = number, overlap = overlap, ...)
}
}
#' @keywords internal
plotOnePredictorPDP <- function(X, ...) {
# Use the first column to determine which type of plot to construct
if (is.numeric(X[[1L]])) {
# Draw a line plot
lattice::xyplot(stats::as.formula(paste("y ~", names(X)[1L])),
data = X, type = "l", ...)
} else {
# Draw a Cleveland dot plot
lattice::dotplot(stats::as.formula(paste("y ~", names(X)[1L])),
data = X, xlab = names(X)[1L], ...)
}
}
#' @keywords internal
plotTwoPredictorPDP <- function(X, level.plot, contour, col.regions, ...) {
# Use the first two columns to determine which type of plot to construct
if (is.factor(X[[1L]]) && is.factor(X[[2L]])) {
# Draw a Cleveland dot plot
lattice::dotplot(stats::as.formula(
paste("y ~", paste(names(X)[1L:2L], collapse = "|"))
), data = X, xlab = names(X)[1L], ...)
} else if (is.factor(X[[1L]]) || is.factor(X[[2L]])) {
# Lattice plot formula
form <- if (is.factor(X[[1L]])) {
stats::as.formula(paste("y ~", paste(names(X)[2L:1L], collapse = "|")))
} else {
stats::as.formula(paste("y ~", paste(names(X)[1L:2L], collapse = "|")))
}
# Draw a paneled line plot
lattice::xyplot(form, data = X, type = "l", ...)
} else {
# Lattice plot formula
form <- stats::as.formula(
paste("y ~", paste(names(X)[1L:2L], collapse = "*"))
)
# Draw a three-dimensional surface
if (level.plot) {
# Draw a false color level plot
lattice::levelplot(form, data = X, col.regions = col.regions,
contour = contour, ...)
} else {
# Draw a wireframe plot
lattice::wireframe(form, data = X, ...)
}
}
}
#' @keywords internal
plotThreePredictorPDP <- function(X, nx, level.plot, contour, col.regions,
number, overlap, ...) {
# Factor, numeric, numeric
if (is.factor(X[[1L]]) && !is.factor(X[[2L]]) && !is.factor(X[[3L]])) {
X[, 1L:3L] <- X[, c(2L, 3L, 1L)]
}
# Numeric, factor, numeric
if (!is.factor(X[[1L]]) && is.factor(X[[2L]]) && !is.factor(X[[3L]])) {
X[, 1L:3L] <- X[, c(1L, 3L, 2L)]
}
# Factor, factor, numeric
if (is.factor(X[[1L]]) && is.factor(X[[2L]]) && !is.factor(X[[3L]])) {
X[, 1L:3L] <- X[, c(3L, 1L, 2L)]
}
# Factor, numeric, factor
if (is.factor(X[[1L]]) && !is.factor(X[[2L]]) && is.factor(X[[3L]])) {
X[, 1L:3L] <- X[, c(2L, 1L, 3L)]
}
# Convert third predictor to a factor using the equal count algorithm
if (is.numeric(X[[3L]])) {
X[[3L]] <- equal.count(X[[3L]], number = number, overlap = overlap)
}
if (is.factor(X[[1L]]) && is.factor(X[[2L]])) {
# Lattice plot formula
form <- stats::as.formula(
paste("y ~", names(X)[1L], "|", paste(names(X)[2L:nx], collapse = "*"))
)
# Produce a paneled dotplot
lattice::dotplot(form, data = X, xlab = names(X)[1L], ...)
} else if (is.numeric(X[[1L]]) && is.factor(X[[2L]])) {
# Lattice plot formula
form <- stats::as.formula(
paste("y ~", names(X)[1L], "|", paste(names(X)[2L:nx], collapse = "*"))
)
# Produce a paneled lineplot
lattice::xyplot(form, data = X, type = "l", ...)
} else {
# Lattice plot formula
form <- stats::as.formula(
paste("y ~", paste(names(X)[1L:2L], collapse = "*"), "|",
paste(names(X)[3L:nx], collapse = "*"))
)
# Draw a three-dimensional surface
if (level.plot) {
# Draw a false color level plot
lattice::levelplot(form, data = X, col.regions = col.regions,
contour = contour, ...)
} else {
# Draw a wireframe plot
lattice::wireframe(form, data = X, ...)
}
}
}
|