File: cv_varsel.R

package info (click to toggle)
r-cran-projpred 2.0.2%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 740 kB
  • sloc: cpp: 355; sh: 14; makefile: 2
file content (706 lines) | stat: -rw-r--r-- 25,177 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
#' Cross-validated variable selection (varsel)
#'
#' Perform cross-validation for the projective variable selection for a
#' generalized linear model or generalized lienar and additive multilevel
#' models.
#'
#' @name cv_varsel
#'
#' @param object Same as in \link[=varsel]{varsel}.
#' @param method Same as in \link[=varsel]{varsel}.
#' @param ndraws Number of posterior draws used for selection. Ignored if
#'   nclusters is provided or if method='L1'.
#' @param nclusters Number of clusters used for selection. Default is 1 and
#'   ignored if method='L1' (L1-search uses always one cluster).
#' @param ndraws_pred Number of samples used for prediction (after selection).
#'   Ignored if nclusters_pred is given.
#' @param nclusters_pred Number of clusters used for prediction (after
#'   selection). Default is 5.
#' @param cv_search Same as in \link[=varsel]{varsel}.
#' @param nterms_max Same as in \link[=varsel]{varsel}.
#' @param intercept Same as in \link[=varsel]{varsel}.
#' @param penalty Same as in \link[=varsel]{varsel}.
#' @param verbose Whether to print out some information during the validation,
#'   Default is TRUE.
#' @param cv_method The cross-validation method, either 'LOO' or 'kfold'.
#'   Default is 'LOO'.
#' @param nloo Number of observations used to compute the LOO validation
#'   (anything between 1 and the total number of observations). Smaller values
#'   lead to faster computation but higher uncertainty (larger errorbars) in the
#'   accuracy estimation. Default is to use all observations, but for faster
#'   experimentation, one can set this to a small value such as 100. Only
#'   applicable if \code{cv_method = 'LOO'}.
#' @param K Number of folds in the K-fold cross validation. Default is 5 for
#'   genuine reference models and 10 for datafits (that is, for penalized
#'   maximum likelihood estimation).
#' @param lambda_min_ratio Same as in \link[=varsel]{varsel}.
#' @param nlambda Same as in \link[=varsel]{varsel}.
#' @param thresh Same as in \link[=varsel]{varsel}.
#' @param regul Amount of regularization in the projection. Usually there is no
#'   need for regularization, but sometimes for some models the projection can
#'   be ill-behaved and we need to add some regularization to avoid numerical
#'   problems.
#' @param validate_search Whether to cross-validate also the selection process,
#'   that is, whether to perform selection separately for each fold. Default is
#'   TRUE and we strongly recommend not setting this to FALSE, because this is
#'   known to bias the accuracy estimates for the selected submodels. However,
#'   setting this to FALSE can sometimes be useful because comparing the results
#'   to the case where this parameter is TRUE gives idea how strongly the
#'   feature selection is (over)fitted to the data (the difference corresponds
#'   to the search degrees of freedom or the effective number of parameters
#'   introduced by the selectin process).
#' @param seed Random seed used in the subsampling LOO. By default uses a fixed
#'   seed.
#' @param search_terms User defined list of terms to consider for selection.
#' @param ... Additional arguments to be passed to the
#'   \code{get_refmodel}-function.
#'
#' @return An object of type \code{vsel} that contains information about the
#'   feature selection. The fields are not meant to be accessed directly by the
#'   user but instead via the helper functions (see the vignettes or type
#'   ?projpred to see the main functions in the package.)
#'
#' @examples
#' \donttest{
#' if (requireNamespace('rstanarm', quietly=TRUE)) {
#'   ### Usage with stanreg objects
#'   n <- 30
#'   d <- 5
#'   x <- matrix(rnorm(n*d), nrow=n)
#'   y <- x[,1] + 0.5*rnorm(n)
#'   data <- data.frame(x,y)
#'   fit <- rstanarm::stan_glm(y ~ X1 + X2 + X3 + X4 + X5, gaussian(),
#'      data=data, chains=2, iter=500)
#'   cvs <- cv_varsel(fit)
#'   plot(cvs)
#' }
#' }
#'
#' @export
cv_varsel <- function(object, ...) {
  UseMethod("cv_varsel")
}

