File: make_groups.R

package info (click to toggle)
r-cran-rsample 1.1.1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 1,872 kB
  • sloc: sh: 13; makefile: 2
file content (349 lines) | stat: -rw-r--r-- 11,323 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
#' Make groupings for grouped rsplits
#'
#' This function powers grouped resampling by splitting the data based upon
#' a grouping variable and returning the assessment set indices for each
#' split.
#'
#' @inheritParams vfold_cv
#' @param group A variable in `data` (single character or name) used for
#'  grouping observations with the same value to either the analysis or
#'  assessment set within a fold.
#' @param balance If `v` is less than the number of unique groups, how should
#'  groups be combined into folds? Should be one of
#'  `"groups"`, `"observations"`, `"prop"`.
#' @param ... Arguments passed to balance functions.
#'
#' @details
#' Not all `balance` options are accepted -- or make sense -- for all resampling
#'  functions. For instance, `balance = "prop"` assigns groups to folds at
#'  random, meaning that any given observation is not guaranteed to be in one
#'  (and only one) assessment set. That means `balance = "prop"` can't
#'  be used with [group_vfold_cv()], and so isn't an option available for that
#'  function.
#'
#' Similarly, [group_mc_cv()] and its derivatives don't assign data to one (and
#'  only one) assessment set, but rather allow each observation to be in an
#'  assessment set zero-or-more times. As a result, those functions don't have
#'  a `balance` argument, and under the hood always specify `balance = "prop"`
#'  when they call [make_groups()].
#'
#' @keywords internal
make_groups <- function(data,
                        group,
                        v,
                        balance = c("groups", "observations", "prop"),
                        strata = NULL,
                        ...) {
  rlang::check_dots_used(call = rlang::caller_env())
  balance <- rlang::arg_match(balance, error_call = rlang::caller_env())

  data_ind <- tibble(
    ..index = 1:nrow(data),
    ..group = group
  )
  data_ind$..group <- as.character(data_ind$..group)

  res <- switch(
    balance,
    "groups" = balance_groups(
      data_ind = data_ind,
      v = v,
      strata = strata,
      ...
    ),
    "observations" = balance_observations(
      data_ind = data_ind,
      v = v,
      strata = strata,
      ...
    ),
    "prop" = balance_prop(
      data_ind = data_ind,
      v = v,
      strata = strata,
      ...
    )
  )

  data_ind <- res$data_ind
  keys <- res$keys

  data_ind$..group <- as.character(data_ind$..group)
  keys$..group <- as.character(keys$..group)

  data_ind <- data_ind %>%
    full_join(keys, by = "..group") %>%
    arrange(..index)
  split_unnamed(data_ind$..index, data_ind$..folds)

}

balance_groups <- function(data_ind, v, strata = NULL, ...) {
  if (is.null(strata)) {
    balance_groups_normal(data_ind, v, ...)
  } else {
    balance_groups_strata(data_ind, v, strata, ...)
  }
}

balance_groups_normal <- function(data_ind, v, ...) {
  rlang::check_dots_empty()
  unique_groups <- unique(data_ind$..group)
  keys <- data.frame(
    ..group = unique_groups,
    ..folds = sample(
      rep(seq_len(v), length.out = length(unique_groups))
    )
  )
  list(
    data_ind = data_ind,
    keys = keys
  )
}

