File: test_base_tuneThreshold.R

package info (click to toggle)
r-cran-mlr 2.18.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 7,088 kB
  • sloc: ansic: 65; sh: 13; makefile: 2
file content (26 lines) | stat: -rw-r--r-- 1,217 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
context("tuneThreshold")

test_that("tuneThreshold", {
  # binary classification
  res = makeResampleDesc("Holdout")
  lrn = makeLearner("classif.rpart", predict.type = "prob")
  rf = resample(lrn, task = binaryclass.task, resampling = res, measures = list(mmce))
  th = tuneThreshold(rf$pred, measure = mmce)

  expect_equal(length(th$th), 1L) # threshold for positive class
  expect_equal(length(th$perf), 1L) # 1d-performance value

  # multiclass classification
  rf2 = resample(lrn, task = multiclass.task, resampling = res, measures = list(mmce))
  th2 = tuneThreshold(rf2$pred, measure = mmce, control = list(max.call = 5))

  # a measure that has to be maximized
  rf3 = resample(lrn, task = multiclass.task, resampling = res, measures = list(acc))
  th3 = tuneThreshold(rf2$pred, measure = acc, control = list(max.call = 5))
  expect_equal(th3$th, th2$th, tolreance = 0.0001)

  expect_equal(length(th2$perf), 1L) # 1d-performance value
  expect_equal(length(th2$th), length(getTaskClassLevels(multiclass.task))) # no. of threshold = no. of classes
  expect_equal(names(th2$th), getTaskClassLevels(multiclass.task)) # threshold names = class names
  expect_equal(sum(th2$th), 1L) # sum of thresholds = 1L
})