#' @rdname cv_varsel
#' @export
cv_varsel.default <- function(object, ...) {
  refmodel <- get_refmodel(object)
  return(cv_varsel(refmodel, ...))
}

#' @rdname cv_varsel
#' @export
cv_varsel.refmodel <- function(object, method = NULL, cv_method = NULL,
                               ndraws = NULL, nclusters = NULL,
                               ndraws_pred = NULL, nclusters_pred = NULL,
                               cv_search = TRUE, nterms_max = NULL,
                               intercept = NULL, penalty = NULL, verbose = TRUE,
                               nloo = NULL, K = NULL, lambda_min_ratio = 1e-5,
                               nlambda = 150, thresh = 1e-6, regul = 1e-4,
                               validate_search = TRUE, seed = NULL,
                               search_terms = NULL, ...) {
  refmodel <- object
  ## resolve the arguments similar to varsel
  args <- parse_args_varsel(
    refmodel = refmodel, method = method, cv_search = cv_search,
    intercept = intercept, nterms_max = nterms_max, nclusters = nclusters,
    ndraws = ndraws,
    nclusters_pred = nclusters_pred,
    ndraws_pred = ndraws_pred, search_terms = search_terms
  )
  method <- args$method
  cv_search <- args$cv_search
  intercept <- args$intercept
  nterms_max <- args$nterms_max
  nclusters <- args$nclusters
  ndraws <- args$ndraws
  nclusters_pred <- args$nclusters_pred
  ndraws_pred <- args$ndraws_pred
  search_terms <- args$search_terms

  ## arguments specific to this function
  args <- parse_args_cv_varsel(
    refmodel, cv_method, K, nclusters,
    nclusters_pred
  )
  cv_method <- args$cv_method
  K <- args$K
  nclusters <- args$nclusters
  nclusters_pred <- args$nclusters_pred

  ## search options
  opt <- nlist(lambda_min_ratio, nlambda, thresh, regul)

  if (cv_method == "loo") {
    if (!(is.null(K))) warning("K provided, but cv_method is LOO.")
    sel_cv <- loo_varsel(
      refmodel = refmodel, method = method, nterms_max = nterms_max,
      ndraws = ndraws, nclusters = nclusters,
      ndraws_pred = ndraws_pred,
      nclusters_pred = nclusters_pred,
      cv_search = cv_search, intercept = intercept, penalty = penalty,
      verbose = verbose, opt = opt, nloo = nloo,
      validate_search = validate_search, seed = seed,
      search_terms = search_terms
    )
  } else if (cv_method == "kfold") {
    sel_cv <- kfold_varsel(
      refmodel = refmodel, method = method, nterms_max = nterms_max,
      ndraws = ndraws, nclusters = nclusters,
      ndraws_pred = ndraws_pred,
      nclusters_pred = nclusters_pred,
      cv_search = cv_search, intercept = intercept,
      penalty = penalty, verbose = verbose, opt = opt, K = K,
      seed = seed, search_terms = search_terms
    )
  } else {
    stop(sprintf("Unknown cross-validation method: %s.", method))
  }

  ## run the selection using the full dataset
  if (verbose) {
    print(paste("Performing the selection using all the data.."))
  }
  sel <- varsel(refmodel,
    method = method, ndraws = ndraws, nclusters = nclusters,
    ndraws_pred = ndraws_pred, nclusters_pred = nclusters_pred,
    cv_search = cv_search, nterms_max = nterms_max - 1,
    intercept = intercept, penalty = penalty, verbose = verbose,
    lambda_min_ratio = lambda_min_ratio, nlambda = nlambda, regul = regul,
    search_terms = search_terms
  )

  ## find out how many of cross-validated iterations select
  ## the same variables as the selection with all the data.
  solution_terms_cv_ch <- sapply(
    seq_len(NROW(sel_cv$solution_terms_cv)),
    function(i) {
      if (!is.character(sel_cv$solution_terms_cv[i, ])) {
        unlist(search_terms)[sel_cv$solution_terms_cv[i, ]]
      } else {
        sel_cv$solution_terms_cv[i, ]
      }
    }
  )

  ## these weights might be non-constant in case of subsampling LOO
  w <- sel_cv$summaries$sub[[1]]$w
  sel_solution_terms <- sel$solution_terms
  ## if weights are not set, then all validation folds have equal weight
  vars <- unlist(sel_solution_terms)
  pct_solution_terms_cv <- t(sapply(
    seq_along(sel_solution_terms),
    function(size) {
      c(
        size = size,
        sapply(vars, function(var) {
          sum((solution_terms_cv_ch[seq_len(size), , drop = FALSE] == var) * w,
            na.rm = TRUE
          )
        })
      )
    }
  ))

  ## create the object to be returned
  vs <- nlist(refmodel,
    search_path = sel$search_path, d_test = sel_cv$d_test,
    summaries = sel_cv$summaries, family = sel$family, kl = sel$kl,
    solution_terms = sel$solution_terms, pct_solution_terms_cv,
    nterms_max = nterms_max,
    nterms_all = count_terms_in_formula(refmodel$formula)
  )
  class(vs) <- "vsel"
  vs$suggested_size <- suggest_size(vs, warnings = FALSE)
  if (verbose) {
    print("Done.")
  }
  return(vs)
}