balance_groups_strata <- function(data_ind, v, strata, ...) {
  rlang::check_dots_empty()

  data_ind$..strata <- strata
  # Create a table that's all the unique group x strata combinations:
  keys <- vctrs::vec_unique(data_ind[c("..group", "..strata")])
  # Create as many fold IDs as there are group x strata,
  # in repeating order (1, 2, ..., n, 1, 2, ..., n)
  folds <- rep(1:v, length.out = nrow(keys))

  # Split the folds based on how many groups are within each strata
  # So if the first strata in sort is 3, and v is 2, that strata gets a
  # c(1, 2, 1) for fold IDs
  #
  # This means that, if nrow(keys) %% v == 0, each fold should have
  # the same number of groups from each strata
  #
  # We randomize "keys" here so that the function is stochastic even for
  # strata with only one group:
  unique_strata <- unique(keys$..strata)
  keys_order <- sample.int(length(unique_strata))

  # Re-order the keys data.frame based on the reshuffled strata variable:
  keys <- keys[
    order(match(keys$..strata, unique_strata[keys_order])),
  ]

  # And split both folds and keys with the reordered strata vector:
  folds <- split_unnamed(folds, keys$..strata)
  keys <- split_unnamed(keys, keys$..strata)

  # Randomly assign fold IDs to each group within each strata
  keys <- purrr::map2(
    keys,
    folds,
    function(x, y) {
      x$..folds <- sample(y)
      x
    }
  )

  keys <- dplyr::bind_rows(keys)
  keys <- keys[c("..group", "..folds")]
  list(
    data_ind = data_ind,
    keys = keys
  )
}

balance_observations <- function(data_ind, v, strata = NULL, ...) {
  rlang::check_dots_empty()
  n_obs <- nrow(data_ind)
  target_per_fold <- 1 / v

  # This is the core difference between stratification and not:
  #
  # Without stratification, data_ind is broken into v groups,
  # which are roughly balanced based on the number of observations
  #
  # With strata, data_ind is split up by strata, and then each _split_
  # is broken into v groups (which are then combined with the other strata);
  # the balancing for each fold is done separately inside each strata "split"
  data_splits <- if (is.null(strata)) {
    list(data_ind)
  } else {
    split_unnamed(data_ind, strata)
  }

  freq_table <- purrr::map_dfr(
    data_splits,
    balance_observations_helper,
    v = v,
    target_per_fold = target_per_fold
  )

  collapse_groups(freq_table, data_ind, v)
}

balance_observations_helper <- function(data_split, v, target_per_fold) {

  n_obs <- nrow(data_split)
  # Create a frequency table counting how many of each group are in the data:
  freq_table <- vec_count(data_split$..group, sort = "location")
  # Randomly shuffle that table, then assign the first few rows to folds
  # (to ensure that each fold gets at least one group assigned):
  freq_table <- freq_table[sample.int(nrow(freq_table)), ]
  freq_table$assignment <- NA
  # Assign the first `v` rows to folds, so that each fold has _some_ data:
  freq_table$assignment[seq_len(v)] <- seq_len(v)

  # Each run of this loop assigns one "NA" assignment to a fold,
  # so we won't get caught in an endless loop here
  while (any(is.na(freq_table$assignment))) {
    # Get the index of the next row to be assigned, and its count:
    next_row <- which(is.na(freq_table$assignment))[[1]]
    next_size <- freq_table[next_row, ]$count

    # Calculate which fold to assign this new row into:
    group_breakdown <- freq_table %>%
      # The only NA column in freq_table should be assignment
      # So this should only drop un-assigned groups:
      stats::na.omit() %>%
      # Group by fold assignments and count data in each fold:
      dplyr::group_by(.data$assignment) %>%
      dplyr::summarise(count = sum(.data$count), .groups = "drop") %>%
      # Calculate...:
      dplyr::mutate(
        # The proportion of data in each fold so far,
        prop = .data$count / n_obs,
        # The amount off from the target proportion so far,
        pre_error = abs(.data$prop - target_per_fold),
        # The amount off from the target proportion if we add this new group,
        if_added_count = .data$count + next_size,
        if_added_prop = .data$if_added_count / n_obs,
        post_error = abs(.data$if_added_prop - target_per_fold),
        # And how much better or worse adding this new group would make things
        improvement = .data$post_error - .data$pre_error
      )

    # Assign the group in question to the best fold and move on to the next one:
    most_improved <- which.min(group_breakdown$improvement)
    freq_table[next_row, ]$assignment <-
      group_breakdown[most_improved, ]$assignment
  }
  freq_table
}

