File: cluster_discrimination.R

package info (click to toggle)
r-cran-parameters 0.24.2-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,852 kB
  • sloc: sh: 16; makefile: 2
file content (89 lines) | stat: -rw-r--r-- 3,093 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#' Compute a linear discriminant analysis on classified cluster groups
#'
#' Computes linear discriminant analysis (LDA) on classified cluster groups, and
#' determines the goodness of classification for each cluster group. See `MASS::lda()`
#' for details.
#'
#' @param x A data frame
#' @param cluster_groups Group classification of the cluster analysis, which can
#'   be retrieved from the [cluster_analysis()] function.
#' @param ... Other arguments to be passed to or from.
#'
#' @seealso [`n_clusters()`] to determine the number of clusters to extract,
#' [`cluster_analysis()`] to compute a cluster analysis and
#' [`performance::check_clusterstructure()`] to check suitability of data for
#' clustering.
#'
#' @examplesIf requireNamespace("MASS", quietly = TRUE)
#' # Retrieve group classification from hierarchical cluster analysis
#' clustering <- cluster_analysis(iris[, 1:4], n = 3)
#'
#' # Goodness of group classification
#' cluster_discrimination(clustering)
#' @export
cluster_discrimination <- function(x, cluster_groups = NULL, ...) {
  UseMethod("cluster_discrimination")
}

#' @export
cluster_discrimination.cluster_analysis <- function(x, cluster_groups = NULL, ...) {
  if (is.null(cluster_groups)) {
    cluster_groups <- stats::predict(x)
  }
  cluster_discrimination(attributes(x)$data, cluster_groups, ...)
}


#' @export
cluster_discrimination.default <- function(x, cluster_groups = NULL, ...) {
  if (is.null(cluster_groups)) {
    insight::format_error("Please provide cluster assignments via `cluster_groups`.")
  }

  x <- stats::na.omit(x)
  cluster_groups <- stats::na.omit(cluster_groups)

  # compute discriminant analysis of groups on original data frame
  insight::check_if_installed("MASS")
  disc <- MASS::lda(cluster_groups ~ ., data = x, na.action = "na.omit", CV = TRUE)

  # Assess the accuracy of the prediction
  # percent correct for each category of groups
  classification_table <- table(cluster_groups, disc$class)
  correct <- diag(prop.table(classification_table, 1))

  # total correct percentage
  total_correct <- sum(diag(prop.table(classification_table)))

  out <- data.frame(
    Group = unique(cluster_groups),
    Accuracy = correct,
    stringsAsFactors = FALSE
  )

  # Sort according to accuracy
  out <- out[order(out$Group), ]

  attr(out, "Overall_Accuracy") <- total_correct
  class(out) <- c("cluster_discrimination", class(out))

  out
}


# Utils -------------------------------------------------------------------


#' @export
print.cluster_discrimination <- function(x, ...) {
  orig_x <- x
  insight::print_color("# Accuracy of Cluster Group Classification via Linear Discriminant Analysis (LDA)\n\n", "blue")

  total_accuracy <- attributes(x)$Overall_Accuracy
  x$Accuracy <- sprintf("%.2f%%", 100 * x$Accuracy)
  total <- sprintf("%.2f%%", 100 * total_accuracy)

  print.data.frame(x, row.names = FALSE, ...)
  insight::print_color(sprintf("\nOverall accuracy of classification: %s\n", total), "yellow")
  invisible(orig_x)
}