#
# Auxiliary function for parsing the input arguments for specific cv_varsel.
# This is similar in spirit to parse_args_varsel, that is, to avoid the main
# function to become too long and complicated to maintain.
#
# @param refmodel Reference model as extracted by get_refmodel
# @param cv_method The cross-validation method, either 'LOO' or 'kfold'.
#   Default is 'LOO'.
# @param K Number of folds in the K-fold cross validation. Default is 5 for
#   genuine reference models and 10 for datafits (that is, for penalized
#   maximum likelihood estimation).
parse_args_cv_varsel <- function(refmodel, cv_method = NULL, K = NULL,
                                 nclusters = NULL,
                                 nclusters_pred = NULL) {
  if (is.null(cv_method)) {
    if (inherits(refmodel, "datafit")) {
      ## only data given, no actual reference model
      cv_method <- "kfold"
    } else {
      cv_method <- "loo"
    }
  }

  if (!is.null(K)) {
    if (length(K) > 1 || !(is.numeric(K)) || !(K == round(K))) {
      stop("K must be a single integer value")
    }
    if (K < 2) {
      stop("K must be at least 2")
    }
    if (K > NROW(refmodel$y)) {
      stop("K cannot exceed n")
    }
  }

  if (tolower(cv_method) == "kfold" || is.null(K)) {
    if (inherits(refmodel, "datafit")) {
      K <- 10
      nclusters_pred <- 1
      nclusters <- 1
    } else {
      K <- 5
    }
  }

  cv_method <- tolower(cv_method)
  return(nlist(cv_method, K, nclusters, nclusters_pred))
}

