1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
|
#' Cross-validate a gbm
#'
#' Functions for cross-validating gbm. These functions are used internally and
#' are not intended for end-user direct usage.
#'
#' These functions are not intended for end-user direct usage, but are used
#' internally by \code{gbm}.
#'
#' @aliases gbmCrossVal gbmCrossValModelBuild gbmDoFold gbmCrossValErr
#' gbmCrossValPredictions
#' @param cv.folds The number of cross-validation folds.
#' @param nTrain The number of training samples.
#' @param n.cores The number of cores to use.
#' @param class.stratify.cv Whether or not stratified cross-validation samples
#' are used.
#' @param data The data.
#' @param x The model matrix.
#' @param y The response variable.
#' @param offset The offset.
#' @param distribution The type of loss function. See \code{\link{gbm}}.
#' @param w Observation weights.
#' @param var.monotone See \code{\link{gbm}}.
#' @param n.trees The number of trees to fit.
#' @param interaction.depth The degree of allowed interactions. See
#' \code{\link{gbm}}.
#' @param n.minobsinnode See \code{\link{gbm}}.
#' @param shrinkage See \code{\link{gbm}}.
#' @param bag.fraction See \code{\link{gbm}}.
#' @param var.names See \code{\link{gbm}}.
#' @param response.name See \code{\link{gbm}}.
#' @param group Used when \code{distribution = "pairwise"}. See
#' \code{\link{gbm}}.
#' @param i.train Items in the training set.
#' @param cv.models A list containing the models for each fold.
#' @param cv.group A vector indicating the cross-validation fold for each
#' member of the training set.
#' @param best.iter.cv The iteration with lowest cross-validation error.
#' @param X Index (cross-validation fold) on which to subset.
#' @param s Random seed.
#' @return A list containing the cross-validation error and predictions.
#' @author Greg Ridgeway \email{gregridgeway@@gmail.com}
#' @seealso \code{\link{gbm}}
#' @references J.H. Friedman (2001). "Greedy Function Approximation: A Gradient
#' Boosting Machine," Annals of Statistics 29(5):1189-1232.
#'
#' L. Breiman (2001).
#' \url{https://www.stat.berkeley.edu/users/breiman/randomforest2001.pdf}.
#' @keywords models
# Perform gbm cross-validation
#
# This function has far too many arguments, but there isn't the
# abstraction in gbm to lose them.
#' @rdname gbmCrossVal
#' @export
gbmCrossVal <- function(cv.folds, nTrain, n.cores,
class.stratify.cv, data,
x, y, offset, distribution, w, var.monotone,
n.trees, interaction.depth, n.minobsinnode,
shrinkage, bag.fraction,
var.names, response.name, group) {
i.train <- 1:nTrain
cv.group <- getCVgroup(distribution, class.stratify.cv, y,
i.train, cv.folds, group)
## build the models
cv.models <- gbmCrossValModelBuild(cv.folds, cv.group, n.cores,
i.train, x, y, offset,
distribution, w, var.monotone,
n.trees, interaction.depth,
n.minobsinnode, shrinkage,
bag.fraction, var.names,
response.name, group)
## get the errors
cv.error <- gbmCrossValErr(cv.models, cv.folds, cv.group, nTrain, n.trees)
best.iter.cv <- which.min(cv.error)
## get the predictions
predictions <- gbmCrossValPredictions(cv.models, cv.folds, cv.group,
best.iter.cv, distribution,
data[i.train, ], y)
list(error = cv.error, predictions = predictions)
}
# Get the gbm cross-validation error
#' @rdname gbmCrossVal
#' @export
gbmCrossValErr <- function(cv.models, cv.folds, cv.group, nTrain, n.trees) {
in.group <- tabulate(cv.group, nbins = cv.folds)
cv.error <- vapply(1:cv.folds,
function(index) {
model <- cv.models[[index]]
model$valid.error * in.group[[index]]
}, double(n.trees))
## this is now a (n.trees, cv.folds) matrix
## and now a n.trees vector
rowSums(cv.error) / nTrain
}
#' @rdname gbmCrossVal
#' @export
gbmCrossValPredictions <- function(cv.models, cv.folds, cv.group,
best.iter.cv, distribution, data, y) {
# Get the predictions for GBM cross validation. This function is not as nice
# as it could be (i.e., leakage of y)
# Test that cv.group and data match
if (nrow(data) != length(cv.group)) {
stop("Mismatch between `data` and `cv.group`.")
}
# This is a little complicated due to multinomial distribution
num.cols <- if (distribution$name == "multinomial") {
nlevels(factor(y))
} else {
1
}
# Initialize results matrix
res <- matrix(nrow = nrow(data), ncol = num.cols)
# There's no real reason to do this as other than a for loop
data.names <- names(data) # column names
for (ind in 1:cv.folds) {
# These are the particular elements
flag <- cv.group == ind
model <- cv.models[[ind]]
# The %in% here is to handle coxph
# my.data <- data[flag, !(data.names %in% model$response.name)]
my.data <- data[flag, model$var.names, drop=FALSE]
predictions <- predict(model, newdata = my.data, n.trees = best.iter.cv) # FIXME
predictions <- matrix(predictions, ncol = num.cols)
res[flag, ] <- predictions
}
# Handle multinomial case
if (distribution$name != "multinomial") {
res <- as.numeric(res)
}
# Return the result
res
}
# Perform gbm cross-validation
#
# This function has far too many arguments.
#' @rdname gbmCrossVal
#' @export
gbmCrossValModelBuild <- function(cv.folds, cv.group, n.cores, i.train, x, y,
offset, distribution, w, var.monotone,
n.trees, interaction.depth, n.minobsinnode,
shrinkage, bag.fraction, var.names,
response.name, group) {
# Set random seeds
seeds <- as.integer(runif(cv.folds, -(2^31 - 1), 2^31))
# Perform cross-validation model builds
if (!is.null(n.cores) && n.cores == 1) {
lapply(1:cv.folds, FUN = gbmDoFold, i.train, x,
y, offset, distribution, w, var.monotone, n.trees,
interaction.depth, n.minobsinnode, shrinkage,
bag.fraction, cv.group, var.names, response.name, group,
seeds)
} else {
# Set up cluster and add finalizer
cluster <- gbmCluster(n.cores)
on.exit(parallel::stopCluster(cluster))
parallel::parLapply(cl = cluster, X = 1:cv.folds, fun = gbmDoFold, i.train, x,
y, offset, distribution, w, var.monotone, n.trees,
interaction.depth, n.minobsinnode, shrinkage,
bag.fraction, cv.group, var.names, response.name, group,
seeds)
}
}
#' @rdname gbmCrossVal
#' @export
gbmDoFold <- function(X, i.train, x, y, offset, distribution, w, var.monotone,
n.trees, interaction.depth, n.minobsinnode, shrinkage,
bag.fraction, cv.group, var.names, response.name, group,
s) {
# Do specified cross-validation fold---a self-contained function for passing
# to individual cores.
# Load required packages for core
library(gbm, quietly = TRUE)
# Print CV information
cat("CV:", X, "\n")
# Setup
set.seed(s[[X]])
i <- order(cv.group == X)
x <- x[i.train, , drop = FALSE][i, , drop = FALSE]
y <- y[i.train][i]
offset <- offset[i.train][i]
nTrain <- length(which(cv.group != X))
group <- group[i.train][i]
# Return a fitted GBM
gbm.fit(x = x, y = y, offset = offset, distribution = distribution,
w = w, var.monotone = var.monotone, n.trees = n.trees,
interaction.depth = interaction.depth,
n.minobsinnode = n.minobsinnode,
shrinkage = shrinkage, bag.fraction = bag.fraction,
nTrain = nTrain, keep.data = FALSE, verbose = FALSE,
response.name = response.name, group = group)
}
|