File: model-matrix.R

package info (click to toggle)
r-cran-hardhat 1.2.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 1,656 kB
  • sloc: sh: 13; makefile: 2
file content (214 lines) | stat: -rw-r--r-- 5,859 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
#' Construct a design matrix
#'
#' `model_matrix()` is a stricter version of [stats::model.matrix()]. Notably,
#' `model_matrix()` will _never_ drop rows, and the result will be a tibble.
#'
#' @param terms A terms object to construct a model matrix with. This is
#' typically the terms object returned from the corresponding call to
#' [model_frame()].
#'
#' @param data A tibble to construct the design matrix with. This is
#' typically the tibble returned from the corresponding call to
#' [model_frame()].
#'
#' @details
#'
#' The following explains the rationale for some of the difference in arguments
#' compared to [stats::model.matrix()]:
#'
#' - `contrasts.arg`: Set the contrasts argument, `options("contrasts")`
#' globally, or assign a contrast to the factor of interest directly using
#' [stats::contrasts()]. See the examples section.
#'
#' - `xlev`: Not allowed because `model.frame()` is never called, so it is
#' unnecessary.
#'
#' - `...`: Not allowed because the default method of `model.matrix()` does
#' not use it, and the `lm` method uses it to pass potential offsets and
#' weights through, which are handled differently in hardhat.
#'
#' @return
#'
#' A tibble containing the design matrix.
#'
#' @examples
#' # ---------------------------------------------------------------------------
#' # Example usage
#'
#' framed <- model_frame(Sepal.Width ~ Species, iris)
#'
#' model_matrix(framed$terms, framed$data)
#'
#' # ---------------------------------------------------------------------------
#' # Missing values never result in dropped rows
#'
#' iris2 <- iris
#' iris2$Species[1] <- NA
#'
#' framed2 <- model_frame(Sepal.Width ~ Species, iris2)
#'
#' model_matrix(framed2$terms, framed2$data)
#'
#' # ---------------------------------------------------------------------------
#' # Contrasts
#'
#' # Default contrasts
#' y <- factor(c("a", "b"))
#' x <- data.frame(y = y)
#' framed <- model_frame(~y, x)
#'
#' # Setting contrasts directly
#' y_with_contrast <- y
#' contrasts(y_with_contrast) <- contr.sum(2)
#' x2 <- data.frame(y = y_with_contrast)
#' framed2 <- model_frame(~y, x2)
#'
#' # Compare!
#' model_matrix(framed$terms, framed$data)
#' model_matrix(framed2$terms, framed2$data)
#'
#' # Also, can set the contrasts globally
#' global_override <- c(unordered = "contr.sum", ordered = "contr.poly")
#'
#' rlang::with_options(
#'   .expr = {
#'     model_matrix(framed$terms, framed$data)
#'   },
#'   contrasts = global_override
#' )
#' @export
model_matrix <- function(terms, data) {
  validate_is_terms(terms)
  data <- check_is_data_like(data)

  # otherwise model.matrix() will try and run model.frame() for us on data
  # but we definitely don't want this, as we have already done it and it can
  # actually error out if we don't prevent it from running
  attr(data, "terms") <- terms

  predictors <- with_options(
    model.matrix(object = terms, data = data),
    na.action = "na.pass"
  )

  predictors <- strip_model_matrix(predictors)

  tibble::as_tibble(predictors)
}

strip_model_matrix <- function(x) {
  colnames <- colnames(x)
  dimnames <- list(NULL, colnames)

  dim <- dim(x)

  attrs <- list(dim = dim, dimnames = dimnames)

  attributes(x) <- attrs

  x
}

is_terms <- function(x) {
  inherits(x, "terms")
}

validate_is_terms <- function(.x, .x_nm) {
  if (is_missing(.x_nm)) {
    .x_nm <- as_label(enexpr(.x))
  }

  validate_is(.x, is_terms, "terms object", .x_nm)
}

# ------------------------------------------------------------------------------

model_matrix_one_hot <- function(terms, data) {
  validate_is_terms(terms)
  data <- check_is_data_like(data)

  n_cols <- length(data)

  # Convert character to factor ahead of time
  # so we can apply the one hot contrast
  for (i in seq_len(n_cols)) {
    col <- data[[i]]

    if (is.character(col)) {
      data[[i]] <- factor(col)
    }
  }

  # Locate unordered factors only
  indicator_unordered_factors <- vapply(data, is_unordered_factor, logical(1))

  names <- names(data)
  names <- names[indicator_unordered_factors]

  # Pre-assign the `contrasts<-` of each unordered factor using
  # `contr_one_hot()` so `model.matrix()` doesn't overwrite them with the
  # default that comes from `getOption("contrasts")`
  for (name in names) {
    col <- data[[name]]
    lvls <- levels(col)
    n <- length(lvls)
    contrasts <- contr_one_hot(lvls)
    data[[name]] <- assign_contrasts(col, n, contrasts)
  }

  model_matrix(terms, data)
}

#' Contrast function for one-hot encodings
#'
#' This contrast function produces a model matrix that has indicator columns for
#' each level of each factor.
#'
#' @param n A vector of character factor levels or the number of unique levels.
#' @param contrasts This argument is for backwards compatibility and only the
#'   default of `TRUE` is supported.
#' @param sparse This argument is for backwards compatibility and only the
#'   default of `FALSE` is supported.
#'
#' @return A diagonal matrix that is `n`-by-`n`.
#'
#' @keywords internal
contr_one_hot <- function(n, contrasts = TRUE, sparse = FALSE) {
  if (sparse) {
    warn("`sparse = TRUE` not implemented for `contr_one_hot()`.")
  }

  if (!contrasts) {
    warn("`contrasts = FALSE` not implemented for `contr_one_hot()`.")
  }

  if (is.character(n)) {
    names <- n
    n <- length(names)
  } else if (is.numeric(n)) {
    n <- as.integer(n)

    if (length(n) != 1L) {
      abort("`n` must have length 1 when an integer is provided.")
    }

    names <- as.character(seq_len(n))
  } else {
    abort("`n` must be a character vector or an integer of size 1.")
  }

  out <- diag(n)

  rownames(out) <- names
  colnames(out) <- names

  out
}

is_unordered_factor <- function(x) {
  inherits(x, "factor") && !inherits(x, "ordered")
}

assign_contrasts <- function(x, how_many, value) {
  stats::`contrasts<-`(x, how_many, value)
}