loo_varsel <- function(refmodel, method, nterms_max, ndraws,
                       nclusters, ndraws_pred, nclusters_pred, cv_search,
                       intercept, penalty, verbose, opt, nloo = NULL,
                       validate_search = TRUE, seed = NULL,
                       search_terms = NULL) {
  ##
  ## Perform the validation of the searching process using LOO. validate_search
  ## indicates whether the selection is performed separately for each fold (for
  ## each data point)
  ##

  family <- refmodel$family
  mu <- refmodel$mu
  dis <- refmodel$dis
  ## the clustering/subsampling used for selection
  p_sel <- .get_refdist(refmodel,
    ndraws = ndraws,
    nclusters = nclusters
  )
  cl_sel <- p_sel$cl # clustering information

  ## the clustering/subsampling used for prediction
  p_pred <- .get_refdist(refmodel,
    ndraws = ndraws_pred,
    nclusters = nclusters_pred
  )
  cl_pred <- p_pred$cl

  ## fetch the log-likelihood for the reference model to obtain the LOO weights
  if (is.null(refmodel$loglik)) {
    ## case where log-likelihood not available, i.e., the reference model is not
    ## a genuine model => cannot compute LOO
    stop(
      "LOO can be performed only if the reference model is a genuine ",
      "probabilistic model for which the log-likelihood can be evaluated."
    )
  } else {
    ## log-likelihood available
    loglik <- refmodel$loglik
  }
  ## TODO: should take r_eff:s into account
  psisloo <- loo::psis(-loglik, cores = 1, r_eff = rep(1, ncol(loglik)))
  lw <- weights(psisloo)
  pareto_k <- loo::pareto_k_values(psisloo)
  n <- length(pareto_k)
  ## by default use all observations
  nloo <- min(nloo, n)

  if (nloo < 0) {
    stop("nloo must be at least 1")
  }

  ## compute loo summaries for the reference model
  loo_ref <- apply(loglik + lw, 2, log_sum_exp)
  mu_ref <- rep(0, n)
  for (i in 1:n) {
    mu_ref[i] <- mu[i, ] %*% exp(lw[, i])
  }

  ## decide which points form the validation set based on the k-values
  validset <- .loo_subsample(n, nloo, pareto_k, seed)
  inds <- validset$inds

  ## initialize matrices where to store the results
  solution_terms_mat <- matrix(nrow = n, ncol = nterms_max - 1)
  loo_sub <- matrix(nrow = n, ncol = nterms_max)
  mu_sub <- matrix(nrow = n, ncol = nterms_max)

  if (verbose) {
    print("Computing LOOs...")
    pb <- utils::txtProgressBar(min = 0, max = nloo, style = 3, initial = 0)
  }

  if (!validate_search) {
    ## perform selection only once using all the data (not separately for each
    ## fold), and perform the projection then for each submodel size
    search_path <- select(
      method = method, p_sel = p_sel, refmodel = refmodel, family = family,
      intercept = intercept, nterms_max = nterms_max, penalty = penalty,
      verbose = FALSE, opt = opt, search_terms = search_terms
    )
    solution_terms <- search_path$solution_terms
  }

  for (run_index in seq_along(inds)) {

    ## observation index
    i <- inds[run_index]

    ## reweight the clusters/samples according to the psis-loo weights
    p_sel <- .get_p_clust(
      family = family, mu = mu, dis = dis, wsample = exp(lw[, i]), cl = cl_sel
    )
    p_pred <- .get_p_clust(
      family = family, mu = mu, dis = dis, wsample = exp(lw[, i]), cl = cl_pred
    )

    if (validate_search) {
      ## perform selection with the reweighted clusters/samples
      search_path <- select(
        method = method, p_sel = p_sel, refmodel = refmodel,
        family = family, intercept = intercept, nterms_max = nterms_max,
        penalty = penalty, verbose = FALSE, opt = opt,
        search_terms = search_terms
      )
      solution_terms <- search_path$solution_terms
    }

    ## project onto the selected models and compute the prediction accuracy for
    ## the left-out point
    submodels <- .get_submodels(
      search_path = search_path, nterms = c(0, seq_along(solution_terms)),
      family = family, p_ref = p_pred, refmodel = refmodel,
      intercept = intercept, regul = opt$regul, cv_search = cv_search
    )
    summaries_sub <- .get_sub_summaries(
      submodels = submodels, test_points = c(i), refmodel = refmodel,
      family = family
    )
    for (k in seq_along(summaries_sub)) {
      loo_sub[i, k] <- summaries_sub[[k]]$lppd
      mu_sub[i, k] <- summaries_sub[[k]]$mu
    }

    ## we are always doing group selection
    ## with `match` we get the indices of the variables as they enter the
    ## solution path in solution_terms
    solution_terms_mat[i, ] <- match(solution_terms, search_terms)

    if (verbose) {
      utils::setTxtProgressBar(pb, run_index)
    }
  }

  if (verbose) {
    ## close the progress bar object
    close(pb)
  }

  ## put all the results together in the form required by cv_varsel
  summ_sub <- lapply(seq_len(nterms_max), function(k) {
    list(lppd = loo_sub[, k], mu = mu_sub[, k], w = validset$w)
  })
  summ_ref <- list(lppd = loo_ref, mu = mu_ref)
  summaries <- list(sub = summ_sub, ref = summ_ref)

  d_test <- list(
    y = refmodel$y, type = "loo",
    test_points = seq_along(refmodel$y),
    weights = refmodel$wobs,
    data = NULL
  )

  return(nlist(solution_terms_cv = solution_terms_mat, summaries, d_test))
}

