1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
|
#' Demonstration of the k-Nearest Neighbour classification
#'
#' Demonstrate the process of k-Nearest Neighbour classification on the 2D
#' plane.
#'
#' For each row of the test set, the \eqn{k} nearest (in Euclidean distance)
#' training set vectors are found, and the classification is decided by majority
#' vote, with ties broken at random. For a single test sample point, the basic
#' steps are:
#'
#' \enumerate{ \item locate the test point \item compute the distances between
#' the test point and all points in the training set \item find \eqn{k} shortest
#' distances and the corresponding training set points \item vote for the result
#' (find the maximum in the table for the true classifications) }
#'
#' As there are four steps in an iteration, the total number of animation frames
#' should be \code{4 * min(nrow(test), ani.options('nmax'))} at last.
#'
#' @param train matrix or data frame of training set cases containing only 2
#' columns
#' @param test matrix or data frame of test set cases. A vector will be
#' interpreted as a row vector for a single case. It should also contain only
#' 2 columns. This data set will be \emph{ignored} if \code{interact = TRUE};
#' see \code{interact} below.
#' @param cl factor of true classifications of training set
#' @param k number of neighbours considered.
#' @param interact logical. If \code{TRUE}, the user will have to choose a test
#' set for himself using mouse click on the screen; otherwise compute kNN
#' classification based on argument \code{test}.
#' @param tt.col a vector of length 2 specifying the colors for the training
#' data and test data.
#' @param cl.pch a vector specifying symbols for each class
#' @param dist.lty,dist.col the line type and color to annotate the distances
#' @param knn.col the color to annotate the k-nearest neighbour points using a
#' polygon
#' @param ... additional arguments to create the empty frame for the animation
#' (passed to \code{\link{plot.default}})
#' @return A vector of class labels for the test set.
#' @note There is a special restriction (only two columns) on the training and
#' test data set just for sake of the convenience for making a scatterplot.
#' This is only a rough demonstration; for practical applications, please
#' refer to existing kNN functions such as \code{\link[class]{knn}} in
#' \pkg{class}, etc.
#'
#' If either one of \code{train} and \code{test} is missing, there'll be
#' random matrices prepared for them. (It's the same for \code{cl}.)
#' @author Yihui Xie
#' @seealso \code{\link[class]{knn}}
#' @references Examples at \url{https://yihui.org/animation/example/knn-ani/}
#'
#' Venables, W. N. and Ripley, B. D. (2002) \emph{Modern Applied
#' Statistics with S}. Fourth edition. Springer.
#'
#' @export
knn.ani = function(
train, test, cl, k = 10, interact = FALSE, tt.col = c('blue', 'red'),
cl.pch = seq_along(unique(cl)), dist.lty = 2, dist.col = 'gray', knn.col = 'green', ...
) {
nmax = ani.options('nmax')
if (missing(train)) {
train = matrix(c(rnorm(40, mean = -1), rnorm(40, mean = 1)), ncol = 2, byrow = TRUE)
cl = rep(c('first class', 'second class'), each = 20)
}
if (missing(test))
test = matrix(rnorm(20, mean = 0, sd = 1.2), ncol = 2)
train <- as.matrix(train)
if (interact) {
plot(train, main = 'Choose test set points', pch = unclass(as.factor(cl)),
col = tt.col[1])
lct = locator(n = nmax, type = 'p', pch = '?', col = tt.col[2])
test = cbind(lct$x, lct$y)
}
if (is.null(dim(test)))
dim(test) <- c(1, length(test))
test <- as.matrix(test)
if (any(is.na(train)) || any(is.na(test)) || any(is.na(cl)))
stop('no missing values are allowed')
if (ncol(test) != 2 | ncol(train) != 2)
stop("both column numbers of 'train' and 'test' must be 2!")
ntr <- nrow(train)
if (length(cl) != ntr)
stop("'train' and 'class' have different lengths")
if (ntr < k) {
warning(gettextf('k = %d exceeds number %d of patterns', k, ntr), domain = NA)
k <- ntr
}
if (k < 1)
stop(gettextf('k = %d must be at least 1', k), domain = NA)
nte = nrow(test)
clf = as.factor(cl)
res = NULL
pre.plot = function(j, pf = NULL, i.point = TRUE, ...) {
plot(rbind(train, test), type = 'n', xlab = expression(italic(X)[1]),
ylab = expression(italic(X)[2]), panel.first = pf, ...)
points(train, col = tt.col[1], pch = cl.pch[unclass(clf)])
if (j < nte)
points(test[(j + 1):nte, 1], test[(j + 1):nte, 2], col = tt.col[2], pch = '?')
if (j > 1)
points(test[1:(j - 1), 1], test[1:(j - 1), 2], col = tt.col[2],
pch = cl.pch[unclass(res)], cex = 2)
if (i.point)
points(test[j, 1], test[j, 2], col = tt.col[2], pch = '?', cex = 2)
legend('topleft', legend = levels(clf), pch = cl.pch[seq_along(levels(clf))],
bty = 'n', y.intersp = 1.3)
legend('bottomleft', legend = c('training set', 'test set'),
fill = tt.col, bty = 'n', y.intersp = 1.3, )
}
nmax = min(nmax, nrow(test))
for (i in 1:nmax) {
dev.hold()
pre.plot(i, ...)
ani.pause()
idx = rank(apply(train, 1, function(x) sqrt(sum((x - test[i, ])^2))),
ties.method = 'random') %in% seq(k)
vote = cl[idx]
res = c(res, factor(names(which.max(table(vote))), levels = levels(clf), labels = levels(clf)))
pre.plot(
i, segments(train[, 1], train[, 2], test[i, 1], test[i, 2], lty = dist.lty, col = dist.col), ...
)
ani.pause()
bd = train[idx, 1:2]
pre.plot(
i, {
segments(train[, 1], train[, 2], test[i, 1], test[i,2], lty = dist.lty, col = dist.col)
if (k > 1) polygon(bd[chull(bd), ], density = 10, col = knn.col) else
points(bd[1], bd[2], col = knn.col, pch = cl.pch[unclass(clf)[idx]], cex = 2, lwd = 2)
}, ...)
ani.pause()
pre.plot(
i, {
segments(train[, 1], train[, 2], test[i, 1], test[i, 2], lty = dist.lty, col = dist.col)
if (k > 1) polygon(bd[chull(bd), ], density = 10, col = knn.col) else
points(bd[1], bd[2], col = knn.col, pch = cl.pch[unclass(clf)[idx]], cex = 2, lwd = 2)
points(test[i, 1], test[i, 2], col = tt.col[2], pch = cl.pch[unclass(res)[i]], cex = 3, lwd = 2)
}, FALSE, ...)
ani.pause()
}
invisible(levels(clf)[res])
}
|