File: loo_compare.R

package info (click to toggle)
r-cran-loo 2.9.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 4,836 kB
  • sloc: sh: 15; makefile: 2
file content (352 lines) | stat: -rw-r--r-- 12,188 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
#' Model comparison
#'
#' @description Compare fitted models based on [ELPD][loo-glossary].
#'
#'   By default the print method shows only the most important information. Use
#'   `print(..., simplify=FALSE)` to print a more detailed summary.
#'
#' @export
#' @param x An object of class `"loo"` or a list of such objects. If a list is
#'   used then the list names will be used as the model names in the output. See
#'   **Examples**.
#' @param ... Additional objects of class `"loo"`, if not passed in as a single
#'   list.
#'
#' @return A matrix with class `"compare.loo"` that has its own
#'   print method. See the **Details** section.
#'
#' @details
#'   When comparing two fitted models, we can estimate the difference in their
#'   expected predictive accuracy by the difference in
#'   [`elpd_loo`][loo-glossary] or `elpd_waic` (or multiplied by \eqn{-2}, if
#'   desired, to be on the deviance scale).
#'
#'   When using `loo_compare()`, the returned matrix will have one row per model
#'   and several columns of estimates. The values in the
#'   [`elpd_diff`][loo-glossary] and [`se_diff`][loo-glossary] columns of the
#'   returned matrix are computed by making pairwise comparisons between each
#'   model and the model with the largest ELPD (the model in the first row). For
#'   this reason the `elpd_diff` column will always have the value `0` in the
#'   first row (i.e., the difference between the preferred model and itself) and
#'   negative values in subsequent rows for the remaining models.
#'
#'   To compute the standard error of the difference in [ELPD][loo-glossary] ---
#'   which should not be expected to equal the difference of the standard errors
#'   --- we use a paired estimate to take advantage of the fact that the same
#'   set of \eqn{N} data points was used to fit both models. These calculations
#'   should be most useful when \eqn{N} is large, because then non-normality of
#'   the distribution is not such an issue when estimating the uncertainty in
#'   these sums. These standard errors, for all their flaws, should give a
#'   better sense of uncertainty than what is obtained using the current
#'   standard approach of comparing differences of deviances to a Chi-squared
#'   distribution, a practice derived for Gaussian linear models or
#'   asymptotically, and which only applies to nested models in any case.
#'   Sivula et al. (2022) discuss the conditions when the normal
#'   approximation used for SE and `se_diff` is good.
#'
#'   If more than \eqn{11} models are compared, we internally recompute the model
#'   differences using the median model by ELPD as the baseline model. We then
#'   estimate whether the differences in predictive performance are potentially
#'   due to chance as described by McLatchie and Vehtari (2023). This will flag
#'   a warning if it is deemed that there is a risk of over-fitting due to the
#'   selection process. In that case users are recommended to avoid model
#'   selection based on LOO-CV, and instead to favor model averaging/stacking or
#'   projection predictive inference.
#' 
#' @seealso
#' * The [FAQ page](https://mc-stan.org/loo/articles/online-only/faq.html) on
#'   the __loo__ website for answers to frequently asked questions.
#' @template loo-and-compare-references
#'
#' @examples
#' # very artificial example, just for demonstration!
#' LL <- example_loglik_array()
#' loo1 <- loo(LL)     # should be worst model when compared
#' loo2 <- loo(LL + 1) # should be second best model when compared
#' loo3 <- loo(LL + 2) # should be best model when compared
#'
#' comp <- loo_compare(loo1, loo2, loo3)
#' print(comp, digits = 2)
#'
#' # show more details with simplify=FALSE
#' # (will be the same for all models in this artificial example)
#' print(comp, simplify = FALSE, digits = 3)
#'
#' # can use a list of objects with custom names
#' # will use apple, banana, and cherry, as the names in the output
#' loo_compare(list("apple" = loo1, "banana" = loo2, "cherry" = loo3))
#'
#' \dontrun{
#' # works for waic (and kfold) too
#' loo_compare(waic(LL), waic(LL - 10))
#' }
#'
loo_compare <- function(x, ...) {
  UseMethod("loo_compare")
}