kfold_varsel <- function(refmodel, method, nterms_max, ndraws,
                         nclusters, ndraws_pred,
                         nclusters_pred, cv_search, intercept, penalty,
                         verbose, opt, K, seed = NULL, search_terms = NULL) {
  ## fetch the k_fold list (or compute it now if not already computed)
  k_fold <- .get_kfold(refmodel, K, verbose, seed)

  ## check that k_fold has the correct form
  ## .validate_kfold(refmodel, k_fold, refmodel$nobs)

  K <- length(k_fold)
  family <- refmodel$family

  ## extract variables from each fit-object (samples, x, y, etc.)
  ## to a list of size K
  refmodels_cv <- lapply(k_fold, function(fold) fold$refmodel)

  # List of size K with test data for each fold
  d_test_cv <- lapply(k_fold, function(fold) {
    list(
      newdata = refmodel$fetch_data(obs = fold$omitted),
      y = refmodel$y[fold$omitted],
      weights = refmodel$wobs[fold$omitted],
      offset = refmodel$offset[fold$omitted],
      omitted = fold$omitted
    )
  })

  ## List of K elements, each containing d_train, p_pred, etc. corresponding
  ## to each fold.
  make_list_cv <- function(refmodel, d_test, msg) {
    nclusters_pred <- min(
      refmodel$nclusters_pred,
      nclusters_pred
    )
    p_sel <- .get_refdist(refmodel, ndraws, nclusters)
    p_pred <- .get_refdist(refmodel, ndraws_pred, nclusters_pred)
    newdata <- d_test$newdata
    pred <- matrix(
      as.numeric(refmodel$ref_predfun(refmodel$fit, newdata = newdata)),
      NROW(newdata), NCOL(refmodel$y)
    )
    mu_test <- family$linkinv(pred)
    nlist(refmodel, p_sel, p_pred, mu_test,
      dis = refmodel$dis, w_test = refmodel$wsample, d_test, msg
    )
  }

  msgs <- paste0(method, " search for fold ", 1:K, "/", K, ".")
  list_cv <- mapply(make_list_cv, refmodels_cv, d_test_cv, msgs,
    SIMPLIFY = FALSE
  )

  ## Perform the selection for each of the K folds
  if (verbose) {
    print("Performing selection for each fold..")
    pb <- utils::txtProgressBar(min = 0, max = K, style = 3, initial = 0)
  }
  search_path_cv <- lapply(seq_along(list_cv), function(fold_index) {
    fold <- list_cv[[fold_index]]
    family <- fold$refmodel$family
    out <- select(
      method = method, p_sel = fold$p_sel, refmodel = fold$refmodel,
      family = family, intercept = intercept, nterms_max = nterms_max,
      penalty = penalty, verbose = verbose, opt = opt,
      search_terms = search_terms
    )
    if (verbose) {
      utils::setTxtProgressBar(pb, fold_index)
    }
    out
  })

  solution_terms_cv <- t(sapply(search_path_cv, function(e) e$solution_terms))
  if (verbose) {
    close(pb)
  }

  ## Construct submodel projections for each fold
  if (verbose && cv_search) {
    print("Computing projections..")
    pb <- utils::txtProgressBar(min = 0, max = K, style = 3, initial = 0)
  }

  get_submodels_cv <- function(search_path, fold_index) {
    fold <- list_cv[[fold_index]]
    family <- fold$refmodel$family
    solution_terms <- search_path$solution_terms
    p_sub <- .get_submodels(
      search_path = search_path, nterms = c(0, seq_along(solution_terms)),
      family = family, p_ref = fold$p_pred, refmodel = fold$refmodel,
      intercept = intercept, regul = opt$regul, cv_search = FALSE
    )
    if (verbose && cv_search) {
      utils::setTxtProgressBar(pb, fold_index)
    }
    return(p_sub)
  }

  p_sub_cv <- mapply(get_submodels_cv, search_path_cv, seq_along(list_cv),
    SIMPLIFY = FALSE
  )
  if (verbose && cv_search) {
    close(pb)
  }

  ## Helper function extract and combine mu and lppd from K lists with each
  ## n/K of the elements to one list with n elements
  hf <- function(x) as.list(do.call(rbind, x))

  ## Apply some magic to manipulate the structure of the list so that instead of
  ## list with K sub_summaries each containing n/K mu:s and lppd:s, we have only
  ## one sub_summary-list that contains with all n mu:s and lppd:s.
  get_summaries_submodel_cv <- function(p_sub, fold) {
    omitted <- fold$d_test$omitted
    fold_summaries <- .get_sub_summaries(
      submodels = p_sub, test_points = omitted, refmodel = refmodel,
      family = family
    )
    lapply(fold_summaries, data.frame)
  }
  sub_cv_summaries <- mapply(get_summaries_submodel_cv, p_sub_cv, list_cv)
  sub <- apply(sub_cv_summaries, 1, hf)
  sub <- lapply(sub, function(summ) {
    summ$w <- rep(1, NROW(solution_terms_cv))
    summ$w <- summ$w / sum(summ$w)
    summ
  })

  ref <- hf(lapply(list_cv, function(fold) {
    data.frame(.weighted_summary_means(
      y_test = fold$d_test, family = family, wsample = fold$d_test$w,
      mu = fold$mu_test, dis = fold$refmodel$dis
    ))
  }))

  ## Combine also the K separate test data sets into one list
  ## with n y's and weights's.
  d_cv <- hf(lapply(d_test_cv, function(fold) {
    data.frame(
      y = fold$y, weights = fold$weights,
      test_points = fold$omitted
    )
  }))

  return(nlist(solution_terms_cv,
    summaries = list(sub = sub, ref = ref),
    d_test = c(d_cv, type = "kfold")
  ))
}


