File: HomogeneousEnsemble.R

package info (click to toggle)
r-cran-mlr 2.19.2%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 8,264 kB
  • sloc: ansic: 65; sh: 13; makefile: 5
file content (75 lines) | stat: -rw-r--r-- 2,685 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
makeHomogeneousEnsemble = function(id, type, next.learner, package, par.set = makeParamSet(),
  learner.subclass, model.subclass, ...) {
  makeBaseWrapper(id, type, next.learner, package, par.set,
    learner.subclass = c(learner.subclass, "HomogeneousEnsemble"),
    model.subclass = c(model.subclass, "HomogeneousEnsembleModel"),
    ...)
}

##############################        HomogeneousEnsembleModel            ##############################

#' @export
# if ANY model in the list is broken --> failure
isFailureModel.HomogeneousEnsembleModel = function(model) {
  mods = getLearnerModel(model, more.unwrap = FALSE)
  any(vlapply(mods, isFailureModel))
}

#' @export
getFailureModelMsg.HomogeneousEnsembleModel = function(model) {
  mods = getLearnerModel(model, more.unwrap = FALSE)
  msgs = vcapply(mods, getFailureModelMsg)
  j = which.first(!is.na(msgs))
  ifelse(j == 0L, NA_character_, msgs[j])
}

#' @export
getFailureModelDump.HomogeneousEnsembleModel = function(model) {
  mods = getLearnerModel(model, more.unwrap = FALSE)
  msgs = lapply(mods, getFailureModelDump)
  j = which.first(!is.null(msgs))
  ifelse(j == 0L, NULL, msgs[[j]])
}

#' Deprecated, use `getLearnerModel` instead.
#' @param model Deprecated.
#' @param learner.models Deprecated.
#' @export
getHomogeneousEnsembleModels = function(model, learner.models = FALSE) {
  .Deprecated("getLearnerModel")
  getLearnerModel(model, more.unwrap = learner.models)
}

#' @export
getLearnerModel.HomogeneousEnsembleModel = function(model, more.unwrap = FALSE) {
  ms = model$learner.model$next.model
  if (more.unwrap) {
    extractSubList(ms, "learner.model", simplify = FALSE)
  } else {
    ms
  }
}

##############################               helpers                      ##############################

# internal mini helper: return a matrix of predictions, either numeric for regr or character for classif
# rows = newdata points, cols = ensembles members
# does only work for responses, not probs, se, etc
predictHomogeneousEnsemble = function(.learner, .model, .newdata, .subset = NULL, ...) {
  models = getLearnerModel(.model, more.unwrap = FALSE)
  # for classif we convert factor to char, nicer to handle later on
  preds = lapply(models, function(mod) {
    p = predict(mod, newdata = .newdata, subset = .subset, ...)$data$response
    if (is.factor(p)) {
      p = as.character(p)
    }
    return(p)
  })
  do.call(cbind, preds)
}

# call this at end of trainLearner.CostSensRegrWrapper
# FIXME: potentially remove this when ChainModel is removed
makeHomChainModel = function(learner, models) {
  makeChainModel(next.model = models, cl = c(learner$model.subclass, "HomogeneousEnsembleModel"))
}