File: partykit.R

package info (click to toggle)
r-cran-plotmo 3.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 3,400 kB
  • sloc: sh: 13; makefile: 2
file content (202 lines) | stat: -rw-r--r-- 8,682 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# partykit.R: hackery for plotmo to support the partykit package

plotmo.prolog.party <- function(object, object.name, trace, ...) # called when plotmo starts
{
    check.mob.object(object)

    # Attach plotmo.importance (a character vector) to the model.

    object <- attach.party.plotmo.importance(object, trace)

    # Following is necessary because we will shortly change the class of the object
    # (and therefore getCall.party won't work, we must rely on getCall.default).
    # We need the call to get the data used to build the model.  (We can't use
    # object$data because that may contain "variable names" like "log(lstat)".)

    object$call <- getCall(object)

    # The meaning of "[[" is redefined for party objects i.e. the partykit
    # package defines "[[.party".  Since in the plotmo code we need [[ to do
    # things like object[["x"]], we change the class of the object here, so
    # [[ has its standard meaning for the object while we are in plotmo.

    trace2(trace,
        "changing class of %s from %s to \"party_plotmo\" for standard \"[[\"\n",
        object.name, quote.with.c(class(object)))

    original.class <- class(object) # save for plotmo.predict.party_plotmo
    class(object) <- "party_plotmo"
    object$original.class <- original.class
    object
}
plotmo.predict.party_plotmo <- function(object, newdata, type, ..., TRACE)
{
    stopifnot(is.character(object$original.class))
    class(object) <- object$original.class

    # suppress warnings:
    #    Warning: 'newdata' had 2 rows but variables found have 297 rows
    #    Warning in rval[ix[[i]]] <- preds[[i]] : number of items to replace is not a multiple of replacement length
    on.exit(options(warn=old.warn))
    options(warn=-1)
    old.warn <- getOption("warn")

    predict <- plotmo.predict(object, newdata, type=type,  ..., TRACE=TRACE)
    predict
}
# attach plotmo.importance (a character vector) to the model
attach.party.plotmo.importance <- function(object, trace)
{
    varimp <- try(varimp(object), silent=TRUE)
    if(is.try.err(varimp)) { # only some party objects support varimp
        # the variable(s) before the | in the formula
        varnames <- colnames(attr(object$info$terms$response, "factors"))
        # append variables actually used in the tree, in order of importance
        varnames <- c(varnames, names(varimp_party(object)))
    } else
        varnames <- names(sort(varimp, decreasing=TRUE))
    varnames.original <- varnames
    for(i in seq_along(varnames))
        varnames[i] <- naken.collapse(varnames[i]) # e.g. log(lstat) becomes lstat
    if(trace >= 1)
        cat("variable importance:", varnames, "\n")
    attr(object, "plotmo.importance") <- varnames
    object
}
# Like varimp.constparty but works for all party trees, including mob trees.
# Splits that affect more observations get more weight.
# Splits near the root get slightly more weight (lower depth).
# (This is to disambiguate vars that have equal importance otherwise.)
varimp_party <- function(object)
{
    init.varimp <- function(node, varimp, depth)
    {
        # update varimp for tree starting at node by walking the tree
        varid <- node$split$varid
        if(!is.null(varid)) {
            check.index(varid, "varid", varimp) # paranoia
            nobs <- if(!is.null(node$info$nobs)) node$info$nobs else 1
            varimp[varid] <- varimp[varid] + nobs - .0001 * depth
        }
        knodes <- partykit::kids_node(node)
        for(node in knodes)
            if(!is.null(node))
                varimp <- init.varimp(node, varimp, depth+1) # recurse
        varimp
    }
    #--- varimp_party starts here
    varnames <- colnames(object$data)
    varimp <- repl(0, length(varnames))
    names(varimp) <- varnames
    varimp <- init.varimp(object$node, varimp, depth=0)
    sort(varimp[varimp != 0], decreasing=TRUE) # discard vars not in tree, sort
}
plotmo.singles.party_plotmo <- function(object, x, nresponse, trace, all1, ...)
{
    all <- seq_along(colnames(x))
    if(all1)
        return(all)
    varnames <- attr(object, "plotmo.importance")
    stopifnot(!is.null(varnames))
    i <- match(varnames, colnames(x))
    ina <- which(is.na(i)) # sanity check
    if(length(ina)) {
        warnf(
"could not find \"%s\" in %s\nWorkaround: use all1=TRUE to plot all variables",
             varnames[ina[1]], quote.with.c(colnames(x)))
        i <- i[!is.na(i)]
    }
    if(length(i) == 0) {
        warnf("could not estimate variable importance")
        i <- seq_along(length(colnames(x))) # something went wrong, use all vars
    }
    # indices of important variables, max of 10 variables
    # (10 becauses plotmo.pairs returns 6, total is 16, therefore 4x4 grid)
    i[1: min(10, length(i))]
}
plotmo.pairs.party_plotmo <- function(object, x, nresponse, trace, all2, ...)
{
    singles <- plotmo.singles(object, x, nresponse, trace, all1=FALSE, ...)
    # choose npairs so a total of no more than 16 plots
    # npairs=5 gives 10 pairplots, npairs=4 gives 6 pairplots
    npairs <- if(length(singles) <= 6) 5 else 4
    form.pairs(singles[1: min(npairs, length(singles))])
}
# Check the mob object formula and issue a work-around message when
# the formula won't work for predictions with new data.
# This prevents err msg: 'newdata' had 1 row but variables found have 167 rows
check.mob.object <- function(object)
{
    call.fit <- getCall(object)$fit # was a fit func passed to the model building func?
    if(is.null(call.fit))
        return()
    # it's a mob object
    func <- eval(call.fit)
    stopifnot(inherits(func, "function"))
    func <- deparse(func, width.cutoff=500)
    # Is there a "(" followed by "~" followed by a lone "x," in the function body?
    # Or a "(" followed by "~" followed by "x - 1,".
    regex1 <- "\\(.*\\~.*[^a-zA-Z0-9_\\.]x,"
    regex2 <- "\\(.*\\~.*x \\- 1,"
    regex <- paste0(regex1, "|", regex2)
    grepl <- grepl(regex, func)
    if(any(grepl)) {
        # Issue the following message (details will vary depending on the fit func):
        #
        # The following formula in the mob fit function is not supported by plotmo:
        #
        #     glm(y ~ 0 + x, family = binomial, start = start, ...)
        #
        # Possible workaround: Replace the fit function with:
        #
        #     function (y, x, start = NULL, weights = NULL, offset = NULL, ...)
        #     {
        #         glm(as.formula(paste("y ~ ", paste(colnames(x)[-1], collapse="+"))),
        #            family = binomial, start = start, ...)
        #     }
        #
        # Error: The formula in the mob fit function is not supported by plotmo (see above)

        printf("\nThe following formula in the mob fit function is not supported by plotmo:\n\n")
        ifunc <- which(grepl)[1]
        cat(func[ifunc])
        regex <- "\\([^,]+,"
        func[ifunc] <- sub(regex,
"(as.formula(paste(\"y ~ \", paste(colnames(x)[-1], collapse=\"+\"))),\n            data=x,",
                          func[ifunc])
        printf("\n\nPossible workaround: Replace the fit function with:\n\n")
        printf("    %s <- ", as.character(call.fit))
        for(i in 1:length(func))
            printf("%s\n    ", func[i])
        printf("\n")
        stop0("The formula in the mob fit function is not supported by plotmo (see above)\n",
"       This is because predict.mob often fails with newdata and type=\"response\"\n",
"       e.g. example(mob); predict(pid_tree, newdata=PimaIndiansDiabetes[1:3,], type=\"response\")")
    }
}
# cforest objects
plotmo.prolog.parties <- function(object, object.name, trace, ...) # called when plotmo starts
{
    attr(object, "plotmo.importance") <-
        order.parties.vars.on.importance(object, trace) # a char vector
    object
}
order.parties.vars.on.importance <- function(object, trace) # a char vector
{
    varimp <- try(varimp(object), silent=TRUE)
    varnames <- if(is.try.err(varimp))
                    colnames(object$data)[-1] # -1 to drop response TODO is this reliable?
                else
                    names(sort(varimp, decreasing=TRUE))
    if(trace >= 1)
        cat("variable importance:", varnames, "\n")
    varnames
}
plotmo.singles.parties <- function(object, x, nresponse, trace, all1, ...)
{
    plotmo.singles.party_plotmo(object, x, nresponse, trace, all1, ...)
}
plotmo.pairs.parties <- function(object, x, nresponse, trace, all2, ...)
{
    plotmo.pairs.party_plotmo(object, x, nresponse, trace, all2, ...)
}