File: clustering.R

package info (click to toggle)
r-cran-rsample 1.1.1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 1,872 kB
  • sloc: sh: 13; makefile: 2
file content (165 lines) | stat: -rw-r--r-- 5,336 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#' Cluster Cross-Validation
#'
#' Cluster cross-validation splits the data into V groups of
#'  disjointed sets using k-means clustering of some variables.
#'  A resample of the analysis data consists of V-1 of the
#'  folds/clusters while the assessment set contains the final fold/cluster. In
#'  basic cross-validation (i.e. no repeats), the number of resamples
#'  is equal to V.
#'
#' @details
#' The variables in the `vars` argument are used for k-means clustering of
#'  the data into disjointed sets or for hierarchical clustering of the data.
#'  These clusters are used as the folds for cross-validation. Depending on how
#'  the data are distributed, there may not be an equal number of points
#'  in each fold.
#'
#' You can optionally provide a custom function to `distance_function`. The
#' function should take a data frame (as created via `data[vars]`) and return
#' a [stats::dist()] object with distances between data points.
#'
#' You can optionally provide a custom function to `cluster_function`. The
#' function must take three arguments:
#' - `dists`, a [stats::dist()] object with distances between data points
#' - `v`, a length-1 numeric for the number of folds to create
#' - `...`, to pass any additional named arguments to your function
#'
#' The function should return a vector of cluster assignments of length
#' `nrow(data)`, with each element of the vector corresponding to the matching
#' row of the data frame.
#'
#' @inheritParams vfold_cv
#' @param vars A vector of bare variable names to use to cluster the data.
#' @param repeats The number of times to repeat the clustered partitioning.
#' @param distance_function Which function should be used for distance calculations?
#' Defaults to [stats::dist()]. You can also provide your own
#' function; see `Details`.
#' @param cluster_function Which function should be used for clustering?
#' Options are either `"kmeans"` (to use [stats::kmeans()])
#' or `"hclust"` (to use [stats::hclust()]). You can also provide your own
#' function; see `Details`.
#' @param ... Extra arguments passed on to `cluster_function`.
#'
#' @return A tibble with classes `rset`, `tbl_df`, `tbl`, and `data.frame`.
#'  The results include a column for the data split objects and
#'  an identification variable `id`.
#'
#' @examplesIf rlang::is_installed("modeldata")
#' data(ames, package = "modeldata")
#' clustering_cv(ames, vars = c(Sale_Price, First_Flr_SF, Second_Flr_SF), v = 2)
#'
#' @rdname clustering_cv
#' @export
clustering_cv <- function(data,
                          vars,
                          v = 10,
                          repeats = 1,
                          distance_function = "dist",
                          cluster_function = c("kmeans", "hclust"),
                          ...) {
  check_repeats(repeats)

  if (!rlang::is_function(cluster_function)) {
    cluster_function <- rlang::arg_match(cluster_function)
  }

  vars <- tidyselect::eval_select(rlang::enquo(vars), data = data)
  if (rlang::is_empty(vars)) {
    rlang::abort("`vars` are required and must be variables in `data`.")
  }
  vars <- data[vars]

  if (repeats == 1) {
    dists <- rlang::exec(distance_function, vars)
    split_objs <- clustering_splits(
      data = data,
      dists = dists,
      v = v,
      cluster_function = cluster_function,
      ...
    )
  } else {
    for (i in 1:repeats) {
      dists <- rlang::exec(distance_function, vars)
      tmp <- clustering_splits(
        data = data,
        dists = dists,
        v = v,
        cluster_function = cluster_function,
        ...
      )
      tmp$id2 <- tmp$id
      tmp$id <- names0(repeats, "Repeat")[i]
      split_objs <- if (i == 1) {
        tmp
      } else {
        rbind(split_objs, tmp)
      }
    }
  }

  split_objs$splits <- map(split_objs$splits, rm_out)

  ## Save some overall information

  cv_att <- list(
    v = v,
    vars = names(vars),
    repeats = repeats,
    distance_function = distance_function,
    cluster_function = cluster_function
  )

  new_rset(
    splits = split_objs$splits,
    ids = split_objs[, grepl("^id", names(split_objs))],
    attrib = cv_att,
    subclass = c("clustering_cv", "rset")
  )
}

clustering_splits <- function(data,
                              dists,
                              v = 10,
                              cluster_function = c("kmeans", "hclust"),
                              ...) {
  if (!rlang::is_function(cluster_function)) {
    cluster_function <- rlang::arg_match(cluster_function)
  }

  check_v(v, nrow(data), "rows", call = rlang::caller_env())
  n <- nrow(data)

  clusterer <- ifelse(
    rlang::is_function(cluster_function),
    "custom",
    cluster_function
  )
  folds <- switch(
    clusterer,
    "kmeans" = {
      clusters <- stats::kmeans(dists, centers = v, ...)
      clusters$cluster
    },
    "hclust" = {
      clusters <- stats::hclust(dists, ...)
      stats::cutree(clusters, k = v)
    },
    do.call(cluster_function, list(dists = dists, v = v, ...))
  )

  idx <- seq_len(n)
  indices <- split_unnamed(idx, folds)
  indices <- lapply(indices, default_complement, n = n)

  split_objs <- purrr::map(
    indices,
    make_splits,
    data = data,
    class = c("clustering_split")
  )
  tibble::tibble(
    splits = split_objs,
    id = names0(length(split_objs), "Fold")
  )
}