File: BaseWrapper_operators.R

package info (click to toggle)
r-cran-mlr 2.19.2%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 8,264 kB
  • sloc: ansic: 65; sh: 13; makefile: 5
file content (61 lines) | stat: -rw-r--r-- 1,707 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#' @export
getParamSet.BaseWrapper = function(x) {
  c(x$par.set, getParamSet(x$next.learner))
}


#' @export
getHyperPars.BaseWrapper = function(learner, for.fun = c("train", "predict", "both")) {
  c(getHyperPars(learner$next.learner, for.fun), getHyperPars.Learner(learner, for.fun))
}


#' @export
setHyperPars2.BaseWrapper = function(learner, par.vals) {
  ns = names(par.vals)
  pds.n = names(learner$par.set$pars)
  for (i in seq_along(par.vals)) {
    if (ns[i] %in% pds.n) {
      learner = setHyperPars2.Learner(learner, par.vals = par.vals[i])
    } else {
      learner$next.learner = setHyperPars2(learner$next.learner, par.vals = par.vals[i])
    }
  }
  return(learner)
}

#' @export
removeHyperPars.BaseWrapper = function(learner, ids) {
  i = intersect(names(learner$par.vals), ids)
  if (length(i) > 0L) {
    learner = removeHyperPars.Learner(learner, i)
  }
  learner$next.learner = removeHyperPars(learner$next.learner, setdiff(ids, i))
  return(learner)
}



getLeafLearner = function(learner) {
  if (inherits(learner, "BaseWrapper")) {
    return(getLeafLearner(learner$next.learner))
  }
  return(learner)
}


# default is to set the predict.type for the wrapper and recursively for all learners inside
# if one does not want this, one must override
#' @export
setPredictType.BaseWrapper = function(learner, predict.type) {
  learner$next.learner = setPredictType(learner$next.learner, predict.type)
  setPredictType.Learner(learner, predict.type)
}


#' @export
getClassWeightParam.BaseWrapper = function(learner, ...) {
  assertClass(learner, "BaseWrapper")
  weight.param.name = learner$next.learner$class.weights.param
  learner$next.learner$par.set$pars[[weight.param.name]]
}