1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913
|
#' Scatterplots of MCMC draws
#'
#' Scatterplots, hexagonal heatmaps, and pairs plots from MCMC draws. See the
#' **Plot Descriptions** section, below, for details.
#'
#' @name MCMC-scatterplots
#' @family MCMC
#'
#' @template args-mcmc-x
#' @template args-pars
#' @template args-regex_pars
#' @template args-transformations
#' @param ... Currently ignored.
#' @param size,alpha For `mcmc_scatter()`, passed to
#' [ggplot2::geom_point()] to control the appearance of the points.
#' @param bins,binwidth For `mcmc_hex()`, an optional numeric vector of
#' *length two* passed to [ggplot2::geom_hex()] to override the
#' default binwidth in both the vertical and horizontal directions.
#'
#' @param np Optionally, a data frame of NUTS sampler parameters, either created
#' by [nuts_params()] or in the same form as the object returned by
#' [nuts_params()]. The colors, shapes, and sizes of the
#' superimposed points can be customized using the `np_style` argument.
#' @param np_style If `np` is specified, `np_style` can be a call to
#' the `scatter_style_np()` helper function (for `mcmc_scatter()`) or
#' the `pairs_style_np()` helper function (for `mcmc_pairs()`) to specify
#' arguments controlling the appearance of superimposed points representing
#' NUTS diagnostic information. (Note: for `pairs_style_np()` the
#' `size` arguments are interpreted as scaling factors).
#'
#' @return `mcmc_scatter()` and `mcmc_hex()` return a ggplot object that
#' can be further customized using the **ggplot2** package.
#'
#' `mcmc_pairs()` returns many ggplot objects organized into a grid via
#' [bayesplot_grid()].
#'
#' @section Plot Descriptions:
#' \describe{
#' \item{`mcmc_scatter()`}{
#' Bivariate scatterplot of posterior draws. If using a very large number of
#' posterior draws then `mcmc_hex()` may be preferable to avoid
#' overplotting. For models fit using [NUTS] the `np`,
#' and `np_style` arguments can be used to add additional information in
#' the plot (in this case the approximate location of divergences).
#' For more on why the scatter plot with divergences is a useful
#' diagnostic tool see [Gabry et al. (2019)](https://github.com/jgabry/bayes-vis-paper#readme).
#' }
#' \item{`mcmc_hex()`}{
#' Hexagonal heatmap of 2-D bin counts. This plot is useful in cases where
#' the posterior sample size is large enough that `mcmc_scatter()` suffers
#' from overplotting.
#' }
#' \item{`mcmc_pairs()`}{
#' A square plot matrix with univariate marginal distributions along the
#' diagonal (as histograms or kernel density plots) and bivariate
#' distributions off the diagonal (as scatterplots or hex heatmaps).
#'
#' For the off-diagonal plots, the default is to split the chains so that
#' (roughly) half are displayed above the diagonal and half are below (all
#' chains are always merged together for the plots along the diagonal). Other
#' possibilities are available by setting the `condition` argument.
#'
#' Additionally, extra diagnostic information for models fit using
#' [NUTS] can be added to the pairs plot using the `lp`,
#' `np`, and `np_style` arguments. If `np` is specified (and
#' `condition` is *not* `"divergent__"`), then points (red, by
#' default) will be superimposed onto the off-diagonal plots indicating which
#' (if any) iterations encountered a divergent transition. Also, if both
#' `np` and `max_treedepth` are specified then points (yellow, by
#' default) will be superimposed to indicate a transition that hit the
#' maximum treedepth rather than terminated its evolution normally. The
#' `np_style` argument can be used with the `pairs_style_np()`
#' convenience function to change the appearance of these overlaid points.
#' See the **Examples** section.
#' }
#' }
#'
#' @template reference-vis-paper
#'
#' @examples
#' library("ggplot2")
#'
#' # some parameter draws to use for demonstration
#' x <- example_mcmc_draws(params = 6)
#' dimnames(x)
#'
#' # scatterplot of alpha vs log(sigma)
#' color_scheme_set("teal")
#' (p <- mcmc_scatter(x, pars = c("alpha", "sigma"),
#' transform = list(sigma = "log")))
#' p +
#' labs(
#' title = "Insert your own headline-grabbing title",
#' subtitle = "with a provocative subtitle",
#' caption = "and a controversial caption",
#' x = expression(alpha),
#' y = expression(log(sigma))
#' )
#'
#' # add ellipse
#' p + stat_ellipse(level = 0.9, color = "gray20", size = 1)
#'
#' # add contour
#' color_scheme_set("red")
#' p2 <- mcmc_scatter(x, pars = c("alpha", "sigma"), size = 3.5, alpha = 0.25)
#' p2 + stat_density_2d(color = "black", size = .5)
#'
#' # can also add lines/smooths
#' color_scheme_set("pink")
#' (p3 <- mcmc_scatter(x, pars = c("alpha", "beta[3]"), alpha = 0.25, size = 3))
#' p3 + geom_smooth(method = "lm", se = FALSE, color = "gray20",
#' size = .75, linetype = 2)
#'
#' \donttest{
#' if (requireNamespace("hexbin", quietly = TRUE)) {
#' # hexagonal heatmap
#' color_scheme_set("brightblue")
#' (p <- mcmc_hex(x, pars = c("sigma", "alpha"), transform = list(sigma = "log")))
#' p + plot_bg(fill = "gray95")
#' p + plot_bg(fill = "gray95") + panel_bg(fill = "gray70")
#' }
#' }
NULL
#' @rdname MCMC-scatterplots
#' @export
mcmc_scatter <- function(x,
pars = character(),
regex_pars = character(),
transformations = list(),
...,
size = 2.5,
alpha = 0.8,
np = NULL,
np_style = scatter_style_np()) {
check_ignored_arguments(...)
.mcmc_scatter(
x,
pars = pars,
regex_pars = regex_pars,
transformations = transformations,
size = size,
alpha = alpha,
hex = FALSE,
binwidth = NULL,
np = np,
np_style = np_style
)
}
#' @rdname MCMC-scatterplots
#' @export
#'
mcmc_hex <- function(x,
pars = character(),
regex_pars = character(),
transformations = list(),
...,
bins = 30,
binwidth = NULL) {
suggested_package("scales")
suggested_package("hexbin")
check_ignored_arguments(...)
.mcmc_scatter(
x,
pars = pars,
regex_pars = regex_pars,
transformations = transformations,
hex = TRUE,
bins = bins,
binwidth = binwidth,
size = NULL,
alpha = NULL
)
}
#' @rdname MCMC-scatterplots
#' @export
#' @param condition For `mcmc_pairs()`, a call to the `pairs_condition()`
#' helper function, which is used to specify a criterion for determining which
#' chains (or iterations) are shown in the plots above the diagonal and which
#' are shown in the plots below the diagonal. The histograms (or density
#' plots) along the diagonal are always made using all chains and iterations,
#' but the scatterplots (or hex plots) above and below the diagonal show
#' different combinations of chains/iterations depending on `condition`.
#' The default is a call to `pairs_condition()` with none of its arguments
#' specified. In this case half of the chains (or roughly half if there are an
#' odd number) will be used in the plots above the diagonal and the rest in
#' the plots below the diagonal. The `chains`, `draws`, and
#' `nuts` arguments to `pairs_condition()`, which are documented
#' below, can be used to change this default.
#'
#' @param lp For `mcmc_pairs()`, a molten data frame of draws of the
#' log-posterior or, more commonly, of a quantity equal to the log-posterior
#' up to a constant. `lp` should either be created via
#' [log_posterior()] or be an object with the same form as the
#' object returned by [log_posterior()].
#' @param max_treedepth For `mcmc_pairs()`, an integer representing the
#' maximum treedepth allowed when fitting the model (if fit using NUTS). This
#' is only needed for detecting which transitions (if any) hit the maximum
#' treedepth.
#' @param diag_fun,off_diag_fun For `mcmc_pairs()`, the plotting function to
#' use for the plots along the diagonal and for the off-diagonal plots,
#' respectively. Currently `diag_fun` can be `"hist"` for histogram
#' or `"dens"` for density, and `off_diag_fun` can be
#' `"scatter"` for scatterplot or `"hex"` for a hexagonal heatmap.
#' @param diag_args,off_diag_args For `mcmc_pairs()`, optional named lists of
#' arguments to pass to the functions implied by the `diag_fun` and
#' `off_diag_fun` arguments, respectively. For example, if
#' `off_diag_fun` is `"scatter"` then `off_diag_args` could
#' include optional arguments to `mcmc_scatter()` like `size` and
#' `alpha`.
#' @param grid_args,save_gg_objects For `mcmc_pairs()`, arguments to pass to
#' [bayesplot_grid()]. For example, since `mcmc_pairs()` returns
#' more than a single ggplot object, using `ggtitle()` afterwards
#' will not work. But you you can still add a title to the plot using
#' `grid_args = list(top="My title")`.
#'
#' @examples
#' \donttest{
#' color_scheme_set("purple")
#'
#' # pairs plots
#' # default of condition=NULL implies splitting chains between upper and lower panels
#' mcmc_pairs(x, pars = "alpha", regex_pars = "beta\\[[1,4]\\]",
#' off_diag_args = list(size = 1, alpha = 0.5))
#'
#' # change to density plots instead of histograms and hex plots instead of
#' # scatterplots
#' mcmc_pairs(x, pars = "alpha", regex_pars = "beta\\[[1,4]\\]",
#' diag_fun = "dens", off_diag_fun = "hex")
#'
#' # plot chain 1 above diagonal and chains 2, 3, and 4 below
#' color_scheme_set("brightblue")
#' mcmc_pairs(x, pars = "alpha", regex_pars = "beta\\[[1,4]\\]",
#' diag_fun = "dens", off_diag_fun = "hex",
#' condition = pairs_condition(chains = list(1, 2:4)))
#' }
#'
#' \dontrun{
#' ### Adding NUTS diagnostics to scatterplots and pairs plots
#'
#' # examples using rstanarm package
#' library(rstanarm)
#'
#' # for demonstration purposes, intentionally fit a model that
#' # will (almost certainly) have some divergences
#' fit <- stan_glm(
#' mpg ~ ., data = mtcars,
#' iter = 1000, refresh = 0,
#' # this combo of prior and adapt_delta should lead to some divergences
#' prior = hs(),
#' adapt_delta = 0.9
#' )
#' posterior <- as.array(fit)
#' np <- nuts_params(fit)
#'
#' # mcmc_scatter with divergences highlighted
#' color_scheme_set("brightblue")
#' mcmc_scatter(posterior, pars = c("wt", "sigma"), np = np)
#'
#' color_scheme_set("darkgray")
#' div_style <- scatter_style_np(div_color = "green", div_shape = 4, div_size = 4)
#' mcmc_scatter(posterior, pars = c("sigma", "(Intercept)"),
#' np = np, np_style = div_style)
#'
#' # split the draws according to above/below median accept_stat__
#' # and show approximate location of divergences (red points)
#' color_scheme_set("brightblue")
#' mcmc_pairs(
#' posterior,
#' pars = c("wt", "cyl", "sigma"),
#' off_diag_args = list(size = 1, alpha = 1/3),
#' condition = pairs_condition(nuts = "accept_stat__"),
#' np = np
#' )
#'
#' # more customizations:
#' # - transform sigma to log(sigma)
#' # - median log-posterior as 'condition'
#' # - hex instead of scatter for off-diagonal plots
#' # - show points where max treedepth hit in blue
#' color_scheme_set("darkgray")
#' mcmc_pairs(
#' posterior,
#' pars = c("wt", "cyl", "sigma"),
#' transform = list(sigma = "log"),
#' off_diag_fun = "hex",
#' condition = pairs_condition(nuts = "lp__"),
#' lp = log_posterior(fit),
#' np = np,
#' np_style = pairs_style_np(div_color = "firebrick",
#' td_color = "blue",
#' td_size = 2),
#' # for demonstration purposes, set max_treedepth to a value that will
#' # result in at least a few max treedepth warnings
#' max_treedepth = with(np, -1 + max(Value[Parameter == "treedepth__"]))
#' )
#' }
#'
mcmc_pairs <- function(x,
pars = character(),
regex_pars = character(),
transformations = list(),
...,
diag_fun = c("hist", "dens"),
off_diag_fun = c("scatter", "hex"),
diag_args = list(),
off_diag_args = list(),
condition = pairs_condition(),
lp = NULL,
np = NULL,
np_style = pairs_style_np(),
max_treedepth = NULL,
grid_args = list(),
save_gg_objects = TRUE) {
check_ignored_arguments(...)
stopifnot(
is.list(diag_args),
is.list(off_diag_args),
inherits(np_style, "nuts_style"),
inherits(condition, "pairs_condition")
)
diag_fun <- match.arg(diag_fun)
off_diag_fun <- match.arg(off_diag_fun)
plot_diagonal <- pairs_plotfun(diag_fun)
plot_off_diagonal <- pairs_plotfun(off_diag_fun)
x <- prepare_mcmc_array(x, pars, regex_pars, transformations)
x <- drop_constants_and_duplicates(x)
n_iter <- num_iters(x)
n_chain <- num_chains(x)
n_param <- num_params(x)
pars <- parameter_names(x)
if (n_chain == 1) {
warn("Only one chain in 'x'. This plot is more useful with multiple chains.")
}
if (n_param < 2) {
abort("This plot requires at least two parameters in 'x'.")
}
no_np <- is.null(np)
no_lp <- is.null(lp)
no_max_td <- is.null(max_treedepth)
if (!no_np) {
param <- sym("Parameter")
val <- sym("Value")
np <- validate_nuts_data_frame(np, lp)
divs <- dplyr::filter(np, UQ(param) == "divergent__") %>% pull(UQ(val))
divergent__ <- matrix(divs, nrow = n_iter * n_chain, ncol = n_param)[, 1]
if (!no_max_td) {
gt_max_td <- (dplyr::filter(np, UQ(param) == "treedepth__") %>% pull(UQ(val))) >= max_treedepth
max_td_hit__ <- matrix(gt_max_td, nrow = n_iter * n_chain, ncol = n_param)[, 1]
}
}
cond <- handle_condition(x, condition, np, lp)
x <- merge_chains(cond[["x"]])
mark <- cond[["mark"]]
all_pairs <- expand.grid(pars, pars,
stringsAsFactors = FALSE,
KEEP.OUT.ATTRS = FALSE)
plots <- vector("list", length = nrow(all_pairs))
use_default_binwidth <- is.null(diag_args[["binwidth"]])
for (j in seq_len(nrow(all_pairs))) {
pair <- as.character(all_pairs[j,])
if (identical(pair[1], pair[2])) {
# Diagonal
diag_args[["x"]] <- x[, pair[1], drop = FALSE]
# silence ggplot2's "Pick better value with `binwidth`" message
if (diag_fun == "hist" && use_default_binwidth)
diag_args[["binwidth"]] <- diff(range(diag_args[["x"]]))/30
plots[[j]] <-
do.call(plot_diagonal, diag_args) +
labs(subtitle = pair[1]) +
theme(axis.line.y = element_blank(),
plot.subtitle = element_text(hjust = 0.5))
} else {
# Off-diagonal
# use mark if above diagonal and !mark if below the diagonal
mark2 <- if (is_lower_tri(j, n_param)) !mark else mark
x_j <- x[mark2, pair, drop = FALSE]
if (!no_np) {
divs_j <- divergent__[mark2]
max_td_hit_j <- if (no_max_td) NULL else max_td_hit__[mark2]
} else {
divs_j <- max_td_hit_j <- NULL
}
off_diag_args[["x"]] <- x_j
plots[[j]] <- do.call(plot_off_diagonal, off_diag_args)
if (isTRUE(any(divs_j == 1))) {
divs_j_fac <- factor(as.logical(divs_j),
levels = c(FALSE, TRUE),
labels = c("NoDiv", "Div"))
plots[[j]] <- plots[[j]] +
geom_point(
aes(color = divs_j_fac, size = divs_j_fac),
shape = np_style$shape[["div"]],
alpha = np_style$alpha[["div"]],
na.rm = TRUE
)
}
if (isTRUE(any(max_td_hit_j == 1))) {
max_td_hit_j_fac <- factor(max_td_hit_j, levels = c(FALSE, TRUE),
labels = c("NoHit", "Hit"))
plots[[j]] <- plots[[j]] +
geom_point(
aes(color = max_td_hit_j_fac, size = max_td_hit_j_fac),
shape = np_style$shape[["td"]],
alpha = np_style$alpha[["td"]],
na.rm = TRUE
)
}
if (isTRUE(any(divs_j == 1)) ||
isTRUE(any(max_td_hit_j == 1)))
plots[[j]] <- format_nuts_points(plots[[j]], np_style)
}
}
plots <- lapply(plots, function(x)
x + xaxis_title(FALSE) + yaxis_title(FALSE))
bayesplot_grid(plots = plots,
legends = FALSE,
grid_args = grid_args,
save_gg_objects = save_gg_objects)
}
#' @rdname MCMC-scatterplots
#' @export
#' @param div_color,div_shape,div_size,div_alpha,td_color,td_shape,td_size,td_alpha
#' Optional arguments to the `scatter_style_np()` or `pairs_style_np()`
#' helper functions that are eventually passed to
#' [ggplot2::geom_point()].The default values are displayed in the
#' **Usage** section above.
scatter_style_np <-
function(div_color = "red",
div_shape = 16,
div_size = 2.5,
div_alpha = 1) {
stopifnot(
is.numeric(div_shape) || is.character(div_shape),
is.character(div_color),
is.numeric(div_size),
is.numeric(div_alpha) && div_alpha >= 0 && div_alpha <= 1
)
style <- list(
color = c(div = div_color),
shape = c(div = div_shape),
size = c(div = div_size),
alpha = c(div = div_alpha)
)
structure(style, class = c(class(style), "nuts_style"))
}
#' @rdname MCMC-scatterplots
#' @export
pairs_style_np <-
function(div_color = "red",
div_shape = 4,
div_size = 1,
div_alpha = 1,
td_color = "yellow2",
td_shape = 3,
td_size = 1,
td_alpha = 1) {
stopifnot(
is.numeric(div_shape) || is.character(div_shape),
is.numeric(td_shape) || is.character(td_shape),
is.character(div_color),
is.character(td_color),
is.numeric(div_size),
is.numeric(td_size),
is.numeric(div_alpha) && div_alpha >= 0 && div_alpha <= 1,
is.numeric(td_alpha) && td_alpha >= 0 && td_alpha <= 1
)
style <- list(
color = c(div = div_color, td = td_color),
shape = c(div = div_shape, td = td_shape),
size = c(div = div_size, td = td_size),
alpha = c(div = div_alpha, td = td_alpha)
)
structure(style, class = c(class(style), "nuts_style"))
}
#' @rdname MCMC-scatterplots
#' @export
#' @param chains,draws,nuts Optional arguments to the `pairs_condition()`
#' helper function, which is used to specify the `condition` argument for
#' `mcmc_pairs()`.
#' \itemize{
#' \item The `chains` argument can be used to select some subset of the
#' chains. If `chains` is an integer vector then the behavior is the same
#' as the default (half the chains above the diagonal and half below) except
#' using only the specified subset of chains. Alternatively, `chains` can
#' be a list of two integer vectors with the first specifying the chains to be
#' shown in the plots above the diagonal and the second for below the
#' diagonal.
#' \item The `draws` argument to `pairs_condition()` can be used to
#' directly specify which realizations are plotted above and below the
#' diagonal. `draws` can be a single proportion, which is interpreted as
#' the proportion of realizations (among all chains) to plot in the lower
#' panel starting with the first realization in each chain, with the
#' complement (from the end of each chain) plotted in the upper panel.
#' Alternatively `draws` can be a logical vector with length equal to the
#' product of the number of iterations and the number of chains, in which case
#' realizations corresponding to `FALSE` and `TRUE` will be plotted
#' in the lower and upper panels, respectively.
#' \item For models fit using NUTS, the `nuts` argument to
#' `pairs_condition()` can be used. It takes a (possibly abbreviated)
#' string to select among `"accept_stat__"`, `"stepsize__"`,
#' `"treedepth__"`, `"n_leapfrog__"`, `"divergent__"`,
#' `"energy__"`, and `"lp__"`. These are the sampler parameters
#' associated with [NUTS()] (and `"lp__"` is the log-posterior
#' up to an additive constant). In this case, plots below the diagonal will
#' contain realizations that are below the median of the indicated variable
#' (or are zero in the case of `"divergent__"`), and plots above the
#' diagonal will contain realizations that are greater than or equal to the
#' median of the indicated variable (or are one in the case of
#' `"divergent__"`). If `"lp__"` is used then the `lp`
#' argument to `mcmc_pairs` must also be specified. For the other NUTS
#' parameters the `np` argument to `mcmc_pairs()` must also be
#' specified.
#' }
#'
pairs_condition <- function(chains = NULL, draws = NULL, nuts = NULL) {
.ignore_args <- function(..., why = NULL) {
dots <- list(...)
nms <- names(dots)[!sapply(dots, is.null)]
if (length(nms)) {
inform(paste0(
"The following specified arguments were ignored by 'pairs_condition' ",
"because ", why, ": ",
paste(sQuote(nms), collapse = ", ")
))
}
}
.error_duplicate_chains <- function() {
abort("Each chain can only be specified once in the 'chains' argument to 'pairs_condition'.")
}
if (is.null(chains) &&
is.null(draws) &&
is.null(nuts)) {
# default: half of the chains above diag, half below
cond <- list()
cond_type <- "default"
} else if (!is.null(chains)) {
# Using 'chains' argument
.ignore_args(
draws = draws, nuts = nuts,
why = "they are superseded by 'chains'"
)
if (is.list(chains)) {
# list of two integer vectors, each specifying a subset of the chains
stopifnot(length(chains) == 2)
chain_vec <- unlist(chains, use.names = FALSE)
if (length(chain_vec) != length(unique(chain_vec)))
.error_duplicate_chains()
cond <- list(upper = as.integer(chains[[1]]),
lower = as.integer(chains[[2]]))
cond_type <- "chain_list"
} else if (is.numeric(chains)) {
# single vector specifying a subset of chains
stopifnot(NCOL(chains) == 1, all(chains == as.integer(chains)))
if (length(chains) != length(unique(chains)))
.error_duplicate_chains()
cond <- as.integer(chains)
cond_type <- "chain_vector"
} else {
abort(paste(
"The 'chains' argument to 'pairs_condition' must be",
"an integer vector or a list of two integer vectors."
))
}
} else if (!is.null(draws)) {
# Using 'draws' argument
.ignore_args(nuts = nuts, why = "they are superseded by 'draws'")
if (is.numeric(draws)) {
# proportion of realizations (among all chains) to plot in the lower panel
stopifnot(draws > 0 && draws < 1)
cond <- draws
cond_type <- "draws_proportion"
} else if (is.logical(draws)) {
# T/F for each iteration to split into upper/lower panels
cond <- draws
cond_type <- "draws_selection"
} else {
abort(paste(
"The 'draws' argument to 'pairs_condition' must be",
"a single proportion or a logical vector."
))
}
} else {
# Using 'nuts' argument
if (!is.character(nuts) || length(nuts) > 1) {
abort("The 'nuts' argument to 'pairs_condition' must be a single string.")
}
cond_type <- "nuts"
cond <- match.arg(nuts, several.ok = FALSE,
choices = c("accept_stat__", "stepsize__",
"treedepth__", "n_leapfrog__",
"divergent__", "energy__", "lp__"))
}
structure(
cond,
class = c(class(cond), "pairs_condition"),
type = cond_type # this attribute is used later by handle_condition()
)
}
# internal ----------------------------------------------------------------
#' @importFrom dplyr pull
.mcmc_scatter <- function(x,
pars = character(),
regex_pars = character(),
transformations = list(),
hex = FALSE,
size = 2.5,
alpha = 0.8,
bins = 30,
binwidth = NULL,
np = NULL,
np_style = scatter_style_np()) {
x <- prepare_mcmc_array(x, pars, regex_pars, transformations)
if (num_params(x) != 2) {
abort(paste(
"For 'mcmc_scatter' and 'mcmc_hex' exactly 2 parameters must be selected.",
"'mcmc_pairs' can be used for more than 2 parameters."
))
}
x <- merge_chains(x)
parnames <- colnames(x)[1:2]
has_divs <- !is.null(np)
xydata <- data.frame(x = c(x[, 1]), y = c(x[, 2]))
if (has_divs) {
if (hex) {
warn("'np' is currently ignored for hex plots.")
}
stopifnot(inherits(np_style, "nuts_style"))
np <- validate_nuts_data_frame(np)
param <- sym("Parameter")
val <- sym("Value")
divg <- sym("Divergent")
xydata$Divergent <-
np %>%
dplyr::filter(UQ(param) == "divergent__") %>%
pull(UQ(val))
divdata <- dplyr::filter(xydata, UQ(divg) == 1)
xydata <- dplyr::filter(xydata, UQ(divg) == 0)
}
graph <- ggplot(data = xydata, aes(x = .data$x, y = .data$y)) +
bayesplot_theme_get()
if (!hex) { # scatterplot
graph <- graph +
geom_point(
shape = 21,
color = get_color("dh"),
fill = get_color("d"),
size = size,
alpha = alpha
)
if (has_divs) {
graph <- graph +
geom_point(
data = divdata,
color = np_style$color[["div"]],
size = np_style$size[["div"]],
alpha = np_style$alpha[["div"]],
shape = np_style$shape[["div"]]
)
}
} else { # hex binning
graph <- graph +
geom_hex(
aes(fill = scales::rescale(after_stat(density))),
bins = bins,
binwidth = binwidth
) +
scale_fill_gradientn(
"Density",
colors = unlist(color_scheme_get()),
breaks = c(.1, .9),
labels = c("low", "high")
)
}
graph + labs(x = parnames[1], y = parnames[2])
}
# internal for mcmc_pairs -------------------------------------------------
#' Get plotting functions from user-specified
#' `diag_fun` and `off_diag_fun` arguments
#'
#' @noRd
#' @param x User specified `diag_fun` or `off_diag_fun` argument to `mcmc_pairs()`
pairs_plotfun <- function(x) {
fun <- paste0("mcmc_", x)
utils::getFromNamespace(fun, "bayesplot")
}
#' Unstack molten data frame
#'
#' @noRd
#' @param df A data frame (from `nuts_params()`, `log_posterior()`, etc)
#' @param .form Same as `form` arg to `utils::unstack()`
unstack_to_matrix <- function(df, .form) {
x <- utils::unstack(df, form = .form)
as.matrix(x)
}
#' Check if off-diagonal plot is above or below the diagonal
#'
#' @noRd
#' @param j integer (index)
#' @param n Number of parameters (number of plots = `n^2`)
#' @return `TRUE` if below the diagonal, `FALSE` if above the diagonal
is_lower_tri <- function(j, n) {
idx <- array_idx_j(j, n)
lower_tri <- lower_tri_idx(n)
row_match_found(idx, lower_tri)
}
#' Get array indices of the jth element in the plot matrix
#'
#' @noRd
#' @param j integer (index)
#' @param n number of parameters (number of plots = n^2)
#' @return rwo vector (1-row matrix) containing the array indices of the jth
#' element in the plot matrix
array_idx_j <- function(j, n) {
jj <- matrix(seq_len(n^2), nrow = n, byrow = TRUE)[j]
arrayInd(jj, .dim = c(n, n))
}
#' Get indices of lower triangular elements of a square matrix
#' @noRd
#' @param n number of rows (columns) in the square matrix
lower_tri_idx <- function(n) {
a <- rev(abs(sequence(seq.int(n - 1)) - n) + 1)
b <- rep.int(seq.int(n - 1), rev(seq.int(n - 1)))
cbind(row = a, col = b)
}
#' Find which (if any) row in y is a match for x
#' @noRd
#' @param x a row vector (i.e., a matrix with 1 row)
#' @param y a matrix
#' @return either a row number in `y` or `NA` if no match
row_match_found <- function(x, y) {
stopifnot(is.matrix(x), is.matrix(y), nrow(x) == 1)
x <- as.data.frame(x)
y <- as.data.frame(y)
res <- match(
do.call(function(...) paste(..., sep=":::"), x),
do.call(function(...) paste(..., sep=":::"), y)
)
isTRUE(!is.na(res) && length(res) == 1)
}
#' Drop any constant or duplicate variables
#' @noRd
#' @param x 3-D array
drop_constants_and_duplicates <- function(x) {
x2 <- drop_consts(x)
x2 <- drop_dupes(x2)
class(x2) <- c(class(x2), "mcmc_array")
x2
}
drop_consts <- function(x) {
varying <- apply(x, 3, FUN = function(y) length(unique(c(y))) > 1)
if (all(varying))
return(x)
warn(paste(
"The following parameters were dropped because they are constant:",
paste(names(varying)[!varying], collapse = ", ")
))
x[, , varying, drop = FALSE]
}
drop_dupes <- function(x) {
dupes <- duplicated(x, MARGIN = 3)
if (!any(dupes))
return(x)
warn(paste(
"The following parameters were dropped because they are duplicative:",
paste(parameter_names(x)[dupes], collapse = ", ")
))
x[, , !dupes, drop = FALSE]
}
#' Handle user's specified `condition`
#' @noRd
#' @param x 3-D mcmc array.
#' @param condition Object returned by `pairs_condition()`.
#' @param np,lp User-specified arguments to `mcmc_pairs()`.
#' @return A named list containing `"x"` (`x`, possibly modified) and `"mark"`
#' (logical or interger vector for eventually splitting `x`).
handle_condition <- function(x, condition=NULL, np=NULL, lp=NULL) {
n_iter <- num_iters(x)
n_chain <- num_chains(x)
no_np <- is.null(np)
no_lp <- is.null(lp)
cond_type <- attr(condition, "type")
if (cond_type == "default") {
k <- ncol(x) %/% 2
mark <- c(rep(FALSE, n_iter * k), rep(TRUE, n_iter * (n_chain - k)))
} else if (cond_type == "chain_vector") {
x <- x[, condition, , drop = FALSE]
k <- ncol(x) %/% 2
n_chain <- length(condition)
mark <- c(rep(FALSE, n_iter * k), rep(TRUE, n_iter * (n_chain - k)))
} else if (cond_type == "chain_list") {
x <- x[, c(condition[[1]], condition[[2]]), , drop = FALSE]
k1 <- length(condition[[1]])
k2 <- length(condition[[2]])
mark <- c(rep(TRUE, n_iter * k1), rep(FALSE, n_iter * k2))
} else if (cond_type == "draws_proportion") {
mark <- rep(1:n_iter > (condition * n_iter), times = n_chain)
} else if (cond_type == "draws_selection") {
# T/F for each iteration to split into upper and lower
stopifnot(length(condition) == (n_iter * n_chain))
mark <- !condition
} else if (cond_type == "nuts") {
# NUTS sampler param or lp__
if (no_np && condition != "lp__")
abort(paste(
"To use this value of 'condition' the 'np' argument",
"to 'mcmc_pairs' must also be specified."
))
if (condition == "lp__") {
if (no_lp)
abort(paste(
"If 'condition' is 'lp__' then the 'lp' argument",
"to 'mcmc_pairs' must also be specified."
))
mark <- unstack_to_matrix(lp, Value ~ Chain)
} else {
param <- sym("Parameter")
mark <- dplyr::filter(np, UQ(param) == condition)
mark <- unstack_to_matrix(mark, Value ~ Chain)
}
if (condition == "divergent__") {
mark <- as.logical(mark)
} else {
mark <- c(mark) >= median(mark)
}
if (length(unique(mark)) == 1)
abort(paste(condition, "is constant so it cannot be used as a condition."))
}
list(x = x, mark = mark)
}
#' Apply scale_color_manual and scale_size_manual if plotting divergences and
#' hitting max_treedepth
#'
#' @noRd
#' @param graph ggplot object
#' @param np_args list of style arguments returned by `pairs_style_np()`
#' @return `graph`, updated
format_nuts_points <- function(graph, np_args) {
graph +
scale_color_manual(
values = set_names(c(NA, np_args$color[["div"]], NA, np_args$color[["td"]]),
c("NoDiv", "Div", "NoHit", "Hit"))
) +
scale_size_manual(
values = set_names(c(NA, rel(np_args$size[["div"]]), NA, rel(np_args$size[["td"]])),
c("NoDiv", "Div", "NoHit", "Hit"))
)
}
|