File: lables.R

package info (click to toggle)
r-cran-rsample 0.0.8-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, bullseye
  • size: 1,696 kB
  • sloc: sh: 13; makefile: 2
file content (275 lines) | stat: -rw-r--r-- 7,058 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
#' Find Labels from rset Object
#'
#' Produce a vector of resampling labels (e.g. "Fold1") from
#'  an `rset` object. Currently, `nested_cv`
#'  is not supported.
#'
#' @param object An `rset` object
#' @param make_factor A logical for whether the results should be
#'  a character or a factor.
#' @param ... Not currently used.
#' @return A single character or factor vector.
#' @export
#' @examples
#' labels(vfold_cv(mtcars))
labels.rset <- function(object, make_factor = FALSE, ...) {
  if (inherits(object, "nested_cv"))
    stop("`labels` not implemented for nested resampling",
         call. = FALSE)
  if (make_factor)
    as.factor(object$id)
  else
    as.character(object$id)
}

#' @rdname labels.rset
#' @export
labels.vfold_cv <- function(object, make_factor = FALSE, ...) {
  if (inherits(object, "nested_cv"))
    stop("`labels` not implemented for nested resampling",
         call. = FALSE)
  is_repeated <- attr(object, "repeats") > 1
  if (is_repeated) {
    out <- as.character(paste(object$id, object$id2, sep = "."))
  } else
    out <- as.character(object$id)
  if (make_factor)
    out <- as.factor(out)
  out
}

#' Find Labels from rsplit Object
#'
#' Produce a tibble of identification variables so that single
#'  splits can be linked to a particular resample.
#'
#' @param object An `rsplit` object
#' @param ... Not currently used.
#' @return A tibble.
#' @seealso add_resample_id
#' @export
#' @examples
#' cv_splits <- vfold_cv(mtcars)
#' labels(cv_splits$splits[[1]])
labels.rsplit <- function(object, ...) {
  out <- if ("id" %in% names(object))
    object$id
  else
    tibble()
  out
}

## The `pretty` methods below are good for when you need to
## textually describe the resampling procedure. Note that they
## can have more than one element (in the case of nesting)


#' Short Descriptions of rsets
#'
#' Produce a character vector describing the resampling method.
#'
#' @param x An `rset` object
#' @param ... Not currently used.
#' @return A character vector.
#' @export pretty.vfold_cv
#' @export
#' @method pretty vfold_cv
#' @keywords internal
pretty.vfold_cv <- function(x, ...) {
  details <- attributes(x)
  res <- paste0(details$v, "-fold cross-validation")
  if (details$repeats > 1)
    res <- paste(res, "repeated", details$repeats, "times")
  if (details$strata)
    res <- paste(res, "using stratification")
  res
}

#' @export pretty.loo_cv
#' @export
#' @method pretty loo_cv
#' @rdname pretty.vfold_cv
pretty.loo_cv <- function(x, ...)
  "Leave-one-out cross-validation"

#' @export pretty.apparent
#' @export
#' @method pretty apparent
#' @rdname pretty.vfold_cv
pretty.apparent <- function(x, ...)
  "Apparent sampling"

#' @export pretty.rolling_origin
#' @export
#' @method pretty rolling_origin
#' @rdname pretty.vfold_cv
pretty.rolling_origin <- function(x, ...)
  "Rolling origin forecast resampling"

#' @export pretty.sliding_window
#' @export
#' @method pretty sliding_window
#' @rdname pretty.vfold_cv
pretty.sliding_window <- function(x, ...)
  "Sliding window resampling"

#' @export pretty.sliding_index
#' @export
#' @method pretty sliding_index
#' @rdname pretty.vfold_cv
pretty.sliding_index <- function(x, ...)
  "Sliding index resampling"

#' @export pretty.sliding_period
#' @export
#' @method pretty sliding_period
#' @rdname pretty.vfold_cv
pretty.sliding_period <- function(x, ...)
  "Sliding period resampling"

