File: predict.R

package info (click to toggle)
r-cran-mlr 2.19.2%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 8,264 kB
  • sloc: ansic: 65; sh: 13; makefile: 5
file content (163 lines) | stat: -rw-r--r-- 5,653 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#' @title Predict new data.
#'
#' @description
#' Predict the target variable of new data using a fitted model.
#' What is stored exactly in the ([Prediction]) object depends
#' on the `predict.type` setting of the [Learner].
#' If `predict.type` was set to \dQuote{prob} probability thresholding
#' can be done calling the [setThreshold] function on the
#' prediction object.
#'
#' The row names of the input `task` or `newdata` are preserved in the output.
#'
#' @param object ([WrappedModel])\cr
#'   Wrapped model, result of [train].
#' @param task ([Task])\cr
#'   The task. If this is passed, data from this task is predicted.
#' @param newdata ([data.frame])\cr
#'   New observations which should be predicted.
#'   Pass this alternatively instead of `task`.
#' @template arg_subset
#' @param ... (any)\cr
#'   Currently ignored.
#' @return ([Prediction]).
#' @family predict
#' @export
#' @examples
#' \dontshow{ if (requireNamespace("MASS")) \{ }
#' # train and predict
#' train.set = seq(1, 150, 2)
#' test.set = seq(2, 150, 2)
#' model = train("classif.lda", iris.task, subset = train.set)
#' p = predict(model, newdata = iris, subset = test.set)
#' print(p)
#' predict(model, task = iris.task, subset = test.set)
#'
#' # predict now probabiliies instead of class labels
#' lrn = makeLearner("classif.lda", predict.type = "prob")
#' model = train(lrn, iris.task, subset = train.set)
#' p = predict(model, task = iris.task, subset = test.set)
#' print(p)
#' getPredictionProbabilities(p)
#' \dontshow{ \} }
predict.WrappedModel = function(object, task, newdata, subset = NULL, ...) {

  if (!xor(missing(task), missing(newdata))) {
    stop("Pass either a task object or a newdata data.frame to predict, but not both!")
  }
  assertClass(object, classes = "WrappedModel")
  model = object
  learner = model$learner
  td = model$task.desc

  # FIXME: cleanup if cases
  if (missing(newdata)) {
    assertClass(task, classes = "Task")
    size = getTaskSize(task)
  } else {
    assertDataFrame(newdata, min.rows = 1L)
    if (class(newdata)[1] != "data.frame") {
      warningf("Provided data for prediction is not a pure data.frame but from class %s, hence it will be converted.", class(newdata)[1])
      newdata = as.data.frame(newdata)
    }
    size = nrow(newdata)
  }
  subset = checkTaskSubset(subset, size)

  if (missing(newdata)) {
    # if learner does not support functional, we convert to df cols
    if (hasLearnerProperties(object$learner, "functionals") ||
      hasLearnerProperties(object$learner, "single.functional")) {
      newdata = getTaskData(task, subset, functionals.as = "matrix")
    } else {
      newdata = getTaskData(task, subset, functionals.as = "dfcols")
    }
  } else {
    newdata = newdata[subset, , drop = FALSE]
  }

  # if we saved a model and loaded it later just for prediction this is necessary
  requireLearnerPackages(learner)
  t.col = match(td$target, colnames(newdata))

  # get truth and drop target col, if target in newdata
  if (!all(is.na(t.col))) {
    if (length(t.col) > 1L && anyMissing(t.col)) {
      stop("Some but not all target columns found in data")
    }
    truth = newdata[, t.col, drop = TRUE]
    if (is.list(truth)) {
      truth = data.frame(truth)
    }
    newdata = newdata[, -t.col, drop = FALSE]
  } else {
    truth = NULL
  }

  error = NA_character_
  # default to NULL error dump
  dump = NULL
  # was there an error in building the model? --> return NAs
  if (isFailureModel(model)) {
    p = predictFailureModel(model, newdata)
    time.predict = NA_real_
    dump = getFailureModelDump(model)
  } else {
    # FIXME: this copies newdata
    pars = list(
      .learner = learner,
      .model = model,
      .newdata = newdata
    )
    pars = c(pars, getHyperPars(learner, c("predict", "both")))
    debug.seed = getMlrOption("debug.seed", NULL)
    if (!is.null(debug.seed)) {
      set.seed(debug.seed)
    }
    opts = getLearnerOptions(learner, c("show.learner.output", "on.learner.error", "on.learner.warning", "on.error.dump"))
    fun1 = if (opts$show.learner.output) identity else capture.output
    fun2 = if (opts$on.learner.error == "stop") identity else function(x) try(x, silent = TRUE)
    fun3 = if (opts$on.learner.error == "stop" || !opts$on.error.dump) {
      identity
    } else {
      function(x) {
        withCallingHandlers(x, error = function(c) utils::dump.frames())
      }
    }
    if (opts$on.learner.warning == "quiet") {
      old.warn.opt = getOption("warn")
      on.exit(options(warn = old.warn.opt))
      options(warn = -1L)
    }

    time.predict = measureTime(fun1({
      p = fun2(fun3(do.call(predictLearner2, pars)))
    }))

    # was there an error during prediction?
    if (is.error(p)) {
      if (opts$on.learner.error == "warn") {
        warningf("Could not predict with learner %s: %s", learner$id, as.character(p))
      }
      error = as.character(p)
      p = predictFailureModel(model, newdata)
      time.predict = NA_real_
      if (opts$on.error.dump) {
        dump = addClasses(get("last.dump", envir = .GlobalEnv), "mlr.dump")
      }
    }
    # did the prediction fail otherwise?
    np = nrow(p)
    if (is.null(np)) np = length(p)
    if (np != nrow(newdata)) {
      stopf("predictLearner for %s has returned %i predictions instead of %i!", learner$id, np, nrow(newdata))
    }
  }
  if (missing(task)) {
    ids = NULL
  } else {
    ids = subset
  }
  makePrediction(task.desc = td, row.names = rownames(newdata), id = ids, truth = truth,
    predict.type = learner$predict.type, predict.threshold = learner$predict.threshold, y = p, time = time.predict, error = error, dump = dump)
}