1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
|
#' Define class unions
#' @keywords internal
#' @noRd
setClassUnion("matrixOrNULL", c("matrix", "NULL"))
#' @keywords internal
#' @noRd
setClassUnion("numericOrNULL", c("numeric", "NULL"))
#' @keywords internal
#' @noRd
setClassUnion("characterOrNULL", c("character", "NULL"))
#' @keywords internal
#' @noRd
setClassUnion("logicalOrNULL", c("logical", "NULL"))
#' @keywords internal
#' @noRd
setClassUnion("functionOrNULL", c("function", "NULL"))
#' Internal S4 class for marginaleffects
#'
#' This S4 class is used internally to hold common arguments passed between
#' functions to simplify the function signatures and reduce redundant argument passing.
#'
#' @slot by Aggregation/grouping specification
#' @slot byfun Function for aggregation when using by
#' @slot call The original function call
#' @slot calling_function The name of the calling function (comparisons, predictions, hypotheses)
#' @slot comparison Comparison function specification
#' @slot cross Boolean flag for cross-contrasts
#' @slot df The degrees of freedom
#' @slot eps Epsilon value for numerical derivatives
#' @slot jacobian The jacobian matrix or NULL
#' @slot model The fitted model object
#' @slot modeldata The model data frame
#' @slot newdata The new data frame for predictions
#' @slot type The sanitized type from sanitize_type()
#' @slot vcov_model The variance-covariance matrix
#' @slot wts The weights specification
#' @keywords internal
setClass(
"marginaleffects_internal",
slots = c(
by = "ANY",
byfun = "functionOrNULL",
call = "ANY",
call_model = "ANY",
calling_function = "character",
comparison = "ANY",
conf_level = "numeric",
cross = "logicalOrNULL",
df = "ANY",
draws = "matrixOrNULL",
draws_chains = "numeric",
eps = "numericOrNULL",
hypothesis = "ANY",
hypothesis_null = "ANY",
hypothesis_direction = "ANY",
inferences = "ANY",
jacobian = "matrixOrNULL",
model = "ANY",
modeldata = "ANY", # TODO: lmerTest returns nfnGroupedData
modeldata_available = "logical",
newdata = "ANY", # Changed from "data.frame" to handle mira deferred processing
numderiv = "list",
type = "characterOrNULL",
variables = "list",
variable_class = "characterOrNULL",
variable_names_datagrid = "characterOrNULL",
variable_names_predictors = "characterOrNULL",
variable_names_response = "characterOrNULL",
variable_names_by = "characterOrNULL",
variable_names_by_hypothesis = "characterOrNULL",
variable_names_wts = "characterOrNULL",
vcov_model = "ANY",
vcov_type = "characterOrNULL",
wts = "ANY"
)
)
#' Constructor for marginaleffects_internal class
#'
#' @param model The fitted model object (required)
#' @param call The original function call (required)
#' @param vcov_model The variance-covariance matrix
#' @param df The degrees of freedom
#' @param wts The weights specification
#' @param type The sanitized type from sanitize_type()
#' @return An object of class marginaleffects_internal
#' @keywords internal
new_marginaleffects_internal <- function(
model,
call,
by = FALSE,
byfun = NULL,
comparison = NULL,
conf_level = 0.95,
cross = FALSE,
df = NULL,
draws = NULL,
draws_chains = 0,
eps = NULL,
hypothesis = NULL,
hypothesis_null = NULL,
hypothesis_direction = NULL,
jacobian = NULL,
modeldata = NULL,
numderiv = list("fdforward"),
type = NULL,
variables = list(),
variable_names_by = character(),
variable_names_by_hypothesis = character(),
vcov_model = NULL,
vcov_type = NULL,
wts = NULL) {
# For mice objects, modeldata handling is deferred to process_imputation()
if (is.null(modeldata)) {
if (inherits(model, c("mira", "amest"))) {
modeldata <- data.frame() # placeholder - actual processing happens in process_imputation()
} else {
modeldata <- get_modeldata(model, additional_variables = TRUE)
}
}
variable_class <- detect_variable_class(modeldata, model = model)
variable_names_response <- hush(insight::find_response(model,
combine = TRUE, component = "all", flatten = TRUE))
if (is.null(variable_names_response)) {
variable_names_response <- character(0)
}
variable_names_wts <- hush(insight::find_weights(model))
variable_names_predictors <- hush(insight::find_predictors(model, flatten = TRUE, verbose = FALSE))
# Extract calling function from call
calling_function <- extract_calling_function(call)
call_model <- tryCatch(insight::get_call(model), error = function(e) NULL)
methods::new(
"marginaleffects_internal",
by = by,
byfun = byfun,
call = call,
call_model = call_model,
calling_function = calling_function,
comparison = comparison,
conf_level = conf_level,
cross = cross,
df = df,
draws = draws,
draws_chains = draws_chains,
eps = eps,
hypothesis = hypothesis,
hypothesis_null = hypothesis_null,
hypothesis_direction = hypothesis_direction,
jacobian = jacobian,
model = model,
modeldata = modeldata,
modeldata_available = TRUE,
newdata = data.frame(),
numderiv = numderiv,
type = type,
variables = variables,
variable_class = variable_class,
variable_names_by = variable_names_by,
variable_names_by_hypothesis = variable_names_by_hypothesis,
variable_names_datagrid = character(),
variable_names_predictors = variable_names_predictors,
variable_names_response = variable_names_response,
variable_names_wts = variable_names_wts,
vcov_model = vcov_model,
vcov_type = vcov_type,
wts = wts
)
}
#' Extract calling function name from call object
#'
#' @param call The call object from which to extract the function name
#' @return Character string of the calling function ("comparisons", "predictions", or "hypotheses")
#' @keywords internal
extract_calling_function <- function(call) {
if (is.call(call)) {
func_name <- as.character(call[[1]])
# Handle namespaced calls like marginaleffects::predictions
if (length(func_name) > 1) {
func_name <- func_name[length(func_name)]
}
# Map avg_ functions to their base function
if (startsWith(func_name, "avg_")) {
func_name <- sub("^avg_", "", func_name)
}
# Return the appropriate function name if it's one of the main functions
if (func_name %in% c("comparisons", "predictions", "hypotheses")) {
return(func_name)
}
}
# Default fallback
return("unknown")
}
|