File: knn.ani.R

package info (click to toggle)
r-cran-animation 2.7%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, forky, sid, trixie
  • size: 1,268 kB
  • sloc: javascript: 873; sh: 15; makefile: 2
file content (139 lines) | stat: -rw-r--r-- 6,315 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#' Demonstration of the k-Nearest Neighbour classification
#'
#' Demonstrate the process of k-Nearest Neighbour classification on the 2D
#' plane.
#'
#' For each row of the test set, the \eqn{k} nearest (in Euclidean distance)
#' training set vectors are found, and the classification is decided by majority
#' vote, with ties broken at random. For a single test sample point, the basic
#' steps are:
#'
#' \enumerate{ \item locate the test point \item compute the distances between
#' the test point and all points in the training set \item find \eqn{k} shortest
#' distances and the corresponding training set points \item vote for the result
#' (find the maximum in the table for the true classifications) }
#'
#' As there are four steps in an iteration, the total number of animation frames
#' should be \code{4 * min(nrow(test), ani.options('nmax'))} at last.
#'
#' @param train matrix or data frame of training set cases containing only 2
#'   columns
#' @param test matrix or data frame of test set cases. A vector will be
#'   interpreted as a row vector for a single case. It should also contain only
#'   2 columns. This data set will be \emph{ignored} if \code{interact = TRUE};
#'   see \code{interact} below.
#' @param cl factor of true classifications of training set
#' @param k number of neighbours considered.
#' @param interact logical. If \code{TRUE}, the user will have to choose a test
#'   set for himself using mouse click on the screen; otherwise compute kNN
#'   classification based on argument \code{test}.
#' @param tt.col a vector of length 2 specifying the colors for the training
#'   data and test data.
#' @param cl.pch a vector specifying symbols for each class
#' @param dist.lty,dist.col the line type and color to annotate the distances
#' @param knn.col the color to annotate the k-nearest neighbour points using a
#'   polygon
#' @param ... additional arguments to create the empty frame for the animation
#'   (passed to \code{\link{plot.default}})
#' @return A vector of class labels for the test set.
#' @note There is a special restriction (only two columns) on the training and
#'   test data set just for sake of the convenience for making a scatterplot.
#'   This is only a rough demonstration; for practical applications, please
#'   refer to existing kNN functions such as \code{\link[class]{knn}} in
#'   \pkg{class}, etc.
#'
#'   If either one of \code{train} and \code{test} is missing, there'll be
#'   random matrices prepared for them. (It's the same for \code{cl}.)
#' @author Yihui Xie
#' @seealso \code{\link[class]{knn}}
#' @references Examples at \url{https://yihui.org/animation/example/knn-ani/}
#' 
#'   Venables, W. N. and Ripley, B. D. (2002) \emph{Modern Applied
#'   Statistics with S}. Fourth edition. Springer.
#'
#' @export
knn.ani = function(
  train, test, cl, k = 10, interact = FALSE, tt.col = c('blue', 'red'),
  cl.pch = seq_along(unique(cl)), dist.lty = 2, dist.col = 'gray', knn.col = 'green', ...
) {
  nmax = ani.options('nmax')
  if (missing(train)) {
    train = matrix(c(rnorm(40, mean = -1), rnorm(40, mean = 1)), ncol = 2, byrow = TRUE)
    cl = rep(c('first class', 'second class'), each = 20)
  }
  if (missing(test))
    test = matrix(rnorm(20, mean = 0, sd = 1.2), ncol = 2)
  train <- as.matrix(train)
  if (interact) {
    plot(train, main = 'Choose test set points', pch = unclass(as.factor(cl)),
         col = tt.col[1])
    lct = locator(n = nmax, type = 'p', pch = '?', col = tt.col[2])
    test = cbind(lct$x, lct$y)
  }
  if (is.null(dim(test)))
    dim(test) <- c(1, length(test))
  test <- as.matrix(test)
  if (any(is.na(train)) || any(is.na(test)) || any(is.na(cl)))
    stop('no missing values are allowed')
  if (ncol(test) != 2 | ncol(train) != 2)
    stop("both column numbers of 'train' and 'test' must be 2!")
  ntr <- nrow(train)
  if (length(cl) != ntr)
    stop("'train' and 'class' have different lengths")
  if (ntr < k) {
    warning(gettextf('k = %d exceeds number %d of patterns', k, ntr), domain = NA)
    k <- ntr
  }
  if (k < 1)
    stop(gettextf('k = %d must be at least 1', k), domain = NA)
  nte = nrow(test)
  clf = as.factor(cl)
  res = NULL
  pre.plot = function(j, pf = NULL, i.point = TRUE, ...) {
    plot(rbind(train, test), type = 'n', xlab = expression(italic(X)[1]),
         ylab = expression(italic(X)[2]), panel.first = pf, ...)
    points(train, col = tt.col[1], pch = cl.pch[unclass(clf)])
    if (j < nte)
      points(test[(j + 1):nte, 1], test[(j + 1):nte, 2], col = tt.col[2], pch = '?')
    if (j > 1)
      points(test[1:(j - 1), 1], test[1:(j - 1), 2], col = tt.col[2],
             pch = cl.pch[unclass(res)], cex = 2)
    if (i.point)
      points(test[j, 1], test[j, 2], col = tt.col[2], pch = '?', cex = 2)
    legend('topleft', legend = levels(clf), pch = cl.pch[seq_along(levels(clf))],
           bty = 'n', y.intersp = 1.3)
    legend('bottomleft', legend = c('training set', 'test set'),
           fill = tt.col, bty = 'n', y.intersp = 1.3, )
  }
  nmax = min(nmax, nrow(test))
  for (i in 1:nmax) {
    dev.hold()
    pre.plot(i, ...)
    ani.pause()
    idx = rank(apply(train, 1, function(x) sqrt(sum((x - test[i, ])^2))),
               ties.method = 'random') %in% seq(k)
    vote = cl[idx]
    res = c(res, factor(names(which.max(table(vote))), levels = levels(clf), labels = levels(clf)))
    pre.plot(
      i, segments(train[, 1], train[, 2], test[i, 1], test[i, 2], lty = dist.lty, col = dist.col), ...
    )
    ani.pause()
    bd = train[idx, 1:2]
    pre.plot(
      i, {
        segments(train[, 1], train[, 2], test[i, 1], test[i,2], lty = dist.lty, col = dist.col)
        if (k > 1) polygon(bd[chull(bd), ], density = 10, col = knn.col) else
          points(bd[1], bd[2], col = knn.col, pch = cl.pch[unclass(clf)[idx]], cex = 2, lwd = 2)
      }, ...)
    ani.pause()
    pre.plot(
      i, {
        segments(train[, 1], train[, 2], test[i, 1], test[i, 2], lty = dist.lty, col = dist.col)
        if (k > 1) polygon(bd[chull(bd), ], density = 10, col = knn.col) else
          points(bd[1], bd[2], col = knn.col, pch = cl.pch[unclass(clf)[idx]], cex = 2, lwd = 2)
        points(test[i, 1], test[i, 2], col = tt.col[2], pch = cl.pch[unclass(res)[i]], cex = 3, lwd = 2)
      }, FALSE, ...)
    ani.pause()
  }
  invisible(levels(clf)[res])
}