#' @rdname loo_compare
#' @export
loo_compare.default <- function(x, ...) {
  if (is.loo(x)) {
    dots <- list(...)
    loos <- c(list(x), dots)
  } else {
    if (!is.list(x) || !length(x)) {
      stop("'x' must be a list if not a 'loo' object.")
    }
    if (length(list(...))) {
      stop("If 'x' is a list then '...' should not be specified.")
    }
    loos <- x
  }

  # If subsampling is used
  if (any(sapply(loos, inherits, "psis_loo_ss"))) {
    return(loo_compare.psis_loo_ss_list(loos))
  }

  loo_compare_checks(loos)

  comp <- loo_compare_matrix(loos)
  ord <- loo_compare_order(loos)

  # compute elpd_diff and se_elpd_diff relative to best model
  rnms <- rownames(comp)
  diffs <- mapply(FUN = elpd_diffs, loos[ord[1]], loos[ord])
  elpd_diff <- apply(diffs, 2, sum)
  se_diff <- apply(diffs, 2, se_elpd_diff)
  comp <- cbind(elpd_diff = elpd_diff, se_diff = se_diff, comp)
  rownames(comp) <- rnms

  # run order statistics-based checks on models
  loo_order_stat_check(loos, ord)

  class(comp) <- c("compare.loo", class(comp))
  return(comp)
}

#' @rdname loo_compare
#' @export
#' @param digits For the print method only, the number of digits to use when
#'   printing.
#' @param simplify For the print method only, should only the essential columns
#'   of the summary matrix be printed? The entire matrix is always returned, but
#'   by default only the most important columns are printed.
print.compare.loo <- function(x, ..., digits = 1, simplify = TRUE) {
  xcopy <- x
  if (inherits(xcopy, "old_compare.loo")) {
    if (NCOL(xcopy) >= 2 && simplify) {
      patts <- "^elpd_|^se_diff|^p_|^waic$|^looic$"
      xcopy <- xcopy[, grepl(patts, colnames(xcopy))]
    }
  } else if (NCOL(xcopy) >= 2 && simplify) {
     xcopy <- xcopy[, c("elpd_diff", "se_diff")]
  }
  print(.fr(xcopy, digits), quote = FALSE)
  invisible(x)
}



# internal ----------------------------------------------------------------

#' Compute pointwise elpd differences
#' @noRd
#' @param loo_a,loo_b Two `"loo"` objects.
elpd_diffs <- function(loo_a, loo_b) {
  pt_a <- loo_a$pointwise
  pt_b <- loo_b$pointwise
  elpd <- grep("^elpd", colnames(pt_a))
  pt_b[, elpd] - pt_a[, elpd]
}

#' Compute standard error of the elpd difference
#' @noRd
#' @param diffs Vector of pointwise elpd differences
se_elpd_diff <- function(diffs) {
  N <- length(diffs)
  # As `elpd_diff` is defined as the sum of N independent components,
  # we can compute the standard error by using the standard deviation
  # of the N components and multiplying by `sqrt(N)`.
  sqrt(N) * sd(diffs)
}


#' Perform checks on `"loo"` objects before comparison
#' @noRd
#' @param loos List of `"loo"` objects.
#' @return Nothing, just possibly throws errors/warnings.
loo_compare_checks <- function(loos) {
  ## errors
  if (length(loos) <= 1L) {
    stop("'loo_compare' requires at least two models.", call.=FALSE)
  }
  if (!all(sapply(loos, is.loo))) {
    stop("All inputs should have class 'loo'.", call.=FALSE)
  }

  Ns <- vapply(loos, function(x) nrow(x$pointwise), integer(1))
  if (any(Ns != Ns[1L])) {
    stop(
      paste0(
        "All models must have the same number of observations, but models have inconsistent observation counts: ",
        paste(paste0("'", find_model_names(loos), "' (", Ns, ")"), collapse = ", ")
      ), 
      call. = FALSE
    )
  }

  ## warnings

  yhash <- lapply(loos, attr, which = "yhash")
  yhash_ok <- sapply(yhash, function(x) { # ok only if all yhash are same (all NULL is ok)
    isTRUE(all.equal(x, yhash[[1]]))
  })
  if (!all(yhash_ok)) {
    warning("Not all models have the same y variable. ('yhash' attributes do not match)",
            call. = FALSE)
  }

  if (all(sapply(loos, is.kfold))) {
    Ks <- unlist(lapply(loos, attr, which = "K"))
    if (!all(Ks == Ks[1])) {
      warning("Not all kfold objects have the same K value. ",
              "For a more accurate comparison use the same number of folds. ",
              call. = FALSE)
    }
  } else if (any(sapply(loos, is.kfold)) && any(sapply(loos, is.psis_loo))) {
    warning("Comparing LOO-CV to K-fold-CV. ",
            "For a more accurate comparison use the same number of folds ",
            "or loo for all models compared.",
            call. = FALSE)
  }
}


