File: lmtree.R

package info (click to toggle)
r-cran-partykit 1.2-23-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,832 kB
  • sloc: ansic: 91; sh: 75; makefile: 38
file content (134 lines) | stat: -rw-r--r-- 4,176 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
## simple wrapper function to specify fitter and return class
lmtree <- function(formula, data, subset, na.action, weights, offset, cluster, ...)
{
  ## TODO: variance as model parameter

  ## use dots for setting up mob_control
  control <- mob_control(...)
  if(control$vcov != "opg") {
    warning('only vcov = "opg" supported in lmtree')
    control$vcov <- "opg"
  }
  if(!is.null(control$prune)) {
    if(is.character(control$prune)) {
      control$prune <- tolower(control$prune)
      control$prune <- match.arg(control$prune, c("aic", "bic", "none"))
      control$prune <- switch(control$prune,
        "aic" = {
	  function(objfun, df, nobs) (nobs[1L] * log(objfun[1L]) + 2 * df[1L]) < (nobs[1L] * log(objfun[2L]) + 2 * df[2L])
	}, "bic" = {
	  function(objfun, df, nobs) (nobs[1L] * log(objfun[1L]) + log(nobs[2L]) * df[1L]) < (nobs[1L] * log(objfun[2L]) + log(nobs[2L]) * df[2L])
	}, "none" = {
	  NULL
	})      
    }
    if(!is.function(control$prune)) {
      warning("unknown specification of 'prune'")
      control$prune <- NULL
    }
  }

  ## keep call
  cl <- match.call(expand.dots = TRUE)

  ## extend formula if necessary
  f <- Formula::Formula(formula)
  if(length(f)[2L] == 1L) {
    attr(f, "rhs") <- c(list(1), attr(f, "rhs"))
    formula[[3L]] <- formula(f)[[3L]]
  } else {
    f <- NULL
  }

  ## call mob
  m <- match.call(expand.dots = FALSE)
  if(!is.null(f)) m$formula <- formula
  m$fit <- lmfit
  m$control <- control
  if("..." %in% names(m)) m[["..."]] <- NULL
  m[[1L]] <- as.call(quote(partykit::mob))
  rval <- eval(m, parent.frame())
  
  ## extend class and keep original call
  rval$info$call <- cl
  class(rval) <- c("lmtree", class(rval))
  return(rval)
}

## actual fitting function for mob()
lmfit <- function(y, x, start = NULL, weights = NULL, offset = NULL, cluster = NULL, ...,
  estfun = FALSE, object = FALSE)
{
  ## add intercept-only regressor matrix (if missing)
  ## NOTE: does not have terms/formula
  if(is.null(x)) x <- matrix(1, nrow = NROW(y), ncol = 1L,
    dimnames = list(NULL, "(Intercept)"))
  
  ## call lm fitting function
  if(is.null(weights) || identical(as.numeric(weights), rep.int(1, length(weights)))) {
    z <- lm.fit(x, y, offset = offset, ...)
    weights <- 1
  } else {
    z <- lm.wfit(x, y, w = weights, offset = offset, ...)
  }

  ## list structure
  rval <- list(
    coefficients = z$coefficients,
    objfun = sum(weights * z$residuals^2),
    estfun = NULL,
    object = NULL
  )

  ## add estimating functions (if desired)
  if(estfun) {
    rval$estfun <- as.vector(z$residuals) * weights * x[, !is.na(z$coefficients), drop = FALSE]
  }

  ## add model (if desired)
  if(object) {
    class(z) <- c(if(is.matrix(z$fitted)) "mlm", "lm")
    z$offset <- if(is.null(offset)) 0 else offset
    z$contrasts <- attr(x, "contrasts")
    z$xlevels <- attr(x, "xlevels")    

    cl <- as.call(expression(lm))
    cl$formula <- attr(x, "formula")
    if(!is.null(offset)) cl$offset <- attr(x, "offset")
    z$call <- cl
    z$terms <- attr(x, "terms")

    rval$object <- z
  }

  return(rval)
}

## methods
print.lmtree <- function(x,
  title = "Linear model tree", objfun = "residual sum of squares", ...)
{
  print.modelparty(x, title = title, objfun = objfun, ...)
}

predict.lmtree <- function(object, newdata = NULL, type = "response", ...)
{
  ## FIXME: possible to get default?
  if(is.null(newdata) & !identical(type, "node")) stop("newdata has to be provided")
  predict.modelparty(object, newdata = newdata, type = type, ...)
}

plot.lmtree <- function(x, terminal_panel = node_bivplot,
  tp_args = list(), tnex = NULL, drop_terminal = NULL, ...)
{
  nreg <- if(is.null(tp_args$which)) x$info$nreg else length(tp_args$which)
  if(nreg < 1L & missing(terminal_panel)) {
    plot.constparty(as.constparty(x),
      tp_args = tp_args, tnex = tnex, drop_terminal = drop_terminal, ...)
  } else {
    if(is.null(tnex)) tnex <- if(is.null(terminal_panel)) 1L else 2L * nreg
    if(is.null(drop_terminal)) drop_terminal <- !is.null(terminal_panel)
    plot.modelparty(x, terminal_panel = terminal_panel,
      tp_args = tp_args, tnex = tnex, drop_terminal = drop_terminal, ...)
  }
}