.get_kfold <- function(refmodel, K, verbose, seed) {
  ## Fetch the k_fold list or compute it now if not already computed. This
  ## function will return a list of length K, where each element is a list
  ## with fields 'refmodel' (object of type refmodel computed by init_refmodel)
  ## and index list 'test_points' that denotes which of the data points were
  ## left out for the corresponding fold.

  if (is.null(refmodel$cvfits)) {
    if (!is.null(refmodel$cvfun)) {
      # cv-function provided so perform the cross-validation now. In case
      # refmodel is datafit, cvfun will return an empty list and this will lead
      # to normal cross-validation for the submodels although we don't have an
      # actual reference model
      if (verbose && !("datafit" %in% class(refmodel))) {
        print("Performing cross-validation for the reference model..")
      }
      nobs <- NROW(refmodel$y)
      folds <- cvfolds(nobs, K = K, seed = seed)
      cvfits <- refmodel$cvfun(folds)
      cvfits <- lapply(seq_along(cvfits), function(k) {
        # add the 'omitted' indices for the cvfits
        cvfit <- cvfits[[k]]
        cvfit$omitted <- which(folds == k)
        cvfit
      })
    } else {
      ## genuine probabilistic model but no K-fold fits nor cvfun provided, so
      ## raise an error
      stop(
        "For a generic reference model, you must provide either cvfits or ",
        "cvfun for K-fold cross-validation. See function init_refmodel."
      )
    }
  } else {
    cvfits <- refmodel$cvfits
    K <- attr(cvfits, "K")
    folds <- attr(cvfits, "folds")
    cvfits <- lapply(seq_len(K), function(k) {
      cvfit <- cvfits$fits[[k]]
      obs <- seq_len(NROW(cvfits$data))
      cvfit$omitted <- obs[folds != k]
      cvfit
    })
  }

  train <- seq_along(refmodel$y)

  k_fold <- lapply(cvfits, .init_kfold_refmodel, refmodel, train)

  return(k_fold)
}

