File: FailureModel.R

package info (click to toggle)
r-cran-mlr 2.19.2%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 8,264 kB
  • sloc: ansic: 65; sh: 13; makefile: 5
file content (88 lines) | stat: -rw-r--r-- 2,503 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#' @title Failure model.
#'
#' @description
#'
#' A subclass of [WrappedModel]. It is created
#' - if you set the respective option in [configureMlr] -
#' when a model internally crashed during training.
#' The model always predicts NAs.
#'
#' The if mlr option `on.error.dump` is `TRUE`, the
#' `FailureModel` contains the debug trace of the error.
#' It can be accessed with `getFailureModelDump` and
#' inspected with `debugger`.
#'
#' Its encapsulated `learner.model` is simply a string:
#' The error message that was generated when the model crashed.
#' The following code shows how to access the message.
#'
#' @name FailureModel
#' @family debug
#' @rdname FailureModel
#' @examples
#' configureMlr(on.learner.error = "warn")
#' data = iris
#' data$newfeat = 1 # will make LDA crash
#' task = makeClassifTask(data = data, target = "Species")
#' m = train("classif.lda", task) # LDA crashed, but mlr catches this
#' print(m)
#' print(m$learner.model) # the error message
#' p = predict(m, task) # this will predict NAs
#' print(p)
#' print(performance(p))
#' configureMlr(on.learner.error = "stop")
NULL

predictFailureModel = function(model, newdata) {
  lrn = model$learner
  type = lrn$type
  ptype = lrn$predict.type
  n = nrow(newdata)
  if (type == "classif") {
    levs = model$task.desc$class.levels
    res = if (ptype == "response") {
      factor(rep(NA_character_, n), levels = levs)
    } else {
      matrix(NA_real_, nrow = n, ncol = length(levs), dimnames = list(NULL, levs))
    }
  } else if (type == "regr") {
    res = if (ptype == "response") {
      rep(NA_real_, n)
    } else {
      matrix(NA_real_, nrow = n, ncol = 2L, dimnames = list(NULL, c("response", "se")))
    }
  } else if (type == "surv") {
    if (ptype == "response") {
      res = rep.int(NA_real_, n)
    } else {
      stop("Predict type 'prob' for survival not yet supported")
    }
  } else if (type == "costsens") {
    levs = model$task.desc$class.levels
    res = factor(rep(NA_character_, n), levels = levs)
  } else if (type == "cluster") {
    res = rep(NA_character_, n)
  }
  return(res)
}

#' @export
print.FailureModel = function(x, ...) {
  print.WrappedModel(x)
  catf("Training failed: %s", getFailureModelMsg(x))
}

#' @export
isFailureModel.FailureModel = function(model) {
  return(TRUE)
}

#' @export
getFailureModelMsg.FailureModel = function(model) {
  return(as.character(model$learner.model))
}

#' @export
getFailureModelDump.FailureModel = function(model) {
  return(model$dump)
}