File: performance.R

package info (click to toggle)
r-cran-mlr 2.19.2%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 8,264 kB
  • sloc: ansic: 65; sh: 13; makefile: 5
file content (152 lines) | stat: -rw-r--r-- 6,033 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#' Measure performance of prediction.
#'
#' Measures the quality of a prediction w.r.t. some performance measure.
#'
#' @template arg_pred
#' @template arg_measures
#' @param task ([Task])\cr
#'   Learning task, might be requested by performance measure, usually not needed except for clustering or survival.
#' @param model ([WrappedModel])\cr
#'   Model built on training data, might be requested by performance measure, usually not needed except for survival.
#' @param feats ([data.frame])\cr
#'   Features of predicted data, usually not needed except for clustering.
#'   If the prediction was generated from a `task`, you can also pass this instead and the features
#'   are extracted from it.
#' @param simpleaggr ([logical])\cr
#'   If TRUE, aggregation of \code{ResamplePrediction} objects is skipped. This is used internally for threshold tuning. Default is \code{FALSE}.
#' @return (named [numeric]). Performance value(s), named by measure(s).
#' @export
#' @family performance
#' @examples
#' training.set = seq(1, nrow(iris), by = 2)
#' test.set = seq(2, nrow(iris), by = 2)
#'
#' task = makeClassifTask(data = iris, target = "Species")
#' lrn = makeLearner("classif.lda")
#' mod = train(lrn, task, subset = training.set)
#' pred = predict(mod, newdata = iris[test.set, ])
#' performance(pred, measures = mmce)
#'
#' # Compute multiple performance measures at once
#' ms = list("mmce" = mmce, "acc" = acc, "timetrain" = timetrain)
#' performance(pred, measures = ms, task, mod)
performance = function(pred, measures, task = NULL, model = NULL, feats = NULL, simpleaggr = FALSE) {

  if (!is.null(pred)) {
    assertClass(pred, classes = "Prediction")
  }
  measures = checkMeasures(measures, pred$task.desc)
  res = vnapply(measures, doPerformanceIteration, pred = pred, task = task, model = model, td = NULL, feats = feats, simpleaggr = simpleaggr)
  # FIXME: This is really what the names should be, but it breaks all kinds of other stuff
  # if (inherits(pred, "ResamplePrediction")) {
  #  setNames(res, vcapply(measures, measureAggrName))
  # } else {
  #  setNames(res, extractSubList(measures, "id"))
  # }
  setNames(res, extractSubList(measures, "id"))
}

doPerformanceIteration = function(measure, pred = NULL, task = NULL, model = NULL, td = NULL, feats = NULL, simpleaggr = simpleaggr) {

  m = measure
  props = getMeasureProperties(m)
  if ("req.pred" %in% props) {
    if (is.null(pred)) {
      stopf("You need to pass pred for measure %s!", m$id)
    }
  }
  if ("req.truth" %in% props) {
    type = getTaskDesc(pred)$type
    if (type == "surv") {
      if (is.null(pred$data$truth.time) || is.null(pred$data$truth.event)) {
        stopf("You need to have 'truth.time' and 'truth.event' columns in your pred object for measure %s!", m$id)
      }
    } else if (type == "multilabel") {
      if (!(any(stri_detect_regex(colnames(pred$data), "^truth\\.")))) {
        stopf("You need to have 'truth.*' columns in your pred object for measure %s!", m$id)
      }
    } else {
      if (is.null(pred$data$truth)) {
        stopf("You need to have a 'truth' column in your pred object for measure %s!", m$id)
      }
    }
  }
  if ("req.model" %in% props) {
    if (is.null(model)) {
      stopf("You need to pass model for measure %s!", m$id)
    }
    assertClass(model, classes = "WrappedModel")
  }
  if ("req.task" %in% props) {
    if (is.null(task)) {
      stopf("You need to pass task for measure %s!", m$id)
    }
    assertClass(task, classes = "Task")
  }
  if ("req.feats" %in% props) {
    if (is.null(task) && is.null(feats)) {
      stopf("You need to pass either task or features for measure %s!", m$id)
    } else if (is.null(feats)) {
      feats = task$env$data[pred$data$id, , drop = FALSE]
    } else {
      assertClass(feats, "data.frame")
    }
  }
  # we need to find desc somewhere
  td = if (!is.null(pred)) {
    pred$task.desc
  } else if (!is.null(model)) {
    model$task.desc
  } else if (!is.null(task)) {
    getTaskDesc(task)
  }

  # null only happens in custom resampled measure when we do no individual measurements
  if (!is.null(td)) {
    if (td$type %nin% props) {
      stopf("Measure %s does not support task type %s!", m$id, td$type)
    }
    if (td$type == "classif" && length(td$class.levels) > 2L && "classif.multi" %nin% props) {
      stopf("Multiclass problems cannot be used for measure %s!", m$id)
    }

    # if we have multiple req.pred.types, check if we have one of them (currently we only need prob)
    req.pred.types = if ("req.prob" %in% props) "prob" else character(0L)
    if (!is.null(pred) && length(req.pred.types) > 0L && pred$predict.type %nin% req.pred.types) {
      on.measure.not.applicable = getMlrOption(name = "on.measure.not.applicable")
      msg = sprintf("Measure %s requires predict type to be: '%s'!", m$id, collapse(req.pred.types))
      if (on.measure.not.applicable == "stop") {
        stop(msg)
      } else if (on.measure.not.applicable == "warn") {
        warning(msg)
      }
      return(NA_real_)
    }
  }

  # if it's a ResamplePrediction, aggregate
  if (simpleaggr) {
    measure$fun(task, model, pred, feats, m$extra.args)
  } else {
    if (inherits(pred, "ResamplePrediction")) {
      if (is.null(pred$data$iter)) pred$data$iter = 1L
      if (is.null(pred$data$set)) pred$data$set = "test"
      fun = function(ss) {
        is.train = ss$set == "train"
        if (any(is.train)) {
          pred$data = as.data.frame(ss[is.train, ])
          perf.train = measure$fun(task, model, pred, feats, m$extra.args)
        } else {
          perf.train = NA_real_
        }
        pred$data = as.data.frame(ss[!is.train, ])
        perf.test = measure$fun(task, model, pred, feats, m$extra.args)
        list(perf.train = perf.train, perf.test = perf.test)
      }
      perfs = as.data.table(pred$data)[, fun(.SD), by = "iter"]
      measure$aggr$fun(task, perfs$perf.test, perfs$perf.train, measure, perfs$iter, pred)
    } else {
      measure$fun(task, model, pred, feats, m$extra.args)
    }
  }
}