1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
|
# partykit.R: hackery for plotmo to support the partykit package
plotmo.prolog.party <- function(object, object.name, trace, ...) # called when plotmo starts
{
check.mob.object(object)
# Attach plotmo.importance (a character vector) to the model.
object <- attach.party.plotmo.importance(object, trace)
# Following is necessary because we will shortly change the class of the object
# (and therefore getCall.party won't work, we must rely on getCall.default).
# We need the call to get the data used to build the model. (We can't use
# object$data because that may contain "variable names" like "log(lstat)".)
object$call <- getCall(object)
# The meaning of "[[" is redefined for party objects i.e. the partykit
# package defines "[[.party". Since in the plotmo code we need [[ to do
# things like object[["x"]], we change the class of the object here, so
# [[ has its standard meaning for the object while we are in plotmo.
trace2(trace,
"changing class of %s from %s to \"party_plotmo\" for standard \"[[\"\n",
object.name, quote.with.c(class(object)))
original.class <- class(object) # save for plotmo.predict.party_plotmo
class(object) <- "party_plotmo"
object$original.class <- original.class
object
}
plotmo.predict.party_plotmo <- function(object, newdata, type, ..., TRACE)
{
stopifnot(is.character(object$original.class))
class(object) <- object$original.class
# suppress warnings:
# Warning: 'newdata' had 2 rows but variables found have 297 rows
# Warning in rval[ix[[i]]] <- preds[[i]] : number of items to replace is not a multiple of replacement length
on.exit(options(warn=old.warn))
options(warn=-1)
old.warn <- getOption("warn")
predict <- plotmo.predict(object, newdata, type=type, ..., TRACE=TRACE)
predict
}
# attach plotmo.importance (a character vector) to the model
attach.party.plotmo.importance <- function(object, trace)
{
varimp <- try(varimp(object), silent=TRUE)
if(is.try.err(varimp)) { # only some party objects support varimp
# the variable(s) before the | in the formula
varnames <- colnames(attr(object$info$terms$response, "factors"))
# append variables actually used in the tree, in order of importance
varnames <- c(varnames, names(varimp_party(object)))
} else
varnames <- names(sort(varimp, decreasing=TRUE))
varnames.original <- varnames
for(i in seq_along(varnames))
varnames[i] <- naken.collapse(varnames[i]) # e.g. log(lstat) becomes lstat
if(trace >= 1)
cat("variable importance:", varnames, "\n")
attr(object, "plotmo.importance") <- varnames
object
}
# Like varimp.constparty but works for all party trees, including mob trees.
# Splits that affect more observations get more weight.
# Splits near the root get slightly more weight (lower depth).
# (This is to disambiguate vars that have equal importance otherwise.)
varimp_party <- function(object)
{
init.varimp <- function(node, varimp, depth)
{
# update varimp for tree starting at node by walking the tree
varid <- node$split$varid
if(!is.null(varid)) {
check.index(varid, "varid", varimp) # paranoia
nobs <- if(!is.null(node$info$nobs)) node$info$nobs else 1
varimp[varid] <- varimp[varid] + nobs - .0001 * depth
}
knodes <- partykit::kids_node(node)
for(node in knodes)
if(!is.null(node))
varimp <- init.varimp(node, varimp, depth+1) # recurse
varimp
}
#--- varimp_party starts here
varnames <- colnames(object$data)
varimp <- repl(0, length(varnames))
names(varimp) <- varnames
varimp <- init.varimp(object$node, varimp, depth=0)
sort(varimp[varimp != 0], decreasing=TRUE) # discard vars not in tree, sort
}
plotmo.singles.party_plotmo <- function(object, x, nresponse, trace, all1, ...)
{
all <- seq_along(colnames(x))
if(all1)
return(all)
varnames <- attr(object, "plotmo.importance")
stopifnot(!is.null(varnames))
i <- match(varnames, colnames(x))
ina <- which(is.na(i)) # sanity check
if(length(ina)) {
warnf(
"could not find \"%s\" in %s\nWorkaround: use all1=TRUE to plot all variables",
varnames[ina[1]], quote.with.c(colnames(x)))
i <- i[!is.na(i)]
}
if(length(i) == 0) {
warnf("could not estimate variable importance")
i <- seq_along(length(colnames(x))) # something went wrong, use all vars
}
# indices of important variables, max of 10 variables
# (10 becauses plotmo.pairs returns 6, total is 16, therefore 4x4 grid)
i[1: min(10, length(i))]
}
plotmo.pairs.party_plotmo <- function(object, x, nresponse, trace, all2, ...)
{
singles <- plotmo.singles(object, x, nresponse, trace, all1=FALSE, ...)
# choose npairs so a total of no more than 16 plots
# npairs=5 gives 10 pairplots, npairs=4 gives 6 pairplots
npairs <- if(length(singles) <= 6) 5 else 4
form.pairs(singles[1: min(npairs, length(singles))])
}
# Check the mob object formula and issue a work-around message when
# the formula won't work for predictions with new data.
# This prevents err msg: 'newdata' had 1 row but variables found have 167 rows
check.mob.object <- function(object)
{
call.fit <- getCall(object)$fit # was a fit func passed to the model building func?
if(is.null(call.fit))
return()
# it's a mob object
func <- eval(call.fit)
stopifnot(inherits(func, "function"))
func <- deparse(func, width.cutoff=500)
# Is there a "(" followed by "~" followed by a lone "x," in the function body?
# Or a "(" followed by "~" followed by "x - 1,".
regex1 <- "\\(.*\\~.*[^a-zA-Z0-9_\\.]x,"
regex2 <- "\\(.*\\~.*x \\- 1,"
regex <- paste0(regex1, "|", regex2)
grepl <- grepl(regex, func)
if(any(grepl)) {
# Issue the following message (details will vary depending on the fit func):
#
# The following formula in the mob fit function is not supported by plotmo:
#
# glm(y ~ 0 + x, family = binomial, start = start, ...)
#
# Possible workaround: Replace the fit function with:
#
# function (y, x, start = NULL, weights = NULL, offset = NULL, ...)
# {
# glm(as.formula(paste("y ~ ", paste(colnames(x)[-1], collapse="+"))),
# family = binomial, start = start, ...)
# }
#
# Error: The formula in the mob fit function is not supported by plotmo (see above)
printf("\nThe following formula in the mob fit function is not supported by plotmo:\n\n")
ifunc <- which(grepl)[1]
cat(func[ifunc])
regex <- "\\([^,]+,"
func[ifunc] <- sub(regex,
"(as.formula(paste(\"y ~ \", paste(colnames(x)[-1], collapse=\"+\"))),\n data=x,",
func[ifunc])
printf("\n\nPossible workaround: Replace the fit function with:\n\n")
printf(" %s <- ", as.character(call.fit))
for(i in 1:length(func))
printf("%s\n ", func[i])
printf("\n")
stop0("The formula in the mob fit function is not supported by plotmo (see above)\n",
" This is because predict.mob often fails with newdata and type=\"response\"\n",
" e.g. example(mob); predict(pid_tree, newdata=PimaIndiansDiabetes[1:3,], type=\"response\")")
}
}
# cforest objects
plotmo.prolog.parties <- function(object, object.name, trace, ...) # called when plotmo starts
{
attr(object, "plotmo.importance") <-
order.parties.vars.on.importance(object, trace) # a char vector
object
}
order.parties.vars.on.importance <- function(object, trace) # a char vector
{
varimp <- try(varimp(object), silent=TRUE)
varnames <- if(is.try.err(varimp))
colnames(object$data)[-1] # -1 to drop response TODO is this reliable?
else
names(sort(varimp, decreasing=TRUE))
if(trace >= 1)
cat("variable importance:", varnames, "\n")
varnames
}
plotmo.singles.parties <- function(object, x, nresponse, trace, all1, ...)
{
plotmo.singles.party_plotmo(object, x, nresponse, trace, all1, ...)
}
plotmo.pairs.parties <- function(object, x, nresponse, trace, all2, ...)
{
plotmo.pairs.party_plotmo(object, x, nresponse, trace, all2, ...)
}
|