1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
|
#' @title Filter features by thresholding filter values.
#'
#' @description
#' First, calls [generateFilterValuesData].
#' Features are then selected via `select` and `val`.
#'
#' @template arg_task
#' @param method (`character(1)`)\cr
#' See [listFilterMethods].
#' Default is \dQuote{randomForestSRC.rfsrc}.
#' @param fval ([FilterValues])\cr
#' Result of [generateFilterValuesData].
#' If you pass this, the filter values in the object are used for feature filtering.
#' `method` and `...` are ignored then.
#' Default is `NULL` and not used.
#' @param perc (`numeric(1)`)\cr
#' If set, select `perc`*100 top scoring features.
#' Mutually exclusive with arguments `abs` and `threshold`.
#' @param abs (`numeric(1)`)\cr
#' If set, select `abs` top scoring features.
#' Mutually exclusive with arguments `perc` and `threshold`.
#' @param threshold (`numeric(1)`)\cr
#' If set, select features whose score exceeds `threshold`.
#' Mutually exclusive with arguments `perc` and `abs`.
#' @param mandatory.feat ([character])\cr
#' Mandatory features which are always included regardless of their scores
#' @param ... (any)\cr
#' Passed down to selected filter method.
#' @template ret_task
#' @export
#' @family filter
filterFeatures = function(task, method = "randomForestSRC.rfsrc", fval = NULL, perc = NULL, abs = NULL,
threshold = NULL, mandatory.feat = NULL, ...) {
assertClass(task, "SupervisedTask")
assertChoice(method, choices = ls(.FilterRegister))
select = checkFilterArguments(perc, abs, threshold)
p = getTaskNFeats(task)
nselect = switch(select,
perc = round(perc * p),
abs = min(abs, p),
threshold = p
)
if (is.null(fval)) {
fval = generateFilterValuesData(task = task, method = method, nselect = nselect, ...)$data
} else {
assertClass(fval, "FilterValues")
if (!is.null(fval$method)) { ## fval is generated by deprecated getFilterValues
colnames(fval$data)[which(colnames(fval$data) == "val")] = fval$method
method = fval$method
fval = fval$data[, c(1, 3, 2)]
} else {
methods = colnames(fval$data[, -which(colnames(fval$data) %in% c("name", "type")), drop = FALSE])
if (length(methods) > 1) {
assert(method %in% methods)
} else {
method = methods
fval = fval$data
}
}
}
if (all(is.na(fval[[method]]))) {
stopf("Filter method returned all NA values!")
}
if (!is.null(mandatory.feat)) {
assertCharacter(mandatory.feat)
if (!all(mandatory.feat %in% fval$name))
stop("At least one mandatory feature was not found in the task.")
if (select != "threshold" && nselect < length(mandatory.feat))
stop("The number of features to be filtered cannot be smaller than the number of mandatory features.")
#Set the the filter values of the mandatory features to infinity to always select them
fval[fval$name %in% mandatory.feat, method] = Inf
}
if (select == "threshold")
nselect = sum(fval[[method]] >= threshold, na.rm = TRUE)
features = as.character(head(sortByCol(fval, method, asc = FALSE)$name, nselect))
allfeats = getTaskFeatureNames(task)
j = match(features, allfeats)
features = allfeats[sort(j)]
subsetTask(task, features = features)
}
checkFilterArguments = function(perc, abs, threshold) {
sum.null = sum(!is.null(perc), !is.null(abs), !is.null(threshold))
if (sum.null == 0L)
stop("At least one of 'perc', 'abs' or 'threshold' must be not NULL")
if (sum.null >= 2L)
stop("Arguments 'perc', 'abs' and 'threshold' are mutually exclusive")
if (!is.null(perc)) {
assertNumber(perc, lower = 0, upper = 1)
return("perc")
}
if (!is.null(abs)) {
assertCount(abs)
return("abs")
}
if (!is.null(threshold)) {
assertNumber(threshold)
return("threshold")
}
}
|