File: initial_split.R

package info (click to toggle)
r-cran-rsample 1.2.1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,932 kB
  • sloc: sh: 13; makefile: 2
file content (160 lines) | stat: -rw-r--r-- 4,289 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#' Simple Training/Test Set Splitting
#'
#' `initial_split` creates a single binary split of the data into a training
#'  set and testing set. `initial_time_split` does the same, but takes the
#'  _first_ `prop` samples for training, instead of a random selection.
#'  `group_initial_split` creates splits of the data based
#'  on some grouping variable, so that all data in a "group" is assigned to
#'  the same split.
#'  `training` and `testing` are used to extract the resulting data.
#' @template strata_details
#' @inheritParams vfold_cv
#' @inheritParams make_strata
#' @param prop The proportion of data to be retained for modeling/analysis.
#' @export
#' @return An `rsplit` object that can be used with the `training` and `testing`
#'  functions to extract the data in each split.
#' @examplesIf rlang::is_installed("modeldata")
#' set.seed(1353)
#' car_split <- initial_split(mtcars)
#' train_data <- training(car_split)
#' test_data <- testing(car_split)
#'
#' data(drinks, package = "modeldata")
#' drinks_split <- initial_time_split(drinks)
#' train_data <- training(drinks_split)
#' test_data <- testing(drinks_split)
#' c(max(train_data$date), min(test_data$date)) # no lag
#'
#' # With 12 period lag
#' drinks_lag_split <- initial_time_split(drinks, lag = 12)
#' train_data <- training(drinks_lag_split)
#' test_data <- testing(drinks_lag_split)
#' c(max(train_data$date), min(test_data$date)) # 12 period lag
#'
#' set.seed(1353)
#' car_split <- group_initial_split(mtcars, cyl)
#' train_data <- training(car_split)
#' test_data <- testing(car_split)
#'
#' @export
#'
initial_split <- function(data, prop = 3 / 4,
                          strata = NULL, breaks = 4, pool = 0.1, ...) {
  check_dots_empty()
  res <-
    mc_cv(
      data = data,
      prop = prop,
      strata = {{ strata }},
      breaks = breaks,
      pool = pool,
      times = 1
    )
  res <- res$splits[[1]]
  class(res) <- c("initial_split", class(res))
  res
}

#' @rdname initial_split
#' @param lag A value to include a lag between the assessment
#'  and analysis set. This is useful if lagged predictors will be used
#'  during training and testing.
#' @export
initial_time_split <- function(data, prop = 3 / 4, lag = 0, ...) {
  check_dots_empty()
  if (!is.numeric(prop) | prop >= 1 | prop <= 0) {
    rlang::abort("`prop` must be a number on (0, 1).")
  }

  if (!is.numeric(lag) | !(lag %% 1 == 0)) {
    rlang::abort("`lag` must be a whole number.")
  }

  n_train <- floor(nrow(data) * prop)

  if (lag > n_train) {
    rlang::abort("`lag` must be less than or equal to the number of training observations.")
  }

  split <- rsplit(data, 1:n_train, (n_train + 1 - lag):nrow(data))
  splits <- list(split)
  ids <- "Resample1"
  rset <- new_rset(splits, ids)

  res <- rset$splits[[1]]
  class(res) <- c("initial_time_split", "initial_split", class(res))
  res
}

#' @rdname initial_split
#' @export
#' @param x An `rsplit` object produced by `initial_split()` or
#'  `initial_time_split()`.
training <- function(x, ...) {
  UseMethod("training")
}

#' @export
#' @rdname initial_split
training.default <- function(x, ...) {
  cls <- class(x)
  cli::cli_abort(
    "No method for objects of class{?es}: {cls}"
  )
}

#' @rdname initial_split
#' @export
training.rsplit <- function(x, ...) {
  analysis(x, ...)
}

#' @rdname initial_split
#' @export
testing <- function(x, ...) {
  UseMethod("testing")
}

#' @export
#' @rdname initial_split
testing.default <- function(x, ...) {
  cls <- class(x)
  cli::cli_abort(
    "No method for objects of class{?es}: {cls}"
  )
}

#' @rdname initial_split
#' @export
testing.rsplit <- function(x, ...) {
  assessment(x, ...)
}

#' @inheritParams make_groups
#' @rdname initial_split
#' @export
group_initial_split <- function(data, group, prop = 3 / 4, ..., strata = NULL, pool = 0.1) {
  check_dots_empty()

  if (missing(strata)) {
    res <- group_mc_cv(
        data = data,
        group = {{ group }},
        prop = prop,
        times = 1
      )
  } else {
    res <- group_mc_cv(
        data = data,
        group = {{ group }},
        prop = prop,
        times = 1,
        strata = {{ strata }},
        pool = pool
      )
  }
  res <- res$splits[[1]]
  class(res) <- c("group_initial_split", "initial_split", class(res))
  res
}