File: test_base_PreprocWrapper.R

package info (click to toggle)
r-cran-mlr 2.19.1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 8,392 kB
  • sloc: ansic: 65; sh: 13; makefile: 5
file content (55 lines) | stat: -rwxr-xr-x 2,038 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

test_that("PreprocWrapper", {
  f1 = function(data, target, args) {
    data[, 2] = args$x * data[, 2]
    return(list(data = data, control = list()))
  }
  f2 = function(data, target, args, control) {
    data[, 2] = args$x * data[, 2]
    return(data)
  }
  ps = makeParamSet(
    makeNumericLearnerParam(id = "x"),
    makeNumericLearnerParam(id = "y")
  )
  lrn1 = makeLearner("classif.rpart", minsplit = 10)
  lrn2 = makePreprocWrapper(lrn1, train = f1, predict = f2, par.set = ps, par.vals = list(x = 1, y = 2))
  capture.output(print(lrn2))

  expect_true(setequal(getHyperPars(lrn2), list(xval = 0, minsplit = 10, x = 1, y = 2)))
  expect_true(setequal(getHyperPars(lrn2, "train"), list(xval = 0, minsplit = 10, x = 1, y = 2)))
  expect_true(setequal(lrn2$par.vals, list(x = 1, y = 2)))

  lrn3 = setHyperPars(lrn2, minsplit = 77, x = 88)
  expect_true(setequal(getHyperPars(lrn3), list(xval = 0, minsplit = 77, x = 88, y = 2)))
  expect_true(setequal(lrn3$par.vals, list(x = 88, y = 2)))

  m = train(lrn2, task = multiclass.task)
  capture.output(print(m))
  expect_true(setequal(getHyperPars(m$learner), list(xval = 0, minsplit = 10, x = 1, y = 2)))
})

test_that("getLearnerModel on nested PreprocWrapper", {
  lrn = makeLearner("classif.rpart")
  lrn = makeDummyFeaturesWrapper(lrn)
  lrn = makeImputeWrapper(lrn, classes = list(numeric = imputeMax(5), factor = imputeConstant("NA")))
  m = train(lrn, binaryclass.task)
  expect_s3_class(getLearnerModel(m), "PreprocModel")
  expect_s3_class(getLearnerModel(m, TRUE), "rpart")
})

test_that("PreprocWrapper with glmnet (#958)", {
  requirePackagesOrSkip("glmnet", default.method = "load")
  lrn = makeLearner("classif.glmnet", predict.type = "response")
  lrn2 = makePreprocWrapper(lrn,
    train = function(data, target, args) {
      return(list(data = data, control = list()))
    },
    predict = function(data, target, args, control) {
      return(data)
    }
  )
  mod = train(lrn2, multiclass.task)
  pred = predict(mod, multiclass.task)
  expect_error(pred, NA)
})