File: importance_sampling.R

package info (click to toggle)
r-cran-loo 2.9.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 4,836 kB
  • sloc: sh: 15; makefile: 2
file content (236 lines) | stat: -rw-r--r-- 7,542 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
#' A parent class for different importance sampling methods.
#'
#' @inheritParams psis
#' @param method The importance sampling method to use. The following methods
#'   are implemented:
#' * [`"psis"`][psis]: Pareto-Smoothed Importance Sampling (PSIS). Default method.
#' * [`"tis"`][tis]: Truncated Importance Sampling (TIS) with truncation at
#'   `sqrt(S)`, where `S` is the number of posterior draws.
#' * [`"sis"`][sis]: Standard Importance Sampling (SIS).
#'
importance_sampling <- function(log_ratios, method, ...) {
  UseMethod("importance_sampling")
}


#' @rdname importance_sampling
#' @inheritParams psis
#' @export
importance_sampling.array <-
  function(log_ratios, method,
           ...,
           r_eff = 1,
           cores = getOption("mc.cores", 1)) {
    cores <- loo_cores(cores)
    stopifnot(length(dim(log_ratios)) == 3)
    assert_importance_sampling_method_is_implemented(method)
    log_ratios <- validate_ll(log_ratios)
    log_ratios <- llarray_to_matrix(log_ratios)
    r_eff <- prepare_psis_r_eff(r_eff, len = ncol(log_ratios))
    do_importance_sampling(log_ratios, r_eff = r_eff, cores = cores, method = method)
  }

#' @rdname importance_sampling
#' @inheritParams psis
#' @export
importance_sampling.matrix <-
  function(log_ratios, method,
           ...,
           r_eff = 1,
           cores = getOption("mc.cores", 1)) {
    cores <- loo_cores(cores)
    assert_importance_sampling_method_is_implemented(method)
    log_ratios <- validate_ll(log_ratios)
    r_eff <- prepare_psis_r_eff(r_eff, len = ncol(log_ratios))
    do_importance_sampling(log_ratios, r_eff = r_eff, cores = cores, method = method)
  }

#' @rdname importance_sampling
#' @inheritParams psis
#' @export
importance_sampling.default <-
  function(log_ratios, method, ..., r_eff = 1) {
    stopifnot(is.null(dim(log_ratios)) || length(dim(log_ratios)) == 1)
    assert_importance_sampling_method_is_implemented(method)
    dim(log_ratios) <- c(length(log_ratios), 1)
    r_eff <- prepare_psis_r_eff(r_eff, len = 1)
    importance_sampling.matrix(log_ratios, r_eff = r_eff, cores = 1, method = method)
  }


#' @export
dim.importance_sampling <- function(x) {
  attr(x, "dims")
}


#' Extract importance sampling weights
#'
#' @export
#' @export weights.importance_sampling
#' @method weights importance_sampling
#' @param object An object returned by [psis()], [tis()], or [sis()].
#' @param log Should the weights be returned on the log scale? Defaults to
#'   `TRUE`.
#' @param normalize Should the weights be normalized? Defaults to `TRUE`.
#' @param ... Ignored.
#'
#' @return The `weights()` method returns an object with the same dimensions as
#'   the `log_weights` component of `object`. The `normalize` and `log`
#'   arguments control whether the returned weights are normalized and whether
#'   or not to return them on the log scale.
#'
#' @examples
#' # See the examples at help("psis")
#'
weights.importance_sampling <-
  function(object,
           ...,
           log = TRUE,
           normalize = TRUE) {
    out <- object[["log_weights"]] # smoothed but unnormalized log weights
    if (normalize) {
      out <- sweep(out,
                   MARGIN = 2,
                   STATS = attr(object, "norm_const_log"), # colLogSumExp(log_weights)
                   check.margin = FALSE)
    }
    if (!log) {
      out <- exp(out)
    }

    return(out)
  }

# internal ----------------------------------------------------------------

#' Validate selected importance sampling method
#' @noRd
#' @keywords internal
#' @description
#' Currently implemented importance sampling methods
assert_importance_sampling_method_is_implemented <- function(x){
  if (!x %in% implemented_is_methods()) {
    stop("Importance sampling method '",
         x,
         "' is not implemented. Implemented methods: '",
         paste0(implemented_is_methods, collapse = "', '"),
         "'")
  }
}
implemented_is_methods <- function() c("psis", "tis", "sis")