.init_kfold_refmodel <- function(cvfit, refmodel, train) {
  fold <- setdiff(
    train,
    cvfit$omitted
  )
  default_data <- refmodel$fetch_data(obs = fold)
  fetch_fold <- function(data = NULL, obs = NULL, newdata = NULL) {
    refmodel$fetch_data(obs = fold, newdata = newdata)
  }
  ref_predfun <- function(fit, newdata = default_data) {
    refmodel$ref_predfun(fit, newdata = newdata)
  }
  proj_predfun <- function(fit, newdata = default_data, weights = NULL) {
    refmodel$proj_predfun(fit, newdata = newdata, weights = weights)
  }
  if (!inherits(cvfit, "brmsfit") && !inherits(cvfit, "stanreg")) {
    fit <- NULL
  } else {
    fit <- cvfit
  }
  extract_model_data <- function(object, newdata = fetch_fold(), ...) {
    refmodel$extract_model_data(object = object, newdata = newdata)
  }
  k_refmodel <- init_refmodel(
    object = fit, data = fetch_fold(),
    formula = refmodel$formula, family = refmodel$family,
    ref_predfun = ref_predfun, div_minimizer = refmodel$div_minimizer,
    proj_predfun = proj_predfun, folds = seq_along(fold),
    extract_model_data = extract_model_data
  )
  k_refmodel$fetch_data <- fetch_fold
  k_refmodel$nclusters_pred <- min(NCOL(k_refmodel$mu), 5)
  return(nlist(refmodel = k_refmodel, omitted = cvfit$omitted))
}

.loo_subsample <- function(n, nloo, pareto_k, seed) {
  ## decide which points to go through in the validation (i.e., which points
  ## belong to the semi random subsample of validation points)

  ## set random seed but ensure the old RNG state is restored on exit
  if (exists(".Random.seed")) {
    rng_state_old <- .Random.seed
    on.exit(assign(".Random.seed", rng_state_old, envir = .GlobalEnv))
  }
  set.seed(seed)

  resample <- function(x, ...) x[sample.int(length(x), ...)]

  if (nloo < n) {
    bad <- which(pareto_k > 0.7)
    ok <- which(pareto_k <= 0.7 & pareto_k > 0.5)
    good <- which(pareto_k <= 0.5)
    inds <- resample(bad, min(length(bad), floor(nloo / 3)))
    inds <- c(inds, resample(ok, min(length(ok), floor(nloo / 3))))
    inds <- c(inds, resample(good, min(length(good), floor(nloo / 3))))
    if (length(inds) < nloo) {
      ## not enough points selected, so choose randomly among the rest
      inds <- c(inds, resample(setdiff(1:n, inds), nloo - length(inds)))
    }

    ## assign the weights corresponding to this stratification (for example, the
    ## 'bad' values are likely to be overpresented in the sample)
    w <- rep(0, n)
    w[inds[inds %in% bad]] <- length(bad) / sum(inds %in% bad)
    w[inds[inds %in% ok]] <- length(ok) / sum(inds %in% ok)
    w[inds[inds %in% good]] <- length(good) / sum(inds %in% good)
  } else {
    ## all points used
    inds <- seq_len(n)
    w <- rep(1, n)
  }

  ## ensure weights are normalized
  w <- w / sum(w)

  return(nlist(inds, w))
}