File: case_weights.R

package info (click to toggle)
r-cran-recipes 1.0.4%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 3,636 kB
  • sloc: sh: 37; makefile: 2
file content (282 lines) | stat: -rw-r--r-- 8,214 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
#' Using case weights with recipes
#'
#' Case weights are positive numeric values that may influence how much each
#' data point has during the preprocessing. There are a variety of situations
#' where case weights can be used.
#'
#' tidymodels packages differentiate _how_ different types of case weights
#' should be used during the entire data analysis process, including
#' preprocessing data, model fitting, performance calculations, etc.
#'
#' The tidymodels packages require users to convert their numeric vectors to a
#' vector class that reflects how these should be used. For example, there are
#' some situations where the weights should not affect operations such as
#' centering and scaling or other preprocessing operations.
#'
#' The types of weights allowed in tidymodels are:
#'
#' * Frequency weights via [hardhat::frequency_weights()]
#' * Importance weights via [hardhat::importance_weights()]
#'
#' More types can be added by request.
#'
#' For recipes, we distinguish between supervised and unsupervised steps.
#' Supervised steps use the outcome in the calculations, this type of steps
#' will use frequency and importance weights. Unsupervised steps don't use the
#' outcome and will only use frequency weights.
#'
#' There are 3 main principles about how case weights are used within recipes.
#' First, the data set that is passed to the `recipe()` function should already
#' have a case weights column in it. This column can be created beforehand using
#' [hardhat::frequency_weights()] or [hardhat::importance_weights()]. Second,
#' There can only be 1 case weights column in a recipe at any given time. Third,
#' You can not modify the case weights column with most of the steps or using
#' the `update_role()` and `add_role()` functions.
#'
#' These principles ensure that you experience minimal surprises when using case
#' weights, as the steps automatically apply case weighted operations when
#' supported. The printing method will additionally show which steps where
#' weighted and which steps ignored the weights because they were of an
#' incompatible type.
#'
#' @name case_weights
#' @seealso [frequency_weights()], [importance_weights()]
NULL

#' Helpers for steps with case weights
#'
#' These functions can be used to do basic calculations with or without case
#' weights.
#'
#' @param info A data frame from the `info` argument within steps
#' @param .data The training data
#' @param x A numeric vector or a data frame
#' @param wts A vector of case weights
#' @param na_rm A logical value indicating whether `NA`
#'  values should be removed during computations.
#' @param use Used by [correlations()] or [covariances()] to pass argument to
#'   [cor()] or [cov()]
#' @param method Used by [correlations()] or [covariances()] to pass argument to
#'   [cor()] or [cov()]
#' @param unsupervised Can the step handle unsupervised weights
#' @details
#' [get_case_weights()] is designed for developers of recipe steps, to return
#' a column with the role of "case weight" as a vector.
#'
#' For the other functions, rows with missing case weights are removed from
#' calculations.
#'
#' For `averages()` and `variances()`, missing values in the data (*not* the
#' case weights) only affect the calculations for those rows. For
#' `correlations()`, the correlation matrix computation first removes rows
#' with any missing values (equal to the "complete.obs" strategy in
#' [stats::cor()]).
#'
#' `are_weights_used()` is designed for developers of recipe steps and is used
#' inside print method to determine how printing should be done.
#' @export
#' @name case-weight-helpers
get_case_weights <- function(info, .data) {
  wt_col <- info$variable[info$role == "case_weights" & !is.na(info$role)]


  if (length(wt_col) == 1) {
    res <- .data[[wt_col]]
    if (!is.numeric(res)) {
      rlang::abort(
        paste0(
          "Column ", wt_col, " has a 'case_weights' role but is not numeric."
        )
      )
    }
  } else if (length(wt_col) == 0) {
    res <- NULL
  } else {
    too_many_case_weights(length(wt_col))
  }

  res
}

# ------------------------------------------------------------------------------

too_many_case_weights <- function(n) {
  rlang::abort(
    paste0(
      "There should only be a single column with the role 'case_weights'. ",
      "In these data, there are ", n, " columns."
    )
  )
}

# ------------------------------------------------------------------------------

