File: WeightedClassesWrapper.R

package info (click to toggle)
r-cran-mlr 2.19.2%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 8,264 kB
  • sloc: ansic: 65; sh: 13; makefile: 5
file content (134 lines) | stat: -rw-r--r-- 5,821 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#' @title Wraps a classifier for weighted fitting where each class receives a weight.
#'
#' @description
#' Creates a wrapper, which can be used like any other learner object.
#'
#' Fitting is performed in a weighted fashion where each observation receives a weight,
#' depending on the class it belongs to, see `wcw.weight`.
#' This might help to mitigate problems caused by imbalanced class distributions.
#'
#' This weighted fitting can be achieved in two ways:
#'
#' a) The learner already has a parameter for class weighting, so one weight can directly be defined
#' per class. Example: \dQuote{classif.ksvm} and parameter `class.weights`.
#' In this case we don't really do anything fancy. We convert `wcw.weight` a bit,
#' but basically simply bind its value to the class weighting param.
#' The wrapper in this case simply offers a convenient, consistent fashion for class weighting -
#' and tuning! See example below.
#'
#' b) The learner does not have a direct parameter to support class weighting, but
#' supports observation weights, so `hasLearnerProperties(learner, 'weights')` is `TRUE`.
#' This means that an individual, arbitrary weight can be set per observation during training.
#' We set this weight depending on the class internally in the wrapper. Basically we introduce
#' something like a new \dQuote{class.weights} parameter for the learner via observation weights.
#'
#' @template arg_learner_classif
#' @param wcw.param (`character(1)`)\cr
#'   Name of already existing learner parameter, which allows class weighting.
#'   The default (`wcw.param = NULL`) will use the parameter defined in
#'   the learner (`class.weights.param`). During training, the parameter
#'   must accept a named vector of class weights, where length equals the
#'   number of classes.
#' @param wcw.weight ([numeric])\cr
#'   Weight for each class.
#'   Must be a vector of the same number of elements as classes are in task,
#'   and must also be in the same order as the class levels are in
#'   `getTaskDesc(task)$class.levels`.
#'   For convenience, one must pass a single number in case of binary classification, which
#'   is then taken as the weight of the positive class, while the negative class receives a weight
#'   of 1.
#'   Default is 1.
#' @template ret_learner
#' @family wrapper
#' @export
#' @examples
#' \dontshow{ if (requireNamespace("kernlab")) \{ }
#' \donttest{
#' set.seed(123)
#' # using the direct parameter of the SVM (which is already defined in the learner)
#' lrn = makeWeightedClassesWrapper("classif.ksvm", wcw.weight = 0.01)
#' res = holdout(lrn, sonar.task)
#' print(calculateConfusionMatrix(res$pred))
#'
#' # using the observation weights of logreg
#' lrn = makeWeightedClassesWrapper("classif.logreg", wcw.weight = 0.01)
#' res = holdout(lrn, sonar.task)
#' print(calculateConfusionMatrix(res$pred))
#'
#' # tuning the imbalancy param and the SVM param in one go
#' lrn = makeWeightedClassesWrapper("classif.ksvm", wcw.param = "class.weights")
#' ps = makeParamSet(
#'   makeNumericParam("wcw.weight", lower = 1, upper = 10),
#'   makeNumericParam("C", lower = -12, upper = 12, trafo = function(x) 2^x),
#'   makeNumericParam("sigma", lower = -12, upper = 12, trafo = function(x) 2^x)
#' )
#' ctrl = makeTuneControlRandom(maxit = 3L)
#' rdesc = makeResampleDesc("CV", iters = 2L, stratify = TRUE)
#' res = tuneParams(lrn, sonar.task, rdesc, par.set = ps, control = ctrl)
#' print(res)
#' # print(res$opt.path)
#' }
#' \dontshow{ \} }
makeWeightedClassesWrapper = function(learner, wcw.param = NULL, wcw.weight = 1) {

  learner = checkLearner(learner, "classif")
  pv = list()

  if (is.null(wcw.param)) {
    wcw.param = learner$class.weights.param
  } else if (!is.null(learner$class.weights.param) && (learner$class.weights.param != wcw.param)) {
    stopf("wcw.param (%s) differs from the class.weights.parameter (%s) of the learner!",
      wcw.param, learner$class.weights.param)
  }

  if (is.null(wcw.param)) {
    if (!hasLearnerProperties(learner, "weights")) {
      stopf("Learner '%s' does not support observation weights. You have to set 'wcw.param' to the learner param which allows to set class weights! (which hopefully exists...)", learner$id)
    }
  } else {
    assertSubset(wcw.param, getParamIds(learner$par.set))
  }

  if (!missing(wcw.weight)) {
    assertNumeric(wcw.weight, lower = 0, any.missing = FALSE)
    pv$wcw.weight = wcw.weight
  }
  id = stri_paste("weightedclasses", learner$id, sep = ".")
  ps = makeParamSet(
    makeNumericVectorLearnerParam(id = "wcw.weight", len = NA_integer_, lower = 0)
  )
  x = makeBaseWrapper(id, learner$type, learner, package = learner$package, par.set = ps, par.vals = pv,
    learner.subclass = "WeightedClassesWrapper", model.subclass = "WeightedClassesModel")
  x$wcw.param = wcw.param
  x
}

#' @export
trainLearner.WeightedClassesWrapper = function(.learner, .task, .subset = NULL, .weights, wcw.weight = 1, ...) {
  .task = subsetTask(.task, .subset)
  td = getTaskDesc(.task)
  levs = td$class.levels
  p = .learner$wcw.param
  if (length(levs) == 2L) {
    assertNumber(wcw.weight, lower = 0)
    wcw.weight = c(wcw.weight, 1)
    names(wcw.weight) = c(td$positive, td$negative)
  } else {
    assertNumeric(wcw.weight, len = length(levs), lower = 0)
    names(wcw.weight) = levs
  }
  if (is.null(p)) {
    y = as.character(getTaskTargets(.task))
    weights = wcw.weight[y]
    m = train(.learner$next.learner, task = .task, weights = weights)
  } else {
    .learner = setHyperPars(.learner, par.vals = setNames(list(wcw.weight), p))
    m = train(.learner$next.learner, task = .task)
  }
  makeChainModel(next.model = m, cl = "WeightedClassesModel")
}

#' @export
getLearnerProperties.WeightedClassesWrapper = function(learner) {
  setdiff(getLearnerProperties(learner$next.learner), "weights")
}