File: predictLearner.R

package info (click to toggle)
r-cran-mlr 2.19.2%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 8,264 kB
  • sloc: ansic: 65; sh: 13; makefile: 5
file content (159 lines) | stat: -rw-r--r-- 6,595 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#' Predict new data with an R learner.
#'
#' Mainly for internal use. Predict new data with a fitted model.
#' You have to implement this method if you want to add another learner to this package.
#'
#' Your implementation must adhere to the following:
#' Predictions for the observations in `.newdata` must be made based on the fitted
#' model (`.model$learner.model`).
#' All parameters in `...` must be passed to the underlying predict function.
#'
#' @param .learner ([RLearner])\cr
#'   Wrapped learner.
#' @param .model ([WrappedModel])\cr
#'   Model produced by training.
#' @param .newdata ([data.frame])\cr
#'   New data to predict. Does not include target column.
#' @param ... (any)\cr
#'   Additional parameters, which need to be passed to the underlying predict function.
#' @return
#' \itemize{
#'   \item For classification: Either a factor with class labels for type
#'     \dQuote{response} or, if the learner supports this, a matrix of class probabilities
#'     for type \dQuote{prob}. In the latter case the columns must be named with the class
#'     labels.
#'   \item For regression: Either a numeric vector for type \dQuote{response} or,
#'     if the learner supports this, a matrix with two columns for type \dQuote{se}.
#'     In the latter case the first column contains the estimated response (mean value)
#'     and the second column the estimated standard errors.
#'   \item For survival: Either a numeric vector with some sort of orderable risk
#'     for type \dQuote{response} or, if supported, a numeric vector with time dependent
#'     probabilities for type \dQuote{prob}.
#'   \item For clustering: Either an integer with cluster IDs for type \dQuote{response}
#'     or, if supported, a matrix of membership probabilities for type \dQuote{prob}.
#'   \item For multilabel: A logical matrix that indicates predicted class labels for type
#'     \dQuote{response} or, if supported, a matrix of class probabilities for type
#'     \dQuote{prob}. The columns must be named with the class labels.
#'  }
#' @export
predictLearner = function(.learner, .model, .newdata, ...) {
  lmod = getLearnerModel(.model)
  if (inherits(lmod, "NoFeaturesModel")) {
    predictNofeatures(.model, .newdata)
  } else {
    assertDataFrame(.newdata, min.rows = 1L, min.cols = 1L)
    UseMethod("predictLearner")
  }
}

predictLearner2 = function(.learner, .model, .newdata, ...) {
  # if we have that option enabled, set factor levels to complete levels from task
  if (.learner$fix.factors.prediction) {
    fls = .model$factor.levels
    ns = names(fls)
    # only take objects in .newdata
    ns = intersect(colnames(.newdata), ns)
    fls = fls[ns]
    if (length(ns) > 0L) {
      safe_factor = function(x, levels) {
        if (length(setdiff(levels(x), levels)) > 0) {
          warning("fix.factors.prediction = TRUE produced NAs because of new factor levels in prediction data.")
        }
        factor(x, levels)
      }
      .newdata[ns] = mapply(safe_factor, x = .newdata[ns], levels = fls, SIMPLIFY = FALSE)
    }
  }
  p = predictLearner(.learner, .model, .newdata, ...)
  p = checkPredictLearnerOutput(.learner, .model, p)
  return(p)
}

#' @title Check output returned by predictLearner.
#'
#' @description
#' Check the output coming from a Learner's internal
#' `predictLearner` function.
#'
#' This function is for internal use.
#'
#' @param learner ([Learner])\cr
#'   The learner.
#' @param model ([WrappedModel])]\cr
#'   Model produced by training.
#' @param p (any)\cr
#'   The prediction made by `learner`.
#' @return (any). A sanitized version of `p`.
#' @keywords internal
#' @export
checkPredictLearnerOutput = function(learner, model, p) {
  cl = class(p)[1L]
  if (learner$type == "classif") {
    levs = model$task.desc$class.levels
    if (learner$predict.type == "response") {
      # the levels of the predicted classes might not be complete....
      # be sure to add the levels at the end, otherwise data gets changed!!!
      if (!is.factor(p)) {
        stopf("predictLearner for %s has returned a class %s instead of a factor!", learner$id, cl)
      }
      levs2 = levels(p)
      if (length(levs2) != length(levs) || any(levs != levs2)) {
        p = factor(p, levels = levs)
      }
    } else if (learner$predict.type == "prob") {
      if (!is.matrix(p)) {
        stopf("predictLearner for %s has returned a class %s instead of a matrix!", learner$id, cl)
      }
      cns = colnames(p)
      if (is.null(cns) || length(cns) == 0L) {
        stopf("predictLearner for %s has returned not the class levels as column names, but no column names at all!",
          learner$id)
      }
      if (!setequal(cns, levs)) {
        stopf("predictLearner for %s has returned not the class levels as column names: %s",
          learner$id, collapse(colnames(p)))
      }
    }
  } else if (learner$type == "regr") {
    if (learner$predict.type == "response") {
      if (cl != "numeric") {
        stopf("predictLearner for %s has returned a class %s instead of a numeric!", learner$id, cl)
      }
    } else if (learner$predict.type == "se") {
      if (!is.matrix(p)) {
        stopf("predictLearner for %s has returned a class %s instead of a matrix!", learner$id, cl)
      }
      if (ncol(p) != 2L) {
        stopf("predictLearner for %s has not returned a numeric matrix with 2 columns!", learner$id)
      }
    }
  } else if (learner$type == "surv") {
    if (learner$predict.type == "prob") {
      stop("Survival does not support prediction of probabilites yet.")
    }
    if (!is.numeric(p)) {
      stopf("predictLearner for %s has returned a class %s instead of a numeric!", learner$id, cl)
    }
  } else if (learner$type == "cluster") {
    if (learner$predict.type == "response") {
      if (cl != "integer") {
        stopf("predictLearner for %s has returned a class %s instead of an integer!", learner$id, cl)
      }
    } else if (learner$predict.type == "prob") {
      if (!is.matrix(p)) {
        stopf("predictLearner for %s has returned a class %s instead of a matrix!", learner$id, cl)
      }
    }
  } else if (learner$type == "multilabel") {
    if (learner$predict.type == "response") {
      if (!(is.matrix(p) && typeof(p) == "logical")) {
        stopf("predictLearner for %s has returned a class %s instead of a logical matrix!", learner$id, cl)
      }
    } else if (learner$predict.type == "prob") {
      if (!(is.matrix(p) && typeof(p) == "double")) {
        stopf("predictLearner for %s has returned a class %s instead of a numerical matrix!", learner$id, cl)
      }
    }
  }
  return(p)
}