1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
|
#
# This file is for the low level reuseable utility functions
# that are not supposed to be visibe to a user.
#
#
# General helper utilities ----------------------------------------------------
#
# SQL-style NVL shortcut.
NVL <- function(x, val) {
if (is.null(x))
return(val)
if (is.vector(x)) {
x[is.na(x)] <- val
return(x)
}
if (typeof(x) == 'closure')
return(x)
stop("typeof(x) == ", typeof(x), " is not supported by NVL")
}
#
# Low-level functions for boosting --------------------------------------------
#
# Merges booster params with whatever is provided in ...
# plus runs some checks
check.booster.params <- function(params, ...) {
if (!identical(class(params), "list"))
stop("params must be a list")
# in R interface, allow for '.' instead of '_' in parameter names
names(params) <- gsub("\\.", "_", names(params))
# merge parameters from the params and the dots-expansion
dot_params <- list(...)
names(dot_params) <- gsub("\\.", "_", names(dot_params))
if (length(intersect(names(params),
names(dot_params))) > 0)
stop("Same parameters in 'params' and in the call are not allowed. Please check your 'params' list.")
params <- c(params, dot_params)
# providing a parameter multiple times makes sense only for 'eval_metric'
name_freqs <- table(names(params))
multi_names <- setdiff(names(name_freqs[name_freqs > 1]), 'eval_metric')
if (length(multi_names) > 0) {
warning("The following parameters were provided multiple times:\n\t",
paste(multi_names, collapse = ', '), "\n Only the last value for each of them will be used.\n")
# While xgboost internals would choose the last value for a multiple-times parameter,
# enforce it here in R as well (b/c multi-parameters might be used further in R code,
# and R takes the 1st value when multiple elements with the same name are present in a list).
for (n in multi_names) {
del_idx <- which(n == names(params))
del_idx <- del_idx[-length(del_idx)]
params[[del_idx]] <- NULL
}
}
# for multiclass, expect num_class to be set
if (typeof(params[['objective']]) == "character" &&
substr(NVL(params[['objective']], 'x'), 1, 6) == 'multi:' &&
as.numeric(NVL(params[['num_class']], 0)) < 2) {
stop("'num_class' > 1 parameter must be set for multiclass classification")
}
# monotone_constraints parser
if (!is.null(params[['monotone_constraints']]) &&
typeof(params[['monotone_constraints']]) != "character") {
vec2str <- paste(params[['monotone_constraints']], collapse = ',')
vec2str <- paste0('(', vec2str, ')')
params[['monotone_constraints']] <- vec2str
}
# interaction constraints parser (convert from list of column indices to string)
if (!is.null(params[['interaction_constraints']]) &&
typeof(params[['interaction_constraints']]) != "character"){
# check input class
if (!identical(class(params[['interaction_constraints']]), 'list')) stop('interaction_constraints should be class list')
if (!all(unique(sapply(params[['interaction_constraints']], class)) %in% c('numeric', 'integer'))) {
stop('interaction_constraints should be a list of numeric/integer vectors')
}
# recast parameter as string
interaction_constraints <- sapply(params[['interaction_constraints']], function(x) paste0('[', paste(x, collapse = ','), ']'))
params[['interaction_constraints']] <- paste0('[', paste(interaction_constraints, collapse = ','), ']')
}
return(params)
}
# Performs some checks related to custom objective function.
# WARNING: has side-effects and can modify 'params' and 'obj' in its calling frame
check.custom.obj <- function(env = parent.frame()) {
if (!is.null(env$params[['objective']]) && !is.null(env$obj))
stop("Setting objectives in 'params' and 'obj' at the same time is not allowed")
if (!is.null(env$obj) && typeof(env$obj) != 'closure')
stop("'obj' must be a function")
# handle the case when custom objective function was provided through params
if (!is.null(env$params[['objective']]) &&
typeof(env$params$objective) == 'closure') {
env$obj <- env$params$objective
env$params$objective <- NULL
}
}
# Performs some checks related to custom evaluation function.
# WARNING: has side-effects and can modify 'params' and 'feval' in its calling frame
check.custom.eval <- function(env = parent.frame()) {
if (!is.null(env$params[['eval_metric']]) && !is.null(env$feval))
stop("Setting evaluation metrics in 'params' and 'feval' at the same time is not allowed")
if (!is.null(env$feval) && typeof(env$feval) != 'closure')
stop("'feval' must be a function")
# handle a situation when custom eval function was provided through params
if (!is.null(env$params[['eval_metric']]) &&
typeof(env$params$eval_metric) == 'closure') {
env$feval <- env$params$eval_metric
env$params$eval_metric <- NULL
}
# require maximize to be set when custom feval and early stopping are used together
if (!is.null(env$feval) &&
is.null(env$maximize) && (
!is.null(env$early_stopping_rounds) ||
has.callbacks(env$callbacks, 'cb.early.stop')))
stop("Please set 'maximize' to indicate whether the evaluation metric needs to be maximized or not")
}
# Update a booster handle for an iteration with dtrain data
xgb.iter.update <- function(booster_handle, dtrain, iter, obj = NULL) {
if (!identical(class(booster_handle), "xgb.Booster.handle")) {
stop("booster_handle must be of xgb.Booster.handle class")
}
if (!inherits(dtrain, "xgb.DMatrix")) {
stop("dtrain must be of xgb.DMatrix class")
}
if (is.null(obj)) {
.Call(XGBoosterUpdateOneIter_R, booster_handle, as.integer(iter), dtrain)
} else {
pred <- predict(booster_handle, dtrain, outputmargin = TRUE, training = TRUE,
ntreelimit = 0)
gpair <- obj(pred, dtrain)
.Call(XGBoosterBoostOneIter_R, booster_handle, dtrain, gpair$grad, gpair$hess)
}
return(TRUE)
}
# Evaluate one iteration.
# Returns a named vector of evaluation metrics
# with the names in a 'datasetname-metricname' format.
xgb.iter.eval <- function(booster_handle, watchlist, iter, feval = NULL) {
if (!identical(class(booster_handle), "xgb.Booster.handle"))
stop("class of booster_handle must be xgb.Booster.handle")
if (length(watchlist) == 0)
return(NULL)
evnames <- names(watchlist)
if (is.null(feval)) {
msg <- .Call(XGBoosterEvalOneIter_R, booster_handle, as.integer(iter), watchlist, as.list(evnames))
msg <- stri_split_regex(msg, '(\\s+|:|\\s+)')[[1]][-1]
res <- as.numeric(msg[c(FALSE, TRUE)]) # even indices are the values
names(res) <- msg[c(TRUE, FALSE)] # odds are the names
} else {
res <- sapply(seq_along(watchlist), function(j) {
w <- watchlist[[j]]
preds <- predict(booster_handle, w, outputmargin = TRUE, ntreelimit = 0) # predict using all trees
eval_res <- feval(preds, w)
out <- eval_res$value
names(out) <- paste0(evnames[j], "-", eval_res$metric)
out
})
}
return(res)
}
#
# Helper functions for cross validation ---------------------------------------
#
# Generates random (stratified if needed) CV folds
generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
# cannot do it for rank
if (exists('objective', where = params) &&
is.character(params$objective) &&
strtrim(params$objective, 5) == 'rank:') {
stop("\n\tAutomatic generation of CV-folds is not implemented for ranking!\n",
"\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n")
}
# shuffle
rnd_idx <- sample.int(nrows)
if (stratified &&
length(label) == length(rnd_idx)) {
y <- label[rnd_idx]
# WARNING: some heuristic logic is employed to identify classification setting!
# - For classification, need to convert y labels to factor before making the folds,
# and then do stratification by factor levels.
# - For regression, leave y numeric and do stratification by quantiles.
if (exists('objective', where = params) &&
is.character(params$objective)) {
# If 'objective' provided in params, assume that y is a classification label
# unless objective is reg:squarederror
if (params$objective != 'reg:squarederror')
y <- factor(y)
} else {
# If no 'objective' given in params, it means that user either wants to
# use the default 'reg:squarederror' objective or has provided a custom
# obj function. Here, assume classification setting when y has 5 or less
# unique values:
if (length(unique(y)) <= 5)
y <- factor(y)
}
folds <- xgb.createFolds(y, nfold)
} else {
# make simple non-stratified folds
kstep <- length(rnd_idx) %/% nfold
folds <- list()
for (i in seq_len(nfold - 1)) {
folds[[i]] <- rnd_idx[seq_len(kstep)]
rnd_idx <- rnd_idx[-seq_len(kstep)]
}
folds[[nfold]] <- rnd_idx
}
return(folds)
}
# Creates CV folds stratified by the values of y.
# It was borrowed from caret::createFolds and simplified
# by always returning an unnamed list of fold indices.
xgb.createFolds <- function(y, k = 10)
{
if (is.numeric(y)) {
## Group the numeric data based on their magnitudes
## and sample within those groups.
## When the number of samples is low, we may have
## issues further slicing the numeric data into
## groups. The number of groups will depend on the
## ratio of the number of folds to the sample size.
## At most, we will use quantiles. If the sample
## is too small, we just do regular unstratified
## CV
cuts <- floor(length(y) / k)
if (cuts < 2) cuts <- 2
if (cuts > 5) cuts <- 5
y <- cut(y,
unique(stats::quantile(y, probs = seq(0, 1, length = cuts))),
include.lowest = TRUE)
}
if (k < length(y)) {
## reset levels so that the possible levels and
## the levels in the vector are the same
y <- factor(as.character(y))
numInClass <- table(y)
foldVector <- vector(mode = "integer", length(y))
## For each class, balance the fold allocation as far
## as possible, then resample the remainder.
## The final assignment of folds is also randomized.
for (i in seq_along(numInClass)) {
## create a vector of integers from 1:k as many times as possible without
## going over the number of samples in the class. Note that if the number
## of samples in a class is less than k, nothing is producd here.
seqVector <- rep(seq_len(k), numInClass[i] %/% k)
## add enough random integers to get length(seqVector) == numInClass[i]
if (numInClass[i] %% k > 0) seqVector <- c(seqVector, sample.int(k, numInClass[i] %% k))
## shuffle the integers for fold assignment and assign to this classes's data
## seqVector[sample.int(length(seqVector))] is used to handle length(seqVector) == 1
foldVector[y == dimnames(numInClass)$y[i]] <- seqVector[sample.int(length(seqVector))]
}
} else {
foldVector <- seq(along = y)
}
out <- split(seq(along = y), foldVector)
names(out) <- NULL
out
}
#
# Deprectaion notice utilities ------------------------------------------------
#
#' Deprecation notices.
#'
#' At this time, some of the parameter names were changed in order to make the code style more uniform.
#' The deprecated parameters would be removed in the next release.
#'
#' To see all the current deprecated and new parameters, check the \code{xgboost:::depr_par_lut} table.
#'
#' A deprecation warning is shown when any of the deprecated parameters is used in a call.
#' An additional warning is shown when there was a partial match to a deprecated parameter
#' (as R is able to partially match parameter names).
#'
#' @name xgboost-deprecated
NULL
#' Do not use \code{\link[base]{saveRDS}} or \code{\link[base]{save}} for long-term archival of
#' models. Instead, use \code{\link{xgb.save}} or \code{\link{xgb.save.raw}}.
#'
#' It is a common practice to use the built-in \code{\link[base]{saveRDS}} function (or
#' \code{\link[base]{save}}) to persist R objects to the disk. While it is possible to persist
#' \code{xgb.Booster} objects using \code{\link[base]{saveRDS}}, it is not advisable to do so if
#' the model is to be accessed in the future. If you train a model with the current version of
#' XGBoost and persist it with \code{\link[base]{saveRDS}}, the model is not guaranteed to be
#' accessible in later releases of XGBoost. To ensure that your model can be accessed in future
#' releases of XGBoost, use \code{\link{xgb.save}} or \code{\link{xgb.save.raw}} instead.
#'
#' @details
#' Use \code{\link{xgb.save}} to save the XGBoost model as a stand-alone file. You may opt into
#' the JSON format by specifying the JSON extension. To read the model back, use
#' \code{\link{xgb.load}}.
#'
#' Use \code{\link{xgb.save.raw}} to save the XGBoost model as a sequence (vector) of raw bytes
#' in a future-proof manner. Future releases of XGBoost will be able to read the raw bytes and
#' re-construct the corresponding model. To read the model back, use \code{\link{xgb.load.raw}}.
#' The \code{\link{xgb.save.raw}} function is useful if you'd like to persist the XGBoost model
#' as part of another R object.
#'
#' Note: Do not use \code{\link{xgb.serialize}} to store models long-term. It persists not only the
#' model but also internal configurations and parameters, and its format is not stable across
#' multiple XGBoost versions. Use \code{\link{xgb.serialize}} only for checkpointing.
#'
#' For more details and explanation about model persistence and archival, consult the page
#' \url{https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html}.
#'
#' @examples
#' data(agaricus.train, package='xgboost')
#' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 2,
#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic")
#'
#' # Save as a stand-alone file; load it with xgb.load()
#' xgb.save(bst, 'xgb.model')
#' bst2 <- xgb.load('xgb.model')
#'
#' # Save as a stand-alone file (JSON); load it with xgb.load()
#' xgb.save(bst, 'xgb.model.json')
#' bst2 <- xgb.load('xgb.model.json')
#' if (file.exists('xgb.model.json')) file.remove('xgb.model.json')
#'
#' # Save as a raw byte vector; load it with xgb.load.raw()
#' xgb_bytes <- xgb.save.raw(bst)
#' bst2 <- xgb.load.raw(xgb_bytes)
#'
#' # Persist XGBoost model as part of another R object
#' obj <- list(xgb_model_bytes = xgb.save.raw(bst), description = "My first XGBoost model")
#' # Persist the R object. Here, saveRDS() is okay, since it doesn't persist
#' # xgb.Booster directly. What's being persisted is the future-proof byte representation
#' # as given by xgb.save.raw().
#' saveRDS(obj, 'my_object.rds')
#' # Read back the R object
#' obj2 <- readRDS('my_object.rds')
#' # Re-construct xgb.Booster object from the bytes
#' bst2 <- xgb.load.raw(obj2$xgb_model_bytes)
#' if (file.exists('my_object.rds')) file.remove('my_object.rds')
#'
#' @name a-compatibility-note-for-saveRDS-save
NULL
# Lookup table for the deprecated parameters bookkeeping
depr_par_lut <- matrix(c(
'print.every.n', 'print_every_n',
'early.stop.round', 'early_stopping_rounds',
'training.data', 'data',
'with.stats', 'with_stats',
'numberOfClusters', 'n_clusters',
'features.keep', 'features_keep',
'plot.height', 'plot_height',
'plot.width', 'plot_width',
'n_first_tree', 'trees',
'dummy', 'DUMMY'
), ncol = 2, byrow = TRUE)
colnames(depr_par_lut) <- c('old', 'new')
# Checks the dot-parameters for deprecated names
# (including partial matching), gives a deprecation warning,
# and sets new parameters to the old parameters' values within its parent frame.
# WARNING: has side-effects
check.deprecation <- function(..., env = parent.frame()) {
pars <- list(...)
# exact and partial matches
all_match <- pmatch(names(pars), depr_par_lut[, 1])
# indices of matched pars' names
idx_pars <- which(!is.na(all_match))
if (length(idx_pars) == 0) return()
# indices of matched LUT rows
idx_lut <- all_match[idx_pars]
# which of idx_lut were the exact matches?
ex_match <- depr_par_lut[idx_lut, 1] %in% names(pars)
for (i in seq_along(idx_pars)) {
pars_par <- names(pars)[idx_pars[i]]
old_par <- depr_par_lut[idx_lut[i], 1]
new_par <- depr_par_lut[idx_lut[i], 2]
if (!ex_match[i]) {
warning("'", pars_par, "' was partially matched to '", old_par, "'")
}
.Deprecated(new_par, old = old_par, package = 'xgboost')
if (new_par != 'NULL') {
eval(parse(text = paste(new_par, '<-', pars[[pars_par]])), envir = env)
}
}
}
|