File: class.R

package info (click to toggle)
r-cran-marginaleffects 0.32.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 2,784 kB
  • sloc: sh: 13; makefile: 8
file content (206 lines) | stat: -rw-r--r-- 6,930 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#' Define class unions
#' @keywords internal
#' @noRd
setClassUnion("matrixOrNULL", c("matrix", "NULL"))

#' @keywords internal
#' @noRd
setClassUnion("numericOrNULL", c("numeric", "NULL"))

#' @keywords internal
#' @noRd
setClassUnion("characterOrNULL", c("character", "NULL"))

#' @keywords internal
#' @noRd
setClassUnion("logicalOrNULL", c("logical", "NULL"))

#' @keywords internal
#' @noRd
setClassUnion("functionOrNULL", c("function", "NULL"))

#' Internal S4 class for marginaleffects
#'
#' This S4 class is used internally to hold common arguments passed between
#' functions to simplify the function signatures and reduce redundant argument passing.
#'
#' @slot by Aggregation/grouping specification
#' @slot byfun Function for aggregation when using by
#' @slot call The original function call
#' @slot calling_function The name of the calling function (comparisons, predictions, hypotheses)
#' @slot comparison Comparison function specification
#' @slot cross Boolean flag for cross-contrasts
#' @slot df The degrees of freedom
#' @slot eps Epsilon value for numerical derivatives
#' @slot jacobian The jacobian matrix or NULL
#' @slot model The fitted model object
#' @slot modeldata The model data frame
#' @slot newdata The new data frame for predictions
#' @slot type The sanitized type from sanitize_type()
#' @slot vcov_model The variance-covariance matrix
#' @slot wts The weights specification
#' @keywords internal
setClass(
    "marginaleffects_internal",
    slots = c(
        by = "ANY",
        byfun = "functionOrNULL",
        call = "ANY",
        call_model = "ANY",
        calling_function = "character",
        comparison = "ANY",
        conf_level = "numeric",
        cross = "logicalOrNULL",
        df = "ANY",
        draws = "matrixOrNULL",
        draws_chains = "numeric",
        eps = "numericOrNULL",
        hypothesis = "ANY",
        hypothesis_null = "ANY",
        hypothesis_direction = "ANY",
        inferences = "ANY",
        jacobian = "matrixOrNULL",
        model = "ANY",
        modeldata = "ANY", # TODO: lmerTest returns nfnGroupedData
        modeldata_available = "logical",
        newdata = "ANY", # Changed from "data.frame" to handle mira deferred processing
        numderiv = "list",
        type = "characterOrNULL",
        variables = "list",
        variable_class = "characterOrNULL",
        variable_names_datagrid = "characterOrNULL",
        variable_names_predictors = "characterOrNULL",
        variable_names_response = "characterOrNULL",
        variable_names_by = "characterOrNULL",
        variable_names_by_hypothesis = "characterOrNULL",
        variable_names_wts = "characterOrNULL",
        vcov_model = "ANY",
        vcov_type = "characterOrNULL",
        wts = "ANY"
    )
)

#' Constructor for marginaleffects_internal class
#'
#' @param model The fitted model object (required)
#' @param call The original function call (required)
#' @param vcov_model The variance-covariance matrix
#' @param df The degrees of freedom
#' @param wts The weights specification
#' @param type The sanitized type from sanitize_type()
#' @return An object of class marginaleffects_internal
#' @keywords internal
new_marginaleffects_internal <- function(
    model,
    call,
    by = FALSE,
    byfun = NULL,
    comparison = NULL,
    conf_level = 0.95,
    cross = FALSE,
    df = NULL,
    draws = NULL,
    draws_chains = 0,
    eps = NULL,
    hypothesis = NULL,
    hypothesis_null = NULL,
    hypothesis_direction = NULL,
    jacobian = NULL,
    modeldata = NULL,
    numderiv = list("fdforward"),
    type = NULL,
    variables = list(),
    variable_names_by = character(),
    variable_names_by_hypothesis = character(),
    vcov_model = NULL,
    vcov_type = NULL,
    wts = NULL) {
    # For mice objects, modeldata handling is deferred to process_imputation()
    if (is.null(modeldata)) {
        if (inherits(model, c("mira", "amest"))) {
            modeldata <- data.frame() # placeholder - actual processing happens in process_imputation()
        } else {
            modeldata <- get_modeldata(model, additional_variables = TRUE)
        }
    }

    variable_class <- detect_variable_class(modeldata, model = model)

    variable_names_response <- hush(insight::find_response(model,
        combine = TRUE, component = "all", flatten = TRUE))

    if (is.null(variable_names_response)) {
        variable_names_response <- character(0)
    }

    variable_names_wts <- hush(insight::find_weights(model))

    variable_names_predictors <- hush(insight::find_predictors(model, flatten = TRUE, verbose = FALSE))

    # Extract calling function from call
    calling_function <- extract_calling_function(call)

    call_model <- tryCatch(insight::get_call(model), error = function(e) NULL)

    methods::new(
        "marginaleffects_internal",
        by = by,
        byfun = byfun,
        call = call,
        call_model = call_model,
        calling_function = calling_function,
        comparison = comparison,
        conf_level = conf_level,
        cross = cross,
        df = df,
        draws = draws,
        draws_chains = draws_chains,
        eps = eps,
        hypothesis = hypothesis,
        hypothesis_null = hypothesis_null,
        hypothesis_direction = hypothesis_direction,
        jacobian = jacobian,
        model = model,
        modeldata = modeldata,
        modeldata_available = TRUE,
        newdata = data.frame(),
        numderiv = numderiv,
        type = type,
        variables = variables,
        variable_class = variable_class,
        variable_names_by = variable_names_by,
        variable_names_by_hypothesis = variable_names_by_hypothesis,
        variable_names_datagrid = character(),
        variable_names_predictors = variable_names_predictors,
        variable_names_response = variable_names_response,
        variable_names_wts = variable_names_wts,
        vcov_model = vcov_model,
        vcov_type = vcov_type,
        wts = wts
    )
}

#' Extract calling function name from call object
#'
#' @param call The call object from which to extract the function name
#' @return Character string of the calling function ("comparisons", "predictions", or "hypotheses")
#' @keywords internal
extract_calling_function <- function(call) {
    if (is.call(call)) {
        func_name <- as.character(call[[1]])
        # Handle namespaced calls like marginaleffects::predictions
        if (length(func_name) > 1) {
            func_name <- func_name[length(func_name)]
        }
        # Map avg_ functions to their base function
        if (startsWith(func_name, "avg_")) {
            func_name <- sub("^avg_", "", func_name)
        }
        # Return the appropriate function name if it's one of the main functions
        if (func_name %in% c("comparisons", "predictions", "hypotheses")) {
            return(func_name)
        }
    }
    # Default fallback
    return("unknown")
}