File: ParamSet.R

package info (click to toggle)
r-cran-paramhelpers 1.14.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 992 kB
  • sloc: ansic: 102; sh: 13; makefile: 2
file content (175 lines) | stat: -rw-r--r-- 5,759 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#' @title Construct a parameter set.
#'
#' @description
#' `makeParamSet`: Construct from a bunch of parameters.
#'
#' Multiple sets can be concatenated with `c`.
#'
#' The constructed S3 class is simply a list that contains the element `pars`.
#' `pars` is a list of the passed parameters, named by their ids.
#'
#' If `keys` are provided it will automatically be checked whether all
#' expressions within the provided parameters only contain arguments that are a
#' subset of keys.
#'
#' @param ... ([Param()])\cr
#'   Parameters.
#' @param params (list of [Param()])\cr
#'   List of parameters, alternative way instead of using `...`.
#' @param forbidden (`NULL` | R expression)\cr
#'   States forbidden region of parameter set via an expression. Every setting
#'   which satisfies this expression is considered to be infeasible. This makes
#'   it possible to exclude more complex region of the parameter space than
#'   through simple constraints or `requires`-conditions (although these should
#'   be always used when possible). If parameters have associated trafos, the
#'   forbidden region must always be specified on the original scale and not the
#'   transformed one. Default is `NULL` which means no forbidden region.
#' @template arg_keys
#' @return [ParamSet()] | `LearnerParamSet`.
#'   If all parameters of the `ParamSet` are learner parameters, the output
#'   will inherit the class `LearnerParamSet`.
#' @aliases ParamSet
#' @export
#' @examples
#' makeParamSet(
#'   makeNumericParam("u", lower = 1),
#'   makeIntegerParam("v", lower = 1, upper = 2),
#'   makeDiscreteParam("w", values = 1:2),
#'   makeLogicalParam("x"),
#'   makeDiscreteVectorParam("y", len = 2, values = c("a", "b"))
#' )
#' makeParamSet(
#'   makeNumericParam("u", lower = expression(ceiling(n))),
#'   makeIntegerParam("v", lower = expression(floor(n)), upper = 2),
#'   keys = c("p", "n")
#' )
#' makeParamSet(
#'   makeNumericParam("min", lower = 0, upper = 0.8),
#'   makeNumericParam("max", lower = 0.2, upper = 1),
#'   forbidden = expression(min > max)
#' )
makeParamSet = function(..., params = NULL, forbidden = NULL, keys = NULL) {

  pars = list(...)
  if (length(pars) > 0 && !is.null(params)) {
    stop("You can only use one of ... or params!")
  }
  if (!is.null(params)) {
    assertList(params, types = "Param")
    pars = params
  } else {
    assertList(pars, types = "Param")
  }
  ns = extractSubList(pars, "id")
  if (anyDuplicated(ns)) {
    stop("All parameters must have unique names!")
  }
  names(pars) = ns
  par.set = makeS3Obj("ParamSet", pars = pars, forbidden = forbidden)

  if (length(pars) > 0L) {
    # if all Params are LearnerParams, then the ParSet is considered
    # to be a LearnerParSet and we automatically extend the keys by
    # the default keys from mlr, i.e. task, n, p, k and type
    if (all(vlapply(pars, inherits, what = "LearnerParam"))) {
      par.set = addClasses(par.set, classes = "LearnerParamSet")
      keys = union(keys, c("task", "n", "p", "k", "type"))
    }
    if (!is.null(keys) && (hasExpression(par.set))) {
      checkExpressionFeasibility(par.set = par.set, keys = keys)
    }
  }
  return(par.set)
}

getParSetPrintData = function(x, trafo = TRUE, used = TRUE, constr.clip = 40L) {
  d = lapply(x$pars, getParPrintData, trafo = trafo, used = used, constr.clip = constr.clip)
  return(do.call(rbind, d))
}

#' @export
print.ParamSet = function(x, ..., trafo = TRUE, used = TRUE, constr.clip = 40L) {
  if (isEmpty(x)) {
    print("Empty parameter set.")
  } else {
    print(getParSetPrintData(x, trafo = trafo, used = used, constr.clip = constr.clip))
  }
  if (hasForbidden(x)) {
    catf("Forbidden region specified.")
  }
  return(invisible(NULL))
}

#' @export
c.ParamSet = function(..., recursive = FALSE) {
  pss = list(...)
  pars = Reduce(c, lapply(pss, function(ps) ps$pars))
  # remove the names here. if 'params' is a par name, this wont work in the
  # contructor call but we are allowed to pass the list without names, as they
  # are set again automatically later for pars
  names(pars) = NULL
  return(do.call(makeParamSet, pars))
}

#' Check whether parameter set is empty.
#'
#' @param par.set (ParamSet()])\cr
#'   Parameter set.
#' @return `logical(1)`.
#' @export
isEmpty = function(par.set) {
  assertClass(par.set, "ParamSet")
  UseMethod("isEmpty")
}

#' @export
isEmpty.ParamSet = function(par.set) {
  return(length(par.set$pars) == 0)
}

#' `makeNumericParamSet`: Convenience function for numerics.
#'
#' @param id (`character(1)`)\cr
#'   Name of parameter.
#' @param len (`integer(1)`)\cr
#'   Length of vector.
#' @param lower (`numeric`)\cr
#'   Lower bound.
#'   Default is `-Inf`.
#' @param upper [numeric] \cr
#'   Upper bound.
#'   Default is `Inf`.
#' @param vector (`logical(1)`) \cr
#'   Should a `NumericVectorParam` be used instead of
#'   n `NumericParam` objects?
#'   Default is `TRUE`.
#' @rdname makeParamSet
#' @export
makeNumericParamSet = function(id = "x", len, lower = -Inf, upper = Inf, vector = TRUE) {
  assertString(id)
  if (missing(len)) {
    if (!missing(lower)) {
      len = length(lower)
    } else if (!missing(upper)) {
      len = length(upper)
    }
  } else {
    len = asInt(len)
  }
  if (is.numeric(lower) && length(lower) == 1L) {
    lower = rep(lower, len)
  }
  if (is.numeric(upper) && length(upper) == 1L) {
    upper = rep(upper, len)
  }
  assertNumeric(lower, len = len)
  assertNumeric(upper, len = len)
  assertFlag(vector)
  if (vector) {
    return(makeParamSet(makeNumericVectorParam(id = id, len = len, lower = lower, upper = upper)))
  } else {
    return(makeParamSet(params = lapply(1:len, function(i) {
      makeNumericParam(id = paste(id, i, sep = ""), lower = lower[i], upper = upper[i])
    })))
  }
}