File: mc.R

package info (click to toggle)
r-cran-rsample 0.0.8-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 1,696 kB
  • sloc: sh: 13; makefile: 2
file content (127 lines) | stat: -rw-r--r-- 4,510 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#' Monte Carlo Cross-Validation
#'
#' One resample of Monte Carlo cross-validation takes a random sample (without
#'  replacement) of the original data set to be used for analysis. All other
#'  data points are added to the assessment set.
#' @details The `strata` argument causes the random sampling to be conducted
#'  *within the stratification variable*. This can help ensure that the number of
#'  data points in the analysis data is equivalent to the proportions in the
#'  original data set. (Strata below 10% of the total are pooled together.)
#' @inheritParams vfold_cv
#' @param prop The proportion of data to be retained for modeling/analysis.
#' @param times The number of times to repeat the sampling.
#' @param strata A variable that is used to conduct stratified sampling to
#'  create the resamples. This could be a single character value or a variable
#'  name that corresponds to a variable that exists in the data frame.
#' @param breaks A single number giving the number of bins desired to stratify
#'  a numeric stratification variable.
#' @export
#' @return An tibble with classes `mc_cv`, `rset`, `tbl_df`, `tbl`, and
#'  `data.frame`. The results include a column for the data split objects and a
#'  column called `id` that has a character string with the resample identifier.
#' @examples
#' mc_cv(mtcars, times = 2)
#' mc_cv(mtcars, prop = .5, times = 2)
#'
#' library(purrr)
#' data(wa_churn, package = "modeldata")
#'
#' set.seed(13)
#' resample1 <- mc_cv(wa_churn, times = 3, prop = .5)
#' map_dbl(resample1$splits,
#'         function(x) {
#'           dat <- as.data.frame(x)$churn
#'           mean(dat == "Yes")
#'         })
#'
#' set.seed(13)
#' resample2 <- mc_cv(wa_churn, strata = "churn", times = 3, prop = .5)
#' map_dbl(resample2$splits,
#'         function(x) {
#'           dat <- as.data.frame(x)$churn
#'           mean(dat == "Yes")
#'         })
#'
#' set.seed(13)
#' resample3 <- mc_cv(wa_churn, strata = "tenure", breaks = 6, times = 3, prop = .5)
#' map_dbl(resample3$splits,
#'         function(x) {
#'           dat <- as.data.frame(x)$churn
#'           mean(dat == "Yes")
#'         })
#' @export
mc_cv <- function(data, prop = 3/4, times = 25, strata = NULL, breaks = 4, ...) {

  if(!missing(strata)) {
    strata <- tidyselect::vars_select(names(data), !!enquo(strata))
    if(length(strata) == 0) strata <- NULL
  }

  strata_check(strata, names(data))

  split_objs <-
    mc_splits(data = data,
              prop = 1 - prop,
              times = times,
              strata = strata,
              breaks = breaks)

  ## We remove the holdout indices since it will save space and we can
  ## derive them later when they are needed.

  split_objs$splits <- map(split_objs$splits, rm_out)

  mc_att <- list(prop = prop,
                 times = times,
                 strata = !is.null(strata))

  new_rset(splits = split_objs$splits,
           ids = split_objs$id,
           attrib = mc_att,
           subclass = c("mc_cv", "rset"))
}

# Get the indices of the analysis set from the assessment set
mc_complement <- function(ind, n) {
  list(analysis = setdiff(1:n, ind),
       assessment = ind)
}


mc_splits <- function(data, prop = 3/4, times = 25, strata = NULL, breaks = 4) {
  if (!is.numeric(prop) | prop >= 1 | prop <= 0)
    stop("`prop` must be a number on (0, 1).", call. = FALSE)

  n <- nrow(data)
  if (is.null(strata)) {
    indices <- purrr::map(rep(n, times), sample, size = floor(n * prop))
  } else {
    stratas <- tibble::tibble(idx = 1:n,
                              strata = make_strata(getElement(data, strata),
                                                   breaks = breaks))
    stratas <- split_unnamed(stratas, stratas$strata)
    stratas <-
      purrr::map_df(stratas, strat_sample, prop = prop, times = times)
    indices <- split_unnamed(stratas$idx, stratas$rs_id)
  }
  indices <- lapply(indices, mc_complement, n = n)
  split_objs <-
    purrr::map(indices, make_splits, data = data, class = "mc_split")
  list(splits = split_objs,
       id = names0(length(split_objs), "Resample"))
}

strat_sample <- function(x, prop, times, ...) {
  n <- nrow(x)
  idx <- purrr::map(rep(n, times), sample, size = floor(n*prop), ...)
  out <- purrr::map_df(idx, function(ind, x) x[sort(ind), "idx"], x = x)
  out$rs_id <- rep(1:times, each = floor(n*prop))
  out
}

#' @export
print.mc_cv <- function(x, ...) {
  cat("#", pretty(x), "\n")
  class(x) <- class(x)[!(class(x) %in% c("mc_cv", "rset"))]
  print(x, ...)
}