File: RLearner.R

package info (click to toggle)
r-cran-mlr 2.19.2%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 8,264 kB
  • sloc: ansic: 65; sh: 13; makefile: 5
file content (167 lines) | stat: -rw-r--r-- 6,474 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#' @title Internal construction / wrapping of learner object.
#'
#' @description
#' Wraps an already implemented learning method from R to make it accessible to mlr.
#' Call this method in your constructor. You have to pass an id (name), the required
#' package(s), a description object for all changeable parameters (you do not have to do this for the
#' learner to work, but it is strongly recommended), and use property tags to define
#' features of the learner.
#'
#' For a general overview on how to integrate a learning algorithm into mlr's system, please read the
#' section in the online tutorial:
#' <https://mlr.mlr-org.com/articles/tutorial/create_learner.html>
#'
#' To see all possible properties of a learner, go to: [LearnerProperties].
#'
#' @template arg_lrncl
#' @param package ([character])\cr
#'   Package(s) to load for the implementation of the learner.
#' @param properties ([character])\cr
#'   Set of learner properties. See above.
#'   Default is `character(0)`.
#' @param class.weights.param (`character(1)`)\cr
#'   Name of the parameter, which can be used for providing class weights.
#' @param par.set ([ParamHelpers::ParamSet])\cr
#'   Parameter set of (hyper)parameters and their constraints.
#'   Dependent parameters with a `requires` field must use `quote` and not
#'   `expression` to define it.
#' @param par.vals ([list])\cr
#'   Always set hyperparameters to these values when the object is constructed.
#'   Useful when default values are missing in the underlying function.
#'   The values can later be overwritten when the user sets hyperparameters.
#'   Default is empty list.
#' @param name (`character(1)`)\cr
#'   Meaningful name for learner.
#'   Default is `id`.
#' @param short.name (`character(1)`)\cr
#'   Short name for learner.
#'   Should only be a few characters so it can be used in plots and tables.
#'   Default is `id`.
#' @param note (`character(1)`)\cr
#'   Additional notes regarding the learner and its integration in mlr.
#'   Default is \dQuote{}.
#' @param callees ([character])\cr
#'   Character vector naming all functions of the learner's package being called which
#'   have a relevant R help page.
#'   Default is `character(0)`.
#' @return ([RLearner]). The specific subclass is one of [RLearnerClassif],
#'   [RLearnerCluster], [RLearnerMultilabel],
#'   [RLearnerRegr], [RLearnerSurv].
#' @name RLearner
#' @rdname RLearner
#' @aliases RLearnerClassif RLearnerCluster RLearnerMultilabel RLearnerRegr RLearnerSurv
NULL

#' @export
#' @rdname RLearner
makeRLearner = function() {
  UseMethod("makeRLearner")
}

makeRLearnerInternal = function(id, type, package, par.set, par.vals, properties,
  name = id, short.name = id, note = "", callees) {

  # must do that before accessing par.set
  # one case where lazy eval is actually helpful...
  assertCharacter(package, any.missing = FALSE)
  requirePackages(package, why = stri_paste("learner", id, sep = " "), default.method = "load")

  assertString(id)
  assertChoice(type, choices = c("classif", "regr", "multilabel", "surv", "cluster", "costsens"))
  assertSubset(properties, listLearnerProperties(type))
  assertClass(par.set, classes = "ParamSet")
  checkListElementClass(par.set$pars, "LearnerParam")
  assertList(par.vals)
  if (!isProperlyNamed(par.vals)) {
    stop("Argument par.vals must be a properly named list!")
  }
  assertString(name)
  assertString(short.name)
  assertString(note)
  assertCharacter(callees, any.missing = FALSE)
  learner = makeLearnerBaseConstructor("RLearner",
    id = id,
    type = type,
    package = package,
    properties = unique(properties),
    par.set = par.set,
    par.vals = par.vals,
    predict.type = "response"
  )
  learner$name = name
  learner$short.name = short.name
  learner$note = note
  learner$callees = callees
  learner$help.list = makeParamHelpList(callees, package, par.set)
  return(learner)

}

#' @export
#' @rdname RLearner
makeRLearnerClassif = function(cl, package, par.set, par.vals = list(), properties = character(0L),
  name = cl, short.name = cl, note = "", class.weights.param = NULL, callees = character(0L)) {
  lrn = addClasses(
    makeRLearnerInternal(cl, "classif", package, par.set, par.vals, properties, name, short.name, note, callees),
    c(cl, "RLearnerClassif")
  )

  # include the class.weights.param
  if ("class.weights" %in% getLearnerProperties(lrn)) {
    assertString(class.weights.param)
    if (!is.null(par.set$pars[[class.weights.param]])) {
      lrn$class.weights.param = class.weights.param
    } else {
      stopf("'%s' needs to be defined in the parameter set as well.", class.weights.param)
    }
  }
  return(lrn)
}

#' @export
#' @rdname RLearner
makeRLearnerMultilabel = function(cl, package, par.set, par.vals = list(), properties = character(0L), name = cl, short.name = cl, note = "", callees = character(0L)) {
  addClasses(
    makeRLearnerInternal(cl, "multilabel", package, par.set, par.vals, properties, name, short.name, note, callees),
    c(cl, "RLearnerMultilabel")
  )
}

#' @export
#' @rdname RLearner
makeRLearnerRegr = function(cl, package, par.set, par.vals = list(), properties = character(0L), name = cl, short.name = cl, note = "", callees = character(0L)) {
  addClasses(
    makeRLearnerInternal(cl, "regr", package, par.set, par.vals, properties, name, short.name, note, callees),
    c(cl, "RLearnerRegr")
  )
}

#' @export
#' @rdname RLearner
makeRLearnerSurv = function(cl, package, par.set, par.vals = list(), properties = character(0L), name = cl, short.name = cl, note = "", callees = character(0L)) {
  addClasses(
    makeRLearnerInternal(cl, "surv", package, par.set, par.vals, properties, name, short.name, note, callees),
    c(cl, "RLearnerSurv")
  )
}

#' @export
#' @rdname RLearner
makeRLearnerCluster = function(cl, package, par.set, par.vals = list(), properties = character(0L), name = cl, short.name = cl, note = "", callees = character(0L)) {
  addClasses(
    makeRLearnerInternal(cl, "cluster", package, par.set, par.vals, properties, name, short.name, note, callees),
    c(cl, "RLearnerCluster")
  )
}

#' @export
#' @rdname RLearner
makeRLearnerCostSens = function(cl, package, par.set, par.vals = list(), properties = character(0L),
  name = cl, short.name = cl, note = "", callees = character(0L)) {
  lrn = addClasses(
    makeRLearnerInternal(cl, "costsens", package, par.set, par.vals, properties, name, short.name, note, callees),
    c(cl, "RLearnerCostSens")
  )

  return(lrn)
}