1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
|
#' Generate Data to Choose a Probability Threshold
#'
#' This function uses the resampling results from a \code{\link{train}}
#' object to generate performance statistics over a set of probability
#' thresholds for two-class problems.
#'
#' @param x A \code{\link{train}} object where the values of
#' \code{savePredictions} was either \code{TRUE}, \code{"all"},
#' or \code{"final"} in \code{\link{trainControl}}. Also, the
#' control argument \code{clasProbs} should have been \code{TRUE}.
#' @param threshold A numeric vector of candidate probability thresholds
#' between [0,1]. If the class probability corresponding to the first
#' level of the outcome is greater than the threshold, the data point
#' is classified as that level.
#' @param final A logical: should only the final tuning parameters
#' chosen by \code{\link{train}} be used when
#' \code{savePredictions = 'all'}?
#' @param statistics A character vector indicating which statistics to
#' calculate. See details below for possible choices; the default value
#' \code{"all"} computes all of these.
#' @return A data frame with columns for each of the tuning parameters
#' from the model along with an additional column called
#' \code{prob_threshold} for the probability threshold. There are
#' also columns for summary statistics averaged over resamples with
#' column names corresponding to the input argument \code{statistics}.
#' @details The argument \code{statistics} designates the statistics to compute
#' for each probability threshold. One or more of the following statistics can
#' be selected:
#' \itemize{
#' \item Sensitivity
#' \item Specificity
#' \item Pos Pred Value
#' \item Neg Pred Value
#' \item Precision
#' \item Recall
#' \item F1
#' \item Prevalence
#' \item Detection Rate
#' \item Detection Prevalence
#' \item Balanced Accuracy
#' \item Accuracy
#' \item Kappa
#' \item J
#' \item Dist
#' }
#' For a description of these statistics (except the last two), see the
#' documentation of \code{\link{confusionMatrix}}. The last two statistics
#' are Youden's J statistic and the distance to the best possible cutoff (i.e.
#' perfect sensitivity and specificity.
#' @export
#' @importFrom plyr ddply
#' @examples
#' \dontrun{
#' set.seed(2444)
#' dat <- twoClassSim(500, intercept = -10)
#' table(dat$Class)
#'
#' ctrl <- trainControl(method = "cv",
#' classProbs = TRUE,
#' savePredictions = "all",
#' summaryFunction = twoClassSummary)
#'
#' set.seed(2863)
#' mod <- train(Class ~ ., data = dat,
#' method = "rda",
#' tuneLength = 4,
#' metric = "ROC",
#' trControl = ctrl)
#'
#' resample_stats <- thresholder(mod,
#' threshold = seq(.5, 1, by = 0.05),
#' final = TRUE)
#'
#' ggplot(resample_stats, aes(x = prob_threshold, y = J)) +
#' geom_point()
#' ggplot(resample_stats, aes(x = prob_threshold, y = Dist)) +
#' geom_point()
#' ggplot(resample_stats, aes(x = prob_threshold, y = Sensitivity)) +
#' geom_point() +
#' geom_point(aes(y = Specificity), col = "red")
#' }
thresholder <- function(x, threshold, final = TRUE, statistics = "all") {
if(!inherits(x, "train"))
stop("`x` should be an object of class 'train'",
call. = FALSE)
if (!x$control$classProbs)
stop("`classProbs` must be TRUE in `trainControl`",
call. = FALSE)
if (is.null(threshold))
stop("Please supply probability threshold values.",
call. = FALSE)
if (any(threshold > 1 | threshold < 0))
stop("`threshold` should be on [0,1]", call. = FALSE)
if (is.logical(x$control$savePredictions)) {
if (!x$control$savePredictions)
stop("`savePredictions` should be TRUE, 'all', or 'final'")
} else {
if (x$control$savePredictions == "none")
stop("`savePredictions` should be TRUE, 'all', or 'final'")
}
if (length(levels(x$pred$obs)) > 2)
stop("For two class problems only", call. = TRUE)
stat_names <- c("Sensitivity", "Specificity", "Pos Pred Value",
"Neg Pred Value", "Precision", "Recall", "F1", "Prevalence",
"Detection Rate", "Detection Prevalence", "Balanced Accuracy",
"Accuracy", "Kappa", "J", "Dist")
if (!any(statistics %in% c("all", stat_names)) ||
("all" %in% statistics && length(statistics) > 1))
stop("`statistics` should be either 'all', or one or more of '",
paste0(stat_names, collapse="', '"), "'.")
if (length(statistics) == 1 && statistics == "all")
statistics <- stat_names
disc <- c("pred", "rowIndex", x$levels[-1])
## Expand the predicted values with the candidate values of
## the threshold
pred_dat <- expand_preds(if (final)
merge(x$pred, x$bestTune)
else
x$pred,
threshold,
disc)
param <- c("Resample", names(x$bestTune), "prob_threshold")
## Based on the threshold, recode the predicted classes
pred_dat <- ddply(pred_dat, .variables = param, recode)
## Compute statistics per threshold and tuning parameters
pred_stats <- ddply(pred_dat, .variables = param, stats)
## Summarize over resamples
pred_resamp <- ddply(pred_stats, .variables = param[-1],
summ_stats, statistics)
pred_resamp
}
expand_preds <- function(df, th, excl = NULL) {
th <- unique(th)
nth <- length(th)
ndf <- nrow(df)
if (!is.null(excl))
df <- df[, !(names(df) %in% excl), drop = FALSE]
df <- df[rep(1:nrow(df), times = nth),]
df$prob_threshold <- rep(th, each = ndf)
df
}
recode <- function(dat) {
lvl <- levels(dat$obs)
dat$pred <- ifelse(dat[, lvl[1]] > dat$prob_threshold,
lvl[1], lvl[2])
dat$pred <- factor(dat$pred, levels = lvl)
dat
}
stats <- function(dat) {
tab <- caret::confusionMatrix(dat$pred, dat$obs,
positive = levels(dat$obs)[1])
res <- c(tab$byClass, tab$overall[c("Accuracy", "Kappa")])
res <- c(res,
res["Sensitivity"] + res["Specificity"] - 1,
sqrt((res["Sensitivity"] - 1) ^ 2 + (res["Specificity"] - 1) ^ 2))
names(res)[-seq_len(length(res) - 2)] <- c("J", "Dist")
res
}
summ_stats <- function(x, cols) {
na_cols <- apply(x, 2, function(x) any(is.na(x)))
na_col_names <- colnames(x)[na_cols]
relevant_col_names <- intersect(na_col_names, cols)
if (length(relevant_col_names) > 0)
warning("The following columns have missing values (NA), which have been ",
"removed: '", paste0(relevant_col_names, collapse = "', '"),
"'.\n")
colMeans(x[, cols, drop = FALSE], na.rm = TRUE)
}
|