1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
|
#' A parent class for different importance sampling methods.
#'
#' @inheritParams psis
#' @param method The importance sampling method to use. The following methods
#' are implemented:
#' * [`"psis"`][psis]: Pareto-Smoothed Importance Sampling (PSIS). Default method.
#' * [`"tis"`][tis]: Truncated Importance Sampling (TIS) with truncation at
#' `sqrt(S)`, where `S` is the number of posterior draws.
#' * [`"sis"`][sis]: Standard Importance Sampling (SIS).
#'
importance_sampling <- function(log_ratios, method, ...) {
UseMethod("importance_sampling")
}
#' @rdname importance_sampling
#' @inheritParams psis
#' @export
importance_sampling.array <-
function(log_ratios, method,
...,
r_eff = 1,
cores = getOption("mc.cores", 1)) {
cores <- loo_cores(cores)
stopifnot(length(dim(log_ratios)) == 3)
assert_importance_sampling_method_is_implemented(method)
log_ratios <- validate_ll(log_ratios)
log_ratios <- llarray_to_matrix(log_ratios)
r_eff <- prepare_psis_r_eff(r_eff, len = ncol(log_ratios))
do_importance_sampling(log_ratios, r_eff = r_eff, cores = cores, method = method)
}
#' @rdname importance_sampling
#' @inheritParams psis
#' @export
importance_sampling.matrix <-
function(log_ratios, method,
...,
r_eff = 1,
cores = getOption("mc.cores", 1)) {
cores <- loo_cores(cores)
assert_importance_sampling_method_is_implemented(method)
log_ratios <- validate_ll(log_ratios)
r_eff <- prepare_psis_r_eff(r_eff, len = ncol(log_ratios))
do_importance_sampling(log_ratios, r_eff = r_eff, cores = cores, method = method)
}
#' @rdname importance_sampling
#' @inheritParams psis
#' @export
importance_sampling.default <-
function(log_ratios, method, ..., r_eff = 1) {
stopifnot(is.null(dim(log_ratios)) || length(dim(log_ratios)) == 1)
assert_importance_sampling_method_is_implemented(method)
dim(log_ratios) <- c(length(log_ratios), 1)
r_eff <- prepare_psis_r_eff(r_eff, len = 1)
importance_sampling.matrix(log_ratios, r_eff = r_eff, cores = 1, method = method)
}
#' @export
dim.importance_sampling <- function(x) {
attr(x, "dims")
}
#' Extract importance sampling weights
#'
#' @export
#' @export weights.importance_sampling
#' @method weights importance_sampling
#' @param object An object returned by [psis()], [tis()], or [sis()].
#' @param log Should the weights be returned on the log scale? Defaults to
#' `TRUE`.
#' @param normalize Should the weights be normalized? Defaults to `TRUE`.
#' @param ... Ignored.
#'
#' @return The `weights()` method returns an object with the same dimensions as
#' the `log_weights` component of `object`. The `normalize` and `log`
#' arguments control whether the returned weights are normalized and whether
#' or not to return them on the log scale.
#'
#' @examples
#' # See the examples at help("psis")
#'
weights.importance_sampling <-
function(object,
...,
log = TRUE,
normalize = TRUE) {
out <- object[["log_weights"]] # smoothed but unnormalized log weights
if (normalize) {
out <- sweep(out,
MARGIN = 2,
STATS = attr(object, "norm_const_log"), # colLogSumExp(log_weights)
check.margin = FALSE)
}
if (!log) {
out <- exp(out)
}
return(out)
}
# internal ----------------------------------------------------------------
#' Validate selected importance sampling method
#' @noRd
#' @keywords internal
#' @description
#' Currently implemented importance sampling methods
assert_importance_sampling_method_is_implemented <- function(x){
if (!x %in% implemented_is_methods()) {
stop("Importance sampling method '",
x,
"' is not implemented. Implemented methods: '",
paste0(implemented_is_methods, collapse = "', '"),
"'")
}
}
implemented_is_methods <- function() c("psis", "tis", "sis")
#' Structure the object returned by the importance_sampling methods
#'
#' @noRd
#' @param unnormalized_log_weights Smoothed and possibly truncated log weights,
#' but unnormalized.
#' @param pareto_k Vector of GPD k estimates.
#' @param tail_len Vector of tail lengths used to fit GPD.
#' @param r_eff Vector of relative MCMC ESS (n_eff) for `exp(log lik)`
#' @template is_method
#' @return A list of class `"psis"` with structure described in the main doc at
#' the top of this file.
#'
importance_sampling_object <-
function(unnormalized_log_weights,
pareto_k,
tail_len,
r_eff,
method) {
stopifnot(is.matrix(unnormalized_log_weights))
methods <- unique(method)
stopifnot(all(methods %in% implemented_is_methods()))
if (length(methods) == 1) {
method <- methods
classes <- c(tolower(method), "importance_sampling", "list")
} else {
classes <- c("importance_sampling", "list")
}
norm_const_log <- matrixStats::colLogSumExps(unnormalized_log_weights)
out <- structure(
list(
log_weights = unnormalized_log_weights,
diagnostics = list(pareto_k = pareto_k, n_eff = NULL, r_eff = r_eff)
),
# attributes
norm_const_log = norm_const_log,
tail_len = tail_len,
r_eff = r_eff,
dims = dim(unnormalized_log_weights),
method = method,
class = classes
)
# need normalized weights (not on log scale) for psis_n_eff
w <- weights(out, normalize = TRUE, log = FALSE)
out$diagnostics[["n_eff"]] <- psis_n_eff(w, r_eff)
return(out)
}
#' Do importance sampling given matrix of log weights
#'
#' @noRd
#' @param lr Matrix of log ratios (`-loglik`)
#' @param r_eff Vector of relative effective sample sizes
#' @param cores User's integer `cores` argument
#' @return A list with class `"psis"` and structure described in the main doc at
#' the top of this file.
#'
do_importance_sampling <- function(log_ratios, r_eff, cores, method) {
stopifnot(cores == as.integer(cores))
assert_importance_sampling_method_is_implemented(method)
N <- ncol(log_ratios)
S <- nrow(log_ratios)
k_threshold <- ps_khat_threshold(S)
tail_len <- n_pareto(r_eff, S)
if (method == "psis") {
is_fun <- do_psis_i
throw_tail_length_warnings(tail_len)
} else if (method == "tis") {
is_fun <- do_tis_i
} else if (method == "sis") {
is_fun <- do_sis_i
} else {
stop("Incorrect IS method.")
}
if (cores == 1) {
lw_list <- lapply(seq_len(N), function(i)
is_fun(log_ratios_i = log_ratios[, i], tail_len_i = tail_len[i]))
} else {
if (!os_is_windows()) {
lw_list <- parallel::mclapply(
X = seq_len(N),
mc.cores = cores,
FUN = function(i)
is_fun(log_ratios_i = log_ratios[, i], tail_len_i = tail_len[i])
)
} else {
cl <- parallel::makePSOCKcluster(cores)
on.exit(parallel::stopCluster(cl))
lw_list <-
parallel::parLapply(
cl = cl,
X = seq_len(N),
fun = function(i)
is_fun(log_ratios_i = log_ratios[, i], tail_len_i = tail_len[i])
)
}
}
log_weights <- psis_apply(lw_list, "log_weights", fun_val = numeric(S))
pareto_k <- psis_apply(lw_list, "pareto_k")
throw_pareto_warnings(pareto_k, k_threshold)
importance_sampling_object(
unnormalized_log_weights = log_weights,
pareto_k = pareto_k,
tail_len = tail_len,
r_eff = r_eff,
method = rep(method, length(pareto_k)) # Conform to other attr that exist per obs.
)
}
|