File: groups.R

package info (click to toggle)
r-cran-rsample 0.0.8-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, bullseye
  • size: 1,696 kB
  • sloc: sh: 13; makefile: 2
file content (117 lines) | stat: -rw-r--r-- 3,788 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#' Group V-Fold Cross-Validation
#'
#' Group V-fold cross-validation creates splits of the data based
#'  on some grouping variable (which may have more than a single row
#'  associated with it). The function can create as many splits as
#'  there are unique values of the grouping variable or it can
#'  create a smaller set of splits where more than one value is left
#'  out at a time.
#'
#' @param data A data frame.
#' @param group This could be a single character value or a variable
#'  name that corresponds to a variable that exists in the data frame.
#' @param v The number of partitions of the data set. If let
#'  `NULL`, `v` will be set to the number of unique values
#'  in the group.
#' @param ... Not currently used.
#' @export
#' @return A tibble with classes `group_vfold_cv`,
#'  `rset`, `tbl_df`, `tbl`, and `data.frame`.
#'  The results include a column for the data split objects and an
#'  identification variable.
#' @examples
#' set.seed(3527)
#' test_data <- data.frame(id = sort(sample(1:20, size = 80, replace = TRUE)))
#' test_data$dat <- runif(nrow(test_data))
#'
#' set.seed(5144)
#' split_by_id <- group_vfold_cv(test_data, group = "id")
#'
#' get_id_left_out <- function(x)
#'   unique(assessment(x)$id)
#'
#' library(purrr)
#' table(map_int(split_by_id$splits, get_id_left_out))
#'
#' set.seed(5144)
#' split_by_some_id <- group_vfold_cv(test_data, group = "id", v = 7)
#' held_out <- map(split_by_some_id$splits, get_id_left_out)
#' table(unlist(held_out))
#' # number held out per resample:
#' map_int(held_out, length)
#' @export
group_vfold_cv <- function(data, group = NULL, v = NULL, ...) {

  if(!missing(group)) {
    group <- tidyselect::vars_select(names(data), !!enquo(group))
    if(length(group) == 0) {
      group <- NULL
    }
  }

  if (is.null(group) || !is.character(group) || length(group) != 1)
    stop(
      "`group` should be a single character value for the column ",
      "that will be used for splitting.",
      call. = FALSE
    )
  if (!any(names(data) == group))
    stop("`group` should be a column in `data`.", call. = FALSE)

  split_objs <- group_vfold_splits(data = data, group = group, v = v)

  ## We remove the holdout indices since it will save space and we can
  ## derive them later when they are needed.

  split_objs$splits <- map(split_objs$splits, rm_out)

  # Update `v` if not supplied directly
  if (is.null(v)) {
    v <- length(split_objs$splits)
  }

  ## Save some overall information

  cv_att <- list(v = v, group = group)

  new_rset(splits = split_objs$splits,
           ids = split_objs[, grepl("^id", names(split_objs))],
           attrib = cv_att,
           subclass = c("group_vfold_cv", "rset"))
}

group_vfold_splits <- function(data, group, v = NULL) {
  uni_groups <- unique(getElement(data, group))
  max_v <- length(uni_groups)

  if (is.null(v)) {
    v <- max_v
  } else {
    if (v > max_v)
      stop("`v` should be less than ", max_v, call. = FALSE)
  }
  data_ind <- data.frame(..index = 1:nrow(data), ..group = getElement(data, group))
  keys <- data.frame(..group = uni_groups)

  n <- nrow(keys)
  keys$..folds <- sample(rep(1:v, length.out = n))
  data_ind <- data_ind %>%
    full_join(keys, by = "..group") %>%
    arrange(..index)
  indices <- split_unnamed(data_ind$..index, data_ind$..folds)
  indices <- lapply(indices, vfold_complement, n = nrow(data))
  split_objs <-
    purrr::map(indices,
               make_splits,
               data = data,
               class = "group_vfold_split")
  tibble::tibble(splits = split_objs,
                 id = names0(length(split_objs), "Resample"))
}

#' @export
print.group_vfold_cv <- function(x, ...) {
  cat("#", pretty(x), "\n")
  class(x) <- class(x)[!(class(x) %in% c("group_vfold_cv", "rset"))]
  print(x, ...)
}