File: setPredictType.R

package info (click to toggle)
r-cran-mlr 2.19.2%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 8,264 kB
  • sloc: ansic: 65; sh: 13; makefile: 5
file content (47 lines) | stat: -rw-r--r-- 1,794 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#' @title Set the type of predictions the learner should return.
#'
#' @description
#' Possible prediction types are:
#' Classification: Labels or class probabilities (including labels).
#' Regression: Numeric or response or standard errors (including numeric response).
#' Survival: Linear predictor or survival probability.
#'
#' For complex wrappers the predict type is usually also passed down the
#' encapsulated learner in a recursive fashion.
#'
#' @template arg_learner
#' @param predict.type (`character(1)`)\cr
#'   Classification: \dQuote{response} or \dQuote{prob}.
#'   Regression: \dQuote{response} or \dQuote{se}.
#'   Survival: \dQuote{response} (linear predictor) or \dQuote{prob}.
#'   Clustering: \dQuote{response} or \dQuote{prob}.
#'   Default is \dQuote{response}.
#' @template ret_learner
#' @family predict
#' @family learner
#' @export
setPredictType = function(learner, predict.type) {
  assertClass(learner, classes = "Learner")
  UseMethod("setPredictType")
}

#' @export
setPredictType.Learner = function(learner, predict.type) {
  # checks should be done down here i guess, because of recursive calls in wrappers
  assertChoice(predict.type, choices = switch(learner$type,
    classif = c("response", "prob"),
    multilabel = c("response", "prob"),
    regr = c("response", "se"),
    surv = c("response", "prob"),
    costsens = "response",
    cluster = c("response", "prob")
  ))
  if (predict.type == "prob" && !hasLearnerProperties(learner, "prob")) {
    stopf("Trying to predict probs, but %s does not support that!", learner$id)
  }
  if (predict.type == "se" && !hasLearnerProperties(learner, "se")) {
    stopf("Trying to predict standard errors, but %s does not support that!", learner$id)
  }
  learner$predict.type = predict.type
  return(learner)
}