#' @export pretty.mc_cv
#' @export
#' @method pretty mc_cv
#' @rdname pretty.vfold_cv
pretty.mc_cv <- function(x, ...) {
  details <- attributes(x)
  res <- paste0(
    "Monte Carlo cross-validation (",
    signif(details$prop, 2),
    "/",
    signif(1 - details$prop, 2),
    ") with ",
    details$times,
    " resamples "
  )
  if (details$strata)
    res <- paste(res, "using stratification")
  res
}

#' @export pretty.validation_split
#' @export
#' @method pretty validation_split
#' @rdname pretty.vfold_cv
pretty.validation_split <- function(x, ...) {
  details <- attributes(x)
  res <- paste0(
    "Validation Set Split (",
    signif(details$prop, 2),
    "/",
    signif(1 - details$prop, 2),
    ") "
  )
  if (details$strata)
    res <- paste(res, "using stratification")
  res
}

#' @export pretty.nested_cv
#' @export
#' @method pretty nested_cv
#' @rdname pretty.vfold_cv
pretty.nested_cv <- function(x, ...) {
  details <- attributes(x)

  if (is_call(details$outside)) {
    class(x) <- class(x)[!(class(x) == "nested_cv")]
    outer_label <- pretty(x)
  } else {
    outer_label <- paste0("`", deparse(details$outside), "`")
  }

  inner_label <- if (is_call(details$inside))
    pretty(x$inner_resamples[[1]])
  else
    paste0("`", deparse(details$inside), "`")

  res <- c("Nested resampling:",
           paste(" outer:", outer_label),
           paste(" inner:", inner_label))
  res
}

#' @export pretty.bootstraps
#' @export
#' @method pretty bootstraps
#' @rdname pretty.vfold_cv
pretty.bootstraps <- function(x, ...) {
  details <- attributes(x)
  res <- "Bootstrap sampling"
  if (details$strata)
    res <- paste(res, "using stratification")
  if (details$apparent)
    res <- paste(res, "with apparent sample")
  res
}


#' @export pretty.group_vfold_cv
#' @export
#' @method pretty group_vfold_cv
#' @rdname pretty.vfold_cv
pretty.group_vfold_cv  <- function(x, ...) {
  details <- attributes(x)
  paste0("Group ", details$v, "-fold cross-validation")
}

#' @export pretty.manual_rset
#' @export
#' @method pretty manual_rset
#' @rdname pretty.vfold_cv
pretty.manual_rset <- function(x, ...) {
  "Manual resampling"
}


#' Augment a data set with resampling identifiers
#'
#' For a data set, `add_resample_id()` will add at least one new column that
#'  identifies which resample that the data came from. In most cases, a single
#'  column is added but for some resampling methods, two or more are added.
#' @param .data A data frame
#' @param split A single `rset` object.
#' @param dots A single logical: should the id columns be prefixed with a "."
#'  to avoid name conflicts with `.data`?
#' @return An updated data frame.
#' @examples
#' library(dplyr)
#'
#' set.seed(363)
#' car_folds <- vfold_cv(mtcars, repeats = 3)
#'
#' analysis(car_folds$splits[[1]]) %>%
#'   add_resample_id(car_folds$splits[[1]]) %>%
#'   head()
#'
#' car_bt <- bootstraps(mtcars)
#'
#' analysis(car_bt$splits[[1]]) %>%
#'   add_resample_id(car_bt$splits[[1]]) %>%
#'   head()
#' @seealso labels.rsplit
#' @export
add_resample_id <- function(.data, split, dots = FALSE) {
  if (!inherits(dots, "logical") || length(dots) > 1) {
    stop("`dots` should be a single logical.", call. = FALSE)
  }
  if (!inherits(.data, "data.frame")) {
    stop("`.data` should be a data frame.", call. = FALSE)
  }
  if (!inherits(split, "rsplit")) {
    stop("`split` should be a single 'rset' object.", call. = FALSE)
  }
  labs <- labels(split)

  if (!tibble::is_tibble(labs) && nrow(labs) == 1) {
    stop("`split` should be a single 'rset' object.", call. = FALSE)
  }

  if (dots) {
    colnames(labs) <- paste0(".", colnames(labs))
  }

  cbind(.data, labs)
}