File: FilterWrapper.R

package info (click to toggle)
r-cran-mlr 2.13-1
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 6,760 kB
  • sloc: ansic: 65; sh: 13; makefile: 2
file content (113 lines) | stat: -rw-r--r-- 4,335 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#' @title Fuse learner with a feature filter method.
#'
#' @description
#' Fuses a base learner with a filter method. Creates a learner object, which can be
#' used like any other learner object.
#' Internally uses [filterFeatures] before every model fit.
#'
#' After training, the selected features can be retrieved with
#' [getFilteredFeatures].
#'
#' Note that observation weights do not influence the filtering and are simply passed
#' down to the next learner.
#'
#' @template arg_learner
#' @param fw.method (`character(1)`)\cr
#'   Filter method. See [listFilterMethods].
#'   Default is \dQuote{randomForestSRC.rfsrc}.
#' @param fw.perc (`numeric(1)`)\cr
#'   If set, select `fw.perc`*100 top scoring features.
#'   Mutually exclusive with arguments `fw.abs` and `fw.threshold`.
#' @param fw.abs (`numeric(1)`)\cr
#'   If set, select `fw.abs` top scoring features.
#'   Mutually exclusive with arguments `fw.perc` and `fw.threshold`.
#' @param fw.threshold (`numeric(1)`)\cr
#'   If set, select features whose score exceeds `fw.threshold`.
#'   Mutually exclusive with arguments `fw.perc` and `fw.abs`.
#' @param fw.mandatory.feat ([character])\cr
#'   Mandatory features which are always included regardless of their scores
#' @param ... (any)\cr
#'   Additional parameters passed down to the filter.
#' @template ret_learner
#' @export
#' @family filter
#' @family wrapper
#' @examples
#' task = makeClassifTask(data = iris, target = "Species")
#' lrn = makeLearner("classif.lda")
#' inner = makeResampleDesc("Holdout")
#' outer = makeResampleDesc("CV", iters = 2)
#' lrn = makeFilterWrapper(lrn, fw.perc = 0.5)
#' mod = train(lrn, task)
#' print(getFilteredFeatures(mod))
#' # now nested resampling, where we extract the features that the filter method selected
#' r = resample(lrn, task, outer, extract = function(model) {
#'   getFilteredFeatures(model)
#' })
#' print(r$extract)
makeFilterWrapper = function(learner, fw.method = "randomForestSRC.rfsrc", fw.perc = NULL, fw.abs = NULL, fw.threshold = NULL, fw.mandatory.feat = NULL, ...) {
  learner = checkLearner(learner)
  assertChoice(fw.method, choices = ls(.FilterRegister))
  filter = .FilterRegister[[fw.method]]
  ddd = list(...)
  assertList(ddd, names = "named")

  lrn = makeBaseWrapper(
    id = stri_paste(learner$id, "filtered", sep = "."),
    type = learner$type,
    next.learner = learner,
    package = filter$pkg,
    par.set = makeParamSet(
      makeDiscreteLearnerParam(id = "fw.method", values = ls(.FilterRegister)),
      makeNumericLearnerParam(id = "fw.perc", lower = 0, upper = 1),
      makeIntegerLearnerParam(id = "fw.abs", lower = 0),
      makeNumericLearnerParam(id = "fw.threshold"),
      makeUntypedLearnerParam(id = "fw.mandatory.feat")
    ),
    par.vals = filterNull(list(fw.method = fw.method, fw.perc = fw.perc, fw.abs = fw.abs, fw.threshold = fw.threshold, fw.mandatory.feat = fw.mandatory.feat)),
    learner.subclass = "FilterWrapper", model.subclass = "FilterModel")
  lrn$more.args = ddd
  lrn
}

#' @export
trainLearner.FilterWrapper = function(.learner, .task, .subset = NULL, .weights = NULL,
  fw.method = "randomForestSRC.rfsrc", fw.perc = NULL, fw.abs = NULL, fw.threshold = NULL, fw.mandatory.feat = NULL, ...) {

  .task = subsetTask(.task, subset = .subset)
  .task = do.call(filterFeatures, c(list(task = .task, method = fw.method, perc = fw.perc, abs = fw.abs, threshold = fw.threshold, mandatory.feat = fw.mandatory.feat), .learner$more.args))
  m = train(.learner$next.learner, .task, weights = .weights)
  makeChainModel(next.model = m, cl = "FilterModel")
}


#' @export
predictLearner.FilterWrapper = function(.learner, .model, .newdata, ...) {
  features = getFilteredFeatures(.model)
  NextMethod(.newdata = .newdata[, features, drop = FALSE])
}

#' Returns the filtered features.
#'
#' @param model ([WrappedModel])\cr
#'   Trained Model created with [makeFilterWrapper].
#' @return ([character]).
#' @export
#' @family filter
getFilteredFeatures = function(model) {
  UseMethod("getFilteredFeatures")
}

#' @export
getFilteredFeatures.default = function(model) {
  if (is.null(model$learner.model$next.model)) {
    NULL
  } else {
    getFilteredFeatures(model$learner.model$next.model)
  }
}

#' @export
getFilteredFeatures.FilterModel = function(model) {
  model$learner.model$next.model$features
}