balance_prop <- function(prop, data_ind, v, replace = FALSE, strata = NULL, ...) {
  rlang::check_dots_empty()
  check_prop(prop, replace)

  # This is the core difference between stratification and not:
  #
  # Without stratification, `prop`% of `data_ind` is sampled `v` times;
  # the resampling is done with the entire set of groups
  #
  # With strata, data_ind is split up by strata, and then each _split_
  # has `prop`% of `data_ind` is sampled `v` times;
  # the resampling for each iteration is done inside each strata "split"
  data_splits <- if (is.null(strata)) {
    list(data_ind)
  } else {
    split_unnamed(data_ind, strata)
  }

  freq_table <- purrr::map_dfr(
    data_splits,
    balance_prop_helper,
    prop = prop,
    v = v,
    replace = replace
  )

  collapse_groups(freq_table, data_ind, v)
}

balance_prop_helper <- function(prop, data_ind, v, replace) {

  freq_table <- vec_count(data_ind$..group, sort = "location")

  # Calculate how many groups to sample each iteration
  # If sampling with replacement,
  # set `n` to the number of resamples we'd need
  # if we somehow got the smallest group every time.
  # If sampling without replacement, just reshuffle all the groups.
  n <- nrow(freq_table)
  if (replace) n <- n * prop * sum(freq_table$count) / min(freq_table$count)
  n <- ceiling(n)

  purrr::map_dfr(
    seq_len(v),
    function(x) {
      row_idx <- sample.int(nrow(freq_table), n, replace = replace)
      work_table <- freq_table[row_idx, ]
      cumulative_proportion <- cumsum(work_table$count) / sum(freq_table$count)
      crosses_target <- which(cumulative_proportion > prop)[[1]]
      is_closest <- cumulative_proportion[c(crosses_target, crosses_target - 1)]
      is_closest <- which.min(abs(is_closest - prop)) - 1
      crosses_target <- crosses_target - is_closest
      out <- work_table[seq_len(crosses_target), ]
      out$assignment <- x
      out
    }
  )
}

check_prop <- function(prop, replace) {
  acceptable_prop <- is.numeric(prop)
  acceptable_prop <- acceptable_prop &&
    ((prop <= 1 && replace) || (prop < 1 && !replace))
  acceptable_prop <- acceptable_prop && prop > 0
  if (!acceptable_prop) {
    rlang::abort(
      "`prop` must be a number between 0 and 1.",
      call = rlang::caller_env()
    )
  }
}


collapse_groups <- function(freq_table, data_ind, v) {
  data_ind <- dplyr::left_join(data_ind, freq_table, by = c("..group" = "key"))
  data_ind$..group <- data_ind$assignment
  data_ind <- data_ind[c("..index", "..group")]

  # If a group was never assigned a fold, then its `..group` is NA
  #
  # If we leave that alone, it winds up messing up our fold assignments,
  # because it will be assigned some value in `seq_len(v)`
  #
  # So instead, we drop those groups here:
  data_ind <- stats::na.omit(data_ind)

  unique_groups <- unique(data_ind$..group)

  keys <- data.frame(
    ..group = unique_groups,
    ..folds = sample(rep(seq_len(v), length.out = length(unique_groups)))
  )

  list(
    data_ind = data_ind,
    keys = keys
  )
}

validate_group <- function(group, data, call = rlang::caller_env()) {
  if (!missing(group)) {
    group <- tidyselect::vars_select(names(data), !!enquo(group))
    if (length(group) == 0) {
      group <- NULL
    }
  }

  if (is.null(group) || !is.character(group) || length(group) != 1) {
    rlang::abort(
      "`group` should be a single character value for the column that will be used for splitting.",
      call = call
    )
  }
  if (!any(names(data) == group)) {
    rlang::abort("`group` should be a column in `data`.", call = call)
  }

  group
}