File: summary_stats_helpers.R

package info (click to toggle)
r-cran-shinystan 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, forky, sid, trixie
  • size: 3,172 kB
  • sloc: sh: 15; makefile: 7
file content (73 lines) | stat: -rw-r--r-- 2,428 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# param_summary -----------------------------------------------------------
# summary stats for a single parameter
.param_summary <- function(param, summary) {
  stats <-  c("Rhat", "n_eff", "mean", "sd", "2.5%", "50%", "97.5%")
  out <- summary[param, stats]
  out["n_eff"] <- round(out["n_eff"])
  outmat <- matrix(out, 1, length(out))
  colnames(outmat) <- names(out)
  rownames(outmat) <- NULL
  outmat
}


# all_summary -------------------------------------------------------------
# summary stats for all parameters
.all_summary <- function(summary, digits = 2, cols) {
  if (missing(cols))
    cols <- seq_len(ncol(summary))
  df <- as.data.frame(summary[, cols])
  df <- round(df, digits)
  if ("n_eff" %in% cols) 
    df[, "n_eff"] <- round(df[, "n_eff"])
  df
}

# tex_summary -------------------------------------------------------------
# prep for latex table
.tex_summary <- function(summary, params, cols) {
  df <- as.data.frame(summary[, cols])
  if ("n_eff" %in% cols) 
    df[, "n_eff"] <- round(df[, "n_eff"])
  cbind(Parameter = rownames(df), df)
}

# sampler_summary ---------------------------------------------------------
.sampler_stuff <- function(X, param, report) {
  sapply_funs <- function(x, fun_name) {
    funs <- list(
      maxf = function(x) max(x[, param]),
      minf = function(x) min(x[, param]),
      meanf = function(x) mean(x[, param]),
      sdf = function(x) sd(x[, param])
    )
    sapply(x, FUN = funs[[fun_name]])
  }
  out <- if (report == "maximum") sapply_funs(X, "maxf") 
    else if (report == "minimum") sapply_funs(X, "minf")
    else if (report == "sd") sapply_funs(X, "sdf")
    else sapply_funs(X, "meanf")
  
  names(out) <- paste0("chain",1:length(out))
  out
}

# summary statistics for algorithm=NUTS or algorithm=HMC sampler parameters
.sampler_summary <- function(sampler_params, warmup_val,
                             report = "average", digits = 4){ 
  
  params <- colnames(sampler_params[[1]])
  out <- sapply(params, FUN = function(p) 
    .sampler_stuff(X = sampler_params, param = p, report = report))
  
  if (length(dim(out)) > 1) { # if multiple chains
    out <- rbind("All chains" = colMeans(out), out)
    colnames(out) <- gsub("__","",colnames(out))
    out <- formatC(round(out, digits), format = 'f', digits = digits)
  } else { # if only 1 chain
    names(out) <- gsub("__.chain1", "", names(out))
    out <- round(t(out), digits)
  }
  out
}