#' Structure the object returned by the importance_sampling methods
#'
#' @noRd
#' @param unnormalized_log_weights Smoothed and possibly truncated log weights,
#'   but unnormalized.
#' @param pareto_k Vector of GPD k estimates.
#' @param tail_len Vector of tail lengths used to fit GPD.
#' @param r_eff Vector of relative MCMC ESS (n_eff) for `exp(log lik)`
#' @template is_method
#' @return A list of class `"psis"` with structure described in the main doc at
#'   the top of this file.
#'
importance_sampling_object <-
  function(unnormalized_log_weights,
           pareto_k,
           tail_len,
           r_eff,
           method) {
    stopifnot(is.matrix(unnormalized_log_weights))
    methods <- unique(method)
    stopifnot(all(methods %in% implemented_is_methods()))
    if (length(methods) == 1) {
      method <- methods
      classes <- c(tolower(method), "importance_sampling", "list")
    } else {
      classes <- c("importance_sampling", "list")
    }

    norm_const_log <- matrixStats::colLogSumExps(unnormalized_log_weights)
    out <- structure(
      list(
        log_weights = unnormalized_log_weights,
        diagnostics = list(pareto_k = pareto_k, n_eff = NULL, r_eff = r_eff)
      ),
      # attributes
      norm_const_log = norm_const_log,
      tail_len = tail_len,
      r_eff = r_eff,
      dims = dim(unnormalized_log_weights),
      method = method,
      class = classes
    )

    # need normalized weights (not on log scale) for psis_n_eff
    w <- weights(out, normalize = TRUE, log = FALSE)
    out$diagnostics[["n_eff"]] <- psis_n_eff(w, r_eff)
    return(out)
  }

#' Do importance sampling given matrix of log weights
#'
#' @noRd
#' @param lr Matrix of log ratios (`-loglik`)
#' @param r_eff Vector of relative effective sample sizes
#' @param cores User's integer `cores` argument
#' @return A list with class `"psis"` and structure described in the main doc at
#'   the top of this file.
#'
do_importance_sampling <- function(log_ratios, r_eff, cores, method) {
  stopifnot(cores == as.integer(cores))
  assert_importance_sampling_method_is_implemented(method)
  N <- ncol(log_ratios)
  S <- nrow(log_ratios)
  k_threshold <- ps_khat_threshold(S)
  tail_len <- n_pareto(r_eff, S)

  if (method == "psis") {
    is_fun <- do_psis_i
    throw_tail_length_warnings(tail_len)
  } else if (method == "tis") {
    is_fun <- do_tis_i
  } else if (method == "sis") {
    is_fun <- do_sis_i
  } else {
    stop("Incorrect IS method.")
  }

  if (cores == 1) {
    lw_list <- lapply(seq_len(N), function(i)
      is_fun(log_ratios_i = log_ratios[, i], tail_len_i = tail_len[i]))
  } else {
    if (!os_is_windows()) {
      lw_list <- parallel::mclapply(
        X = seq_len(N),
        mc.cores = cores,
        FUN = function(i)
          is_fun(log_ratios_i = log_ratios[, i], tail_len_i = tail_len[i])
      )
    } else {
      cl <- parallel::makePSOCKcluster(cores)
      on.exit(parallel::stopCluster(cl))
      lw_list <-
        parallel::parLapply(
          cl = cl,
          X = seq_len(N),
          fun = function(i)
            is_fun(log_ratios_i = log_ratios[, i], tail_len_i = tail_len[i])
        )
    }
  }

  log_weights <- psis_apply(lw_list, "log_weights", fun_val = numeric(S))
  pareto_k <- psis_apply(lw_list, "pareto_k")
  throw_pareto_warnings(pareto_k, k_threshold)

  importance_sampling_object(
    unnormalized_log_weights = log_weights,
    pareto_k = pareto_k,
    tail_len = tail_len,
    r_eff = r_eff,
    method = rep(method, length(pareto_k)) # Conform to other attr that exist per obs.
  )
}