#' Find the model names associated with `"loo"` objects
#'
#' @export
#' @keywords internal
#' @param x List of `"loo"` objects.
#' @return Character vector of model names the same length as `x.`
#'
find_model_names <- function(x) {
  stopifnot(is.list(x))
  out_names <- character(length(x))

  names1 <- names(x)
  names2 <- lapply(x, "attr", "model_name", exact = TRUE)
  names3 <- lapply(x, "[[", "model_name")
  names4 <- paste0("model", seq_along(x))

  for (j in seq_along(x)) {
    if (isTRUE(nzchar(names1[j]))) {
      out_names[j] <- names1[j]
    } else if (length(names2[[j]])) {
      out_names[j] <- names2[[j]]
    } else if (length(names3[[j]])) {
      out_names[j] <- names3[[j]]
    } else {
      out_names[j] <- names4[j]
    }
  }
  out_names
}


#' Compute the loo_compare matrix
#' @keywords internal
#' @noRd
#' @param loos List of `"loo"` objects.
loo_compare_matrix <- function(loos){
  tmp <- sapply(loos, function(x) {
    est <- x$estimates
    setNames(c(est), nm = c(rownames(est), paste0("se_", rownames(est))))
  })
  colnames(tmp) <- find_model_names(loos)
  rnms <- rownames(tmp)
  comp <- tmp
  ord <- loo_compare_order(loos)
  comp <- t(comp)[ord, ]
  patts <- c("elpd", "p_", "^waic$|^looic$", "^se_waic$|^se_looic$")
  col_ord <- unlist(sapply(patts, function(p) grep(p, colnames(comp))),
                    use.names = FALSE)
  comp <- comp[, col_ord]
  comp
}

#' Computes the order of loos for comparison
#' @noRd
#' @keywords internal
#' @param loos List of `"loo"` objects.
loo_compare_order <- function(loos){
  tmp <- sapply(loos, function(x) {
    est <- x$estimates
    setNames(c(est), nm = c(rownames(est), paste0("se_", rownames(est))))
  })
  colnames(tmp) <- find_model_names(loos)
  rnms <- rownames(tmp)
  ord <- order(tmp[grep("^elpd", rnms), ], decreasing = TRUE)
  ord
}

#' Perform checks on `"loo"` objects __after__ comparison
#' @noRd
#' @keywords internal
#' @param loos List of `"loo"` objects.
#' @param ord List of `"loo"` object orderings.
#' @return Nothing, just possibly throws errors/warnings.
loo_order_stat_check <- function(loos, ord) {

  ## breaks

  if (length(loos) <= 11L) {
    # procedure cannot be diagnosed for fewer than ten candidate models
    # (total models = worst model + ten candidates)
    # break from function
    return(NULL)
  }

  ## warnings

  # compute the elpd differences from the median model
  baseline_idx <- middle_idx(ord)
  diffs <- mapply(FUN = elpd_diffs, loos[ord[baseline_idx]], loos[ord])
  elpd_diff <- apply(diffs, 2, sum)

  # estimate the standard deviation of the upper-half-normal
  diff_median <- stats::median(elpd_diff)
  elpd_diff_trunc <- elpd_diff[elpd_diff >= diff_median]
  n_models <- sum(!is.na(elpd_diff_trunc))
  candidate_sd <- sqrt(1 / n_models * sum(elpd_diff_trunc^2, na.rm = TRUE))

  # estimate expected best diff under null hypothesis
  K <- length(loos) - 1
  order_stat <- order_stat_heuristic(K, candidate_sd)

  if (max(elpd_diff) <= order_stat) {
    # flag warning if we suspect no model is theoretically better than the baseline
    warning("Difference in performance potentially due to chance.",
            "See McLatchie and Vehtari (2023) for details.",
            call. = FALSE)
  }
}

#' Returns the middle index of a vector
#' @noRd
#' @keywords internal
#' @param vec A vector.
#' @return Integer index value.
middle_idx <- function(vec) floor(length(vec) / 2)

#' Computes maximum order statistic from K Gaussians
#' @noRd
#' @keywords internal
#' @param K Number of Gaussians.
#' @param c Scaling of the order statistic.
#' @return Numeric expected maximum from K samples from a Gaussian with mean
#' zero and scale `"c"`
order_stat_heuristic <- function(K, c) {
  qnorm(p = 1 - 1 / (K * 2), mean = 0, sd = c)
}