File: thresholder.R

package info (click to toggle)
r-cran-caret 7.0-1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 4,036 kB
  • sloc: ansic: 210; sh: 10; makefile: 2
file content (182 lines) | stat: -rw-r--r-- 6,708 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
#' Generate Data to Choose a Probability Threshold
#' 
#' This function uses the resampling results from a \code{\link{train}}
#'  object to generate performance statistics over a set of probability
#'  thresholds for two-class problems. 
#' 
#' @param x A \code{\link{train}} object where the values of
#'  \code{savePredictions} was either \code{TRUE}, \code{"all"},
#'  or \code{"final"} in \code{\link{trainControl}}. Also, the 
#'  control argument \code{clasProbs} should have been \code{TRUE}.
#' @param threshold A numeric vector of candidate probability thresholds
#'  between [0,1]. If the class probability corresponding to the first
#'  level of the outcome is greater than the threshold, the data point
#'  is classified as that level. 
#' @param final A logical: should only the final tuning parameters
#'   chosen by \code{\link{train}} be used when 
#'   \code{savePredictions = 'all'}?
#' @param statistics A character vector indicating which statistics to
#'   calculate. See details below for possible choices; the default value
#'   \code{"all"} computes all of these.
#' @return A data frame with columns for each of the tuning parameters
#'  from the model along with an additional column called
#'  \code{prob_threshold} for the probability threshold. There are
#'  also columns for summary statistics averaged over resamples with
#'  column names corresponding to the input argument \code{statistics}. 
#' @details The argument \code{statistics} designates the statistics to compute
#'  for each probability threshold. One or more of the following statistics can
#'  be selected:
#'  \itemize{
#'  \item Sensitivity
#'  \item Specificity
#'  \item Pos Pred Value
#'  \item Neg Pred Value
#'  \item Precision
#'  \item Recall
#'  \item F1
#'  \item Prevalence
#'  \item Detection Rate
#'  \item Detection Prevalence
#'  \item Balanced Accuracy
#'  \item Accuracy
#'  \item Kappa
#'  \item J
#'  \item Dist
#' }
#' For a description of these statistics (except the last two), see the
#' documentation of \code{\link{confusionMatrix}}. The last two statistics
#' are Youden's J statistic and the distance to the best possible cutoff (i.e.
#' perfect sensitivity and specificity.
#' @export
#' @importFrom plyr ddply
#' @examples 
#' \dontrun{
#' set.seed(2444)
#' dat <- twoClassSim(500, intercept = -10)
#' table(dat$Class)
#' 
#' ctrl <- trainControl(method = "cv", 
#'                      classProbs = TRUE,
#'                      savePredictions = "all",
#'                      summaryFunction = twoClassSummary)
#' 
#' set.seed(2863)
#' mod <- train(Class ~ ., data = dat, 
#'              method = "rda",
#'              tuneLength = 4,
#'              metric = "ROC",
#'              trControl = ctrl)
#' 
#' resample_stats <- thresholder(mod, 
#'                               threshold = seq(.5, 1, by = 0.05), 
#'                               final = TRUE)
#' 
#' ggplot(resample_stats, aes(x = prob_threshold, y = J)) + 
#'   geom_point()
#' ggplot(resample_stats, aes(x = prob_threshold, y = Dist)) + 
#'   geom_point()
#' ggplot(resample_stats, aes(x = prob_threshold, y = Sensitivity)) + 
#'   geom_point() + 
#'   geom_point(aes(y = Specificity), col = "red")
#' }
thresholder <- function(x, threshold, final = TRUE, statistics = "all") {
  if(!inherits(x, "train"))
    stop("`x` should be an object of class 'train'", 
         call. = FALSE)
  if (!x$control$classProbs)
    stop("`classProbs` must be TRUE in `trainControl`",
         call. = FALSE)
  if (is.null(threshold))
    stop("Please supply probability threshold values.",
         call. = FALSE)
  if (any(threshold > 1 | threshold < 0))
    stop("`threshold` should be on [0,1]", call. = FALSE)
  
  if (is.logical(x$control$savePredictions)) {
    if (!x$control$savePredictions)
      stop("`savePredictions` should be TRUE, 'all', or 'final'")
  } else {
    if (x$control$savePredictions == "none")
      stop("`savePredictions` should be TRUE, 'all', or 'final'")
  }
  if (length(levels(x$pred$obs)) > 2)
    stop("For two class problems only", call. = TRUE)
  
  stat_names <- c("Sensitivity", "Specificity", "Pos Pred Value",
                  "Neg Pred Value", "Precision", "Recall", "F1", "Prevalence",
                  "Detection Rate", "Detection Prevalence", "Balanced Accuracy",
                  "Accuracy", "Kappa", "J", "Dist")
  if (!any(statistics %in% c("all", stat_names)) ||
      ("all" %in% statistics && length(statistics) > 1))
    stop("`statistics` should be either 'all', or one or more of '",
         paste0(stat_names, collapse="', '"), "'.")

  if (length(statistics) == 1 && statistics == "all")
    statistics <- stat_names
  
  disc <- c("pred", "rowIndex", x$levels[-1])
  
  ## Expand the predicted values with the candidate values of
  ## the threshold
  pred_dat <- expand_preds(if (final)
    merge(x$pred, x$bestTune)
    else
      x$pred,
    threshold,
    disc)
  
  param <- c("Resample", names(x$bestTune), "prob_threshold")
  
  ## Based on the threshold, recode the predicted classes
  pred_dat <- ddply(pred_dat, .variables = param, recode)
  
  ## Compute statistics per threshold and tuning parameters
  pred_stats <- ddply(pred_dat, .variables = param, stats)
  
  ## Summarize over resamples
  pred_resamp <- ddply(pred_stats, .variables = param[-1],
                       summ_stats, statistics)
  pred_resamp
}

expand_preds <- function(df, th, excl = NULL) {
  th <- unique(th)
  nth <- length(th)
  ndf <- nrow(df)
  if (!is.null(excl))
    df <- df[, !(names(df) %in% excl), drop = FALSE]
  df <- df[rep(1:nrow(df), times = nth),]
  df$prob_threshold <- rep(th, each = ndf)
  df
}


recode <- function(dat) {
  lvl <- levels(dat$obs)
  dat$pred <- ifelse(dat[, lvl[1]] > dat$prob_threshold,
                     lvl[1], lvl[2])
  dat$pred <- factor(dat$pred, levels = lvl)
  dat
}

stats <- function(dat) {
  tab <- caret::confusionMatrix(dat$pred, dat$obs,
                                positive = levels(dat$obs)[1])
  res <- c(tab$byClass, tab$overall[c("Accuracy", "Kappa")])
  res <- c(res,
           res["Sensitivity"] + res["Specificity"] - 1,
           sqrt((res["Sensitivity"] - 1) ^ 2 + (res["Specificity"] - 1) ^ 2))
  names(res)[-seq_len(length(res) - 2)] <- c("J", "Dist")
  res
}

summ_stats <- function(x, cols) {
  na_cols <- apply(x, 2, function(x) any(is.na(x)))
  na_col_names <- colnames(x)[na_cols]
  relevant_col_names <- intersect(na_col_names, cols)
  if (length(relevant_col_names) > 0)
    warning("The following columns have missing values (NA), which have been ",
            "removed: '", paste0(relevant_col_names, collapse = "', '"),
            "'.\n")
  colMeans(x[, cols, drop = FALSE], na.rm = TRUE)
}