1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
|
#' Cluster Cross-Validation
#'
#' Cluster cross-validation splits the data into V groups of
#' disjointed sets using k-means clustering of some variables.
#' A resample of the analysis data consists of V-1 of the
#' folds/clusters while the assessment set contains the final fold/cluster. In
#' basic cross-validation (i.e. no repeats), the number of resamples
#' is equal to V.
#'
#' @details
#' The variables in the `vars` argument are used for k-means clustering of
#' the data into disjointed sets or for hierarchical clustering of the data.
#' These clusters are used as the folds for cross-validation. Depending on how
#' the data are distributed, there may not be an equal number of points
#' in each fold.
#'
#' You can optionally provide a custom function to `distance_function`. The
#' function should take a data frame (as created via `data[vars]`) and return
#' a [stats::dist()] object with distances between data points.
#'
#' You can optionally provide a custom function to `cluster_function`. The
#' function must take three arguments:
#' - `dists`, a [stats::dist()] object with distances between data points
#' - `v`, a length-1 numeric for the number of folds to create
#' - `...`, to pass any additional named arguments to your function
#'
#' The function should return a vector of cluster assignments of length
#' `nrow(data)`, with each element of the vector corresponding to the matching
#' row of the data frame.
#'
#' @inheritParams vfold_cv
#' @param vars A vector of bare variable names to use to cluster the data.
#' @param repeats The number of times to repeat the clustered partitioning.
#' @param distance_function Which function should be used for distance calculations?
#' Defaults to [stats::dist()]. You can also provide your own
#' function; see `Details`.
#' @param cluster_function Which function should be used for clustering?
#' Options are either `"kmeans"` (to use [stats::kmeans()])
#' or `"hclust"` (to use [stats::hclust()]). You can also provide your own
#' function; see `Details`.
#' @param ... Extra arguments passed on to `cluster_function`.
#'
#' @return A tibble with classes `rset`, `tbl_df`, `tbl`, and `data.frame`.
#' The results include a column for the data split objects and
#' an identification variable `id`.
#'
#' @examplesIf rlang::is_installed("modeldata")
#' data(ames, package = "modeldata")
#' clustering_cv(ames, vars = c(Sale_Price, First_Flr_SF, Second_Flr_SF), v = 2)
#'
#' @rdname clustering_cv
#' @export
clustering_cv <- function(data,
vars,
v = 10,
repeats = 1,
distance_function = "dist",
cluster_function = c("kmeans", "hclust"),
...) {
check_repeats(repeats)
if (!rlang::is_function(cluster_function)) {
cluster_function <- rlang::arg_match(cluster_function)
}
vars <- tidyselect::eval_select(rlang::enquo(vars), data = data)
if (rlang::is_empty(vars)) {
rlang::abort("`vars` are required and must be variables in `data`.")
}
vars <- data[vars]
if (repeats == 1) {
dists <- rlang::exec(distance_function, vars)
split_objs <- clustering_splits(
data = data,
dists = dists,
v = v,
cluster_function = cluster_function,
...
)
} else {
for (i in 1:repeats) {
dists <- rlang::exec(distance_function, vars)
tmp <- clustering_splits(
data = data,
dists = dists,
v = v,
cluster_function = cluster_function,
...
)
tmp$id2 <- tmp$id
tmp$id <- names0(repeats, "Repeat")[i]
split_objs <- if (i == 1) {
tmp
} else {
rbind(split_objs, tmp)
}
}
}
split_objs$splits <- map(split_objs$splits, rm_out)
## Save some overall information
cv_att <- list(
v = v,
vars = names(vars),
repeats = repeats,
distance_function = distance_function,
cluster_function = cluster_function
)
new_rset(
splits = split_objs$splits,
ids = split_objs[, grepl("^id", names(split_objs))],
attrib = cv_att,
subclass = c("clustering_cv", "rset")
)
}
clustering_splits <- function(data,
dists,
v = 10,
cluster_function = c("kmeans", "hclust"),
...) {
if (!rlang::is_function(cluster_function)) {
cluster_function <- rlang::arg_match(cluster_function)
}
check_v(v, nrow(data), "rows", call = rlang::caller_env())
n <- nrow(data)
clusterer <- ifelse(
rlang::is_function(cluster_function),
"custom",
cluster_function
)
folds <- switch(
clusterer,
"kmeans" = {
clusters <- stats::kmeans(dists, centers = v, ...)
clusters$cluster
},
"hclust" = {
clusters <- stats::hclust(dists, ...)
stats::cutree(clusters, k = v)
},
do.call(cluster_function, list(dists = dists, v = v, ...))
)
idx <- seq_len(n)
indices <- split_unnamed(idx, folds)
indices <- lapply(indices, default_complement, n = n)
split_objs <- purrr::map(
indices,
make_splits,
data = data,
class = c("clustering_split")
)
tibble::tibble(
splits = split_objs,
id = names0(length(split_objs), "Fold")
)
}
|