File: RLearner_regr_mob.R

package info (click to toggle)
r-cran-mlr 2.19.2%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 8,264 kB
  • sloc: ansic: 65; sh: 13; makefile: 5
file content (68 lines) | stat: -rw-r--r-- 2,532 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#' @export
makeRLearner.regr.mob = function() {
  makeRLearnerRegr(
    cl = "regr.mob",
    package = c("party", "modeltools"),
    par.set = makeParamSet(
      makeNumericLearnerParam(id = "alpha", default = 0.05, lower = 0, upper = 1),
      makeLogicalLearnerParam(id = "bonferroni", default = TRUE),
      makeIntegerLearnerParam(id = "minsplit", default = 20L, lower = 1L),
      makeNumericLearnerParam(id = "trim", default = 0.1, lower = 0, upper = 1),
      makeLogicalLearnerParam(id = "breakties", default = FALSE),
      makeLogicalLearnerParam(id = "verbose", default = FALSE, tunable = FALSE),
      makeDiscreteLearnerParam(id = "model", default = "glinearModel", values = list("glinearModel", "linearModel")),
      makeUntypedLearnerParam(id = "part.feats"),
      makeUntypedLearnerParam(id = "term.feats")
    ),
    par.vals = list(),
    properties = c("numerics", "factors", "weights"),
    name = "Model-based Recursive Partitioning  Yielding a Tree with Fitted Models Associated with each Terminal Node",
    short.name = "mob",
    callees = c("mob", "mob_control")
  )
}

#' @export
trainLearner.regr.mob = function(.learner, .task, .subset, .weights = NULL, alpha, bonferroni, minsplit,
  trim, breakties, verbose, part.feats, term.feats, model, ...) {

  cntrl = learnerArgsToControl(party::mob_control, alpha, bonferroni, minsplit, trim, breakties, verbose)

  feats = getTaskFeatureNames(.task)
  # FIXME: document stuff
  # FIXME: think about these defaults, also ask julia
  if (missing(part.feats)) {
    part.feats = feats
  }
  if (missing(term.feats)) {
    term.feats = feats
  }

  target = getTaskTargetNames(.task)
  f = as.formula(stri_paste(target, "~", collapse(term.feats, sep = " + "), "|", collapse(part.feats, sep = " + "), sep = " "))

  args = list(f, data = getTaskData(.task, .subset), control = cntrl, ...)
  if (!is.null(.weights)) {
    args$weights = .weights
  }
  if (!missing(model)) {
    if (is.character(model)) {
      args$model = getFromNamespace(model, "mda")
    } else {
      args$model = model # this allows to set the model if on.par.out.of.bounds is set to "warn" or "quiet"
    }
  }

  model = do.call(party::mob, args)

  # sometimes mob fails to fit a model but does not signal an exception.
  if (anyMissing(coef(model))) {
    stop("Failed to fit party::mob. Some coefficients are estimated as NA")
  }
  model
}

#' @export
predictLearner.regr.mob = function(.learner, .model, .newdata, ...) {
  predict(.model$learner.model, newdata = .newdata, ...)
}