File: test_base_BaseEnsemble.R

package info (click to toggle)
r-cran-mlr 2.19.1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 8,392 kB
  • sloc: ansic: 65; sh: 13; makefile: 5
file content (65 lines) | stat: -rw-r--r-- 2,956 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

test_that("BaseEnsemble", {
  bl1 = makeLearner("classif.rpart", minsplit = 2L, id = "a")
  bl2 = makeLearner("classif.ksvm", C = 2, id = "b")
  ps = makeParamSet(makeNumericLearnerParam("foo"))
  pv = list(foo = 3)
  be = makeBaseEnsemble(id = "foo", base.learners = list(bl1, bl2), par.set = ps, par.vals = pv,
    cl = "mywrapper")
  expect_true(setequal(getHyperPars(be), list(a.xval = 0L, a.minsplit = 2L,
    b.fit = FALSE, b.C = 2, foo = 3)))

  be = setHyperPars(be, a.minsplit = 11)
  expect_true(setequal(getHyperPars(be), list(a.xval = 0L, a.minsplit = 11L,
    b.fit = FALSE, b.C = 2, foo = 3)))

  be = setHyperPars(be, foo = 12)
  expect_true(setequal(getHyperPars(be), list(a.xval = 0L, a.minsplit = 11L,
    b.fit = FALSE, b.C = 2, foo = 12)))

  # check removing hyperpars
  be1 = removeHyperPars(be, names(getHyperPars(be)))
  expect_true(length(getHyperPars(be1)) == 0)

  bl1 = be
  bl2 = makeOversampleWrapper(makeFilterWrapper(bl2, fw.perc = 0.5), osw.rate = 1)
  ps = makeParamSet(makeNumericLearnerParam("foo"))
  pv = list(foo = 3)
  be = makeBaseEnsemble(id = "foo", base.learners = list(bl1, bl2), par.set = ps, par.vals = pv,
    cl = "mywrapper")
  be1 = removeHyperPars(be, names(getHyperPars(be)))

  # check that we get error if predict types are unequal
  bl1 = makeLearner("classif.rpart", predict.type = "prob")
  bl2 = makeLearner("classif.ksvm", predict.type = "response")
  expect_error(makeBaseEnsemble(id = "foo", base.learners = list(bl1, bl2),
    par.set = ps, par.vals = pv, cl = "mywrapper"), "predict.type")

  # check getHyperPars when we have multiple wrappers
  bl1 = makeLearner("classif.rpart", minsplit = 2L, id = "rpart")
  bl2 = makeLearner("classif.ksvm", C = 2, id = "ksvm")
  # now make a wrapper around bl2 (ksvm)
  bl2 = makeOversampleWrapper(makeFilterWrapper(bl2, fw.perc = 0.5), osw.rate = 1)
  be = makeBaseEnsemble(id = "foo", base.learners = list(bl1, bl2), cl = "mywrapper")
  expect_output(print(be), "mywrapper")
  expect_true(setequal(getHyperPars(be),
    list(rpart.xval = 0L, rpart.minsplit = 2L,
      ksvm.filtered.oversampled.fit = FALSE,
      ksvm.filtered.oversampled.C = 2,
      ksvm.filtered.oversampled.fw.method = "FSelectorRcpp_information.gain",
      ksvm.filtered.oversampled.fw.perc = 0.5,
      ksvm.filtered.oversampled.osw.rate = 1)))

  # check removing hyperpars
  be.rm = removeHyperPars(be, names(getHyperPars(be)))
  expect_true(length(getHyperPars(be.rm)) == 0)

  # check setPredictType
  be.pt = setPredictType(be, predict.type = "prob")
  expect_equal(be.pt$predict.type, "prob")
  expect_equal(lapply(be.pt$base.learners, function(x) x$predict.type), list(rpart = "prob", ksvm.filtered.oversampled = "prob"))

  be.pt = setPredictType(be, predict.type = "response")
  expect_equal(be.pt$predict.type, "response")
  expect_equal(lapply(be.pt$base.learners, function(x) x$predict.type), list(rpart = "response", ksvm.filtered.oversampled = "response"))
})