wt_calcs <- function(x, wts, statistic = "mean") {
  statistic <- rlang::arg_match(statistic, c("mean", "var", "cor", "cov", "pca", "median"))
  if (!is.data.frame(x)) {
    x <- data.frame(x)
  }

  if (is.null(wts)) {
    wts <- rep(1L, nrow(x))
  }

  complete <- stats::complete.cases(x) & !is.na(wts)
  wts <- wts[complete]
  x <- x[complete,,drop = FALSE]
  res <- stats::cov.wt(x, wt = wts, cor = statistic == "cor")

  if (statistic == "mean") {
    res <- unname(res[["center"]])
  } else if (statistic == "median") {
    res <- weighted_median_impl(x$x, wts)
  } else if (statistic == "var") {
    res <- unname(diag(res[["cov"]]))
  } else if (statistic == "pca") {
    res <- cov2pca(res$cov)
  } else if (statistic == "cov") {
    res <- res[["cov"]]
  } else {
    res <- res[["cor"]]
  }
  res
}

#' @export
#' @rdname case-weight-helpers
averages <- function(x, wts = NULL, na_rm = TRUE) {
  if (NCOL(x) == 0) {
    return(vapply(x, mean, c(mean = 0), na.rm = TRUE))
  }
  if (is.null(wts)) {
    res <- colMeans(x, na.rm = TRUE)
  } else {
    wts <- as.double(wts)
    res <- purrr::map_dbl(x, ~ wt_calcs(.x, wts))
  }
  if (!na_rm) {
    res[map_lgl(x, ~any(is.na(.x)))] <- NA
  }
  res
}

#' @export
#' @rdname case-weight-helpers
medians <- function(x, wts = NULL) {
  if (NCOL(x) == 0) {
    return(vapply(x, median, c(median = 0), na.rm = TRUE))
  }
  if (is.null(wts)) {
    res <- apply(x, 2, median, na.rm = TRUE)
  } else {
    wts <- as.double(wts)
    res <- purrr::map_dbl(x, ~ wt_calcs(.x, wts, statistic = "median"))
  }
  res
}

weighted_median_impl <- function(x, wts) {
  order_x <- order(x)

  x <- x[order_x]
  wts <- wts[order_x]

  wts_norm <- cumsum(wts) / sum(wts)
  ps <- min(which(wts_norm > 0.5))
  x[ps]
}

#' @export
#' @rdname case-weight-helpers
variances <- function(x, wts = NULL, na_rm = TRUE) {
  if (NCOL(x) == 0) {
    return(vapply(x, sd, c(sd = 0), na.rm = na_rm))
  }
  if (is.null(wts)) {
    res <- purrr::map_dbl(x, ~ stats::var(.x, na.rm = na_rm))
  } else {
    wts <- as.double(wts)
    res <- purrr::map_dbl(x, ~ wt_calcs(.x, wts, statistic = "var"))
    if (!na_rm) {
      res[map_lgl(x, ~any(is.na(.x)))] <- NA
    }
  }
  res
}

#' @export
#' @rdname case-weight-helpers
correlations <- function(x, wts = NULL, use = "everything", method = "pearson") {
  if (is.null(wts)) {
    res <- stats::cor(x, use = use, method = method)
  } else {
    wts <- as.double(wts)
    res <- wt_calcs(x, wts, statistic = "cor")
  }
  res
}

#' @export
#' @rdname case-weight-helpers
covariances <- function(x, wts = NULL, use = "everything", method = "pearson") {
  if (is.null(wts)) {
    res <- stats::cov(x, use = use, method = method)
  } else {
    wts <- as.double(wts)
    res <- wt_calcs(x, wts, statistic = "cov")
  }
  res
}


#' @export
#' @rdname case-weight-helpers
pca_wts <- function(x, wts = NULL) {
  wts <- as.double(wts)
  res <- wt_calcs(x, wts, statistic = "pca")
  res$center <- FALSE
  res$scale <- FALSE
  rownames(res$rotation) <- names(x)
  res
}

cov2pca <- function(cv_mat) {
  res <- eigen(cv_mat)

  # emulate prcomp results
  list(sdev = sqrt(res$values), rotation = res$vectors)
}

weighted_table <- function(x, wts = NULL) {
  if (is.null(wts)) {
    wts <- rep(1, length(x))
  }

  if (!is.factor(x)) {
    x <- factor(x)
  }

  hardhat::weighted_table(x, weights = wts)
}

is_unsupervised_weights <- function(wts) {
  if (!hardhat::is_case_weights(wts)) {
    rlang::abort("Must be be a case_weights variable")
  }

  hardhat::is_frequency_weights(wts)
}

#' @export
#' @rdname case-weight-helpers
are_weights_used <- function(wts, unsupervised = FALSE) {
  if (is.null(wts)) {
    return(NULL)
  }

  if (unsupervised) {
    return(is_unsupervised_weights(wts))
  }

  TRUE
}