1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
|
#' @title Generates a learning curve.
#'
#' @description
#' Observe how the performance changes with an increasing number of observations.
#'
#' @family generate_plot_data
#' @family learning_curve
#' @aliases LearningCurveData
#'
#' @param learners [(list of) [Learner])\cr
#' Learning algorithms which should be compared.
#' @template arg_task
#' @param resampling ([ResampleDesc] | [ResampleInstance])\cr
#' Resampling strategy to evaluate the performance measure.
#' If no strategy is given a default "Holdout" will be performed.
#' @param percs ([numeric])\cr
#' Vector of percentages to be drawn from the training split.
#' These values represent the x-axis.
#' Internally [makeDownsampleWrapper] is used in combination with [benchmark].
#' Thus for each percentage a different set of observations is drawn resulting in noisy performance measures as the quality of the sample can differ.
#' @param measures [(list of) [Measure])\cr
#' Performance measures to generate learning curves for, representing the y-axis.
#' @param stratify (`logical(1)`)\cr
#' Only for classification:
#' Should the downsampled data be stratified according to the target classes?
#' @template arg_showinfo
#' @return ([LearningCurveData]). A `list` containing:
#' - The [Task]
#' - List of [Measure])\cr
#' Performance measures
#' - data ([data.frame]) with columns:
#' - `learner` Names of learners.
#' - `percentage` Percentages drawn from the training split.
#' - One column for each [Measure] passed to [generateLearningCurveData].
#' @examples
#' \dontshow{ if (requireNamespace("class")) \{ }
#' \dontshow{ if (requireNamespace("rpart")) \{ }
#' r = generateLearningCurveData(list("classif.rpart", "classif.knn"),
#' task = sonar.task, percs = seq(0.2, 1, by = 0.2),
#' measures = list(tp, fp, tn, fn),
#' resampling = makeResampleDesc(method = "Subsample", iters = 5),
#' show.info = FALSE)
#' plotLearningCurve(r)
#' \dontshow{ \} }
#' \dontshow{ \} }
#' @export
generateLearningCurveData = function(learners, task, resampling = NULL,
percs = seq(0.1, 1, by = 0.1), measures, stratify = FALSE, show.info = getMlrOption("show.info")) {
learners = ensureVector(learners, 1, "Learner")
learners = lapply(learners, checkLearner)
assertClass(task, "Task")
assertNumeric(percs, lower = 0L, upper = 1L, min.len = 2L, any.missing = FALSE)
measures = checkMeasures(measures, task)
assertFlag(stratify)
if (is.null(resampling)) {
resampling = makeResampleInstance("Holdout", task = task)
} else {
assert(checkClass(resampling, "ResampleDesc"), checkClass(resampling, "ResampleInstance"))
}
# create downsampled versions for all learners
dsws = lapply(learners, function(lrn) {
lapply(seq_along(percs), function(p.id) {
perc = percs[p.id]
dsw = makeDownsampleWrapper(learner = lrn, dw.perc = perc, dw.stratify = stratify)
setLearnerId(dsw, stri_paste(lrn$id, ".", p.id))
})
})
dsws = unlist(dsws, recursive = FALSE)
bench.res = benchmark(dsws, task, resampling, measures, show.info = show.info)
perfs = getBMRAggrPerformances(bench.res, as.df = TRUE)
# get perc and learner col data
perc = extractSubList(bench.res$learners, c("par.vals", "dw.perc")) # get downsample reate
learner = extractSubList(bench.res$learners, c("next.learner", "id")) # get ID of unwrapped learner
perfs = dropNamed(perfs, c("task.id", "learner.id"))
# set short measures names and resort cols
mids = replaceDupeMeasureNames(measures, "id")
names(measures) = mids
colnames(perfs) = mids
out = cbind(learner = learner, percentage = perc, perfs)
makeS3Obj("LearningCurveData",
task = task,
measures = measures,
data = out)
}
#' @export
print.LearningCurveData = function(x, ...) {
catf("LearningCurveData:")
catf("Task: %s", x$task$task.desc$id)
catf("Measures: %s", collapse(extractSubList(x$measures, "name")))
printHead(x$data, ...)
}
#' @title Plot learning curve data using ggplot2.
#'
#' @family learning_curve
#' @family plot
#'
#' @description
#' Visualizes data size (percentage used for model) vs. performance measure(s).
#'
#' @param obj ([LearningCurveData])\cr
#' Result of [generateLearningCurveData], with class `LearningCurveData`.
#' @param facet (`character(1)`)\cr
#' Selects \dQuote{measure} or \dQuote{learner} to be the facetting variable.
#' The variable mapped to `facet` must have more than one unique value, otherwise it will
#' be ignored. The variable not chosen is mapped to color if it has more than one unique value.
#' The default is \dQuote{measure}.
#' @param pretty.names (`logical(1)`)\cr
#' Whether to use the [Measure] name instead of the id in the plot.
#' Default is `TRUE`.
#' @template arg_facet_nrow_ncol
#' @template ret_gg2
#' @export
plotLearningCurve = function(obj, facet = "measure", pretty.names = TRUE,
facet.wrap.nrow = NULL, facet.wrap.ncol = NULL) {
assertClass(obj, "LearningCurveData")
mappings = c("measure", "learner")
assertChoice(facet, mappings)
assertFlag(pretty.names)
color = mappings[mappings != facet]
if (pretty.names) {
mnames = replaceDupeMeasureNames(obj$measures, "name")
colnames(obj$data) = mapValues(colnames(obj$data),
names(obj$measures), mnames)
}
data = melt(as.data.table(obj$data), id.vars = c("learner", "percentage"), variable.name = "measure", value.name = "performance")
nlearn = length(unique(data$learner))
nmeas = length(unique(data$measure))
if ((color == "learner" & nlearn == 1L) | (color == "measure" & nmeas == 1L)) {
color = NULL
}
if ((facet == "learner" & nlearn == 1L) | (facet == "measure" & nmeas == 1L)) {
facet = NULL
}
if (!is.null(color)) {
plt = ggplot(data, aes_string(x = "percentage", y = "performance", colour = color))
} else {
plt = ggplot(data, aes_string(x = "percentage", y = "performance"))
}
plt = plt + geom_point()
plt = plt + geom_line()
if (!is.null(facet)) {
plt = plt + ggplot2::facet_wrap(as.formula(stri_paste("~", facet, sep = " ")),
scales = "free_y", nrow = facet.wrap.nrow, ncol = facet.wrap.ncol)
}
return(plt)
}
|