1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
|
# NeighborSearch tutorial (k-nearest-neighbors)
Nearest-neighbors search is a common machine learning task. In this setting, we
have a *query* and a *reference* dataset. For each point in the *query*
dataset, we wish to know the `k` points in the *reference* dataset which are
closest to the given query point.
Alternately, if the query and reference datasets are the same, the problem can
be stated more simply: for each point in the dataset, we wish to know the `k`
nearest points to that point.
mlpack provides:
- a simple command-line executable to run nearest-neighbors search (and
furthest-neighbors search)
- a simple C++ interface to perform nearest-neighbors search (and
furthest-neighbors search)
- a generic, extensible, and powerful C++ class (`NeighborSearch`) for complex
usage
## Command-line `mlpack_knn`
The simplest way to perform nearest-neighbors search in mlpack is to use the
`mlpack_knn` executable. *(Note that mlpack also provides bindings to other
languages, so, e.g., the `knn()` function is available in Python and Julia and
has the same options. So, any example here can be readily adapted to another
language that mlpack provides bindings for.)*
The `mlpack_knn` program will perform nearest-neighbors search and place the
resultant neighbors into one file and the resultant distances into another. The
output files are organized such that the first row corresponds to the nearest
neighbors of the first query point, with the first column corresponding to the
nearest neighbor, and so forth.
Below are several examples of simple usage (and the resultant output). The `-v`
option is used so that output is given. Further documentation on each
individual option can be found by typing
```sh
$ mlpack_knn --help
```
### One dataset, 5 nearest neighbors
```sh
$ mlpack_knn -r dataset.csv -n neighbors_out.csv -d distances_out.csv -k 5 -v
[INFO ] Loading 'dataset.csv' as CSV data. Size is 3 x 1000.
[INFO ] Loaded reference data from 'dataset.csv' (3 x 1000).
[INFO ] Building reference tree...
[INFO ] Tree built.
[INFO ] Searching for 5 nearest neighbors with dual-tree kd-tree search...
[INFO ] 18412 node combinations were scored.
[INFO ] 54543 base cases were calculated.
[INFO ] Search complete.
[INFO ] Saving CSV data to 'neighbors_out.csv'.
[INFO ] Saving CSV data to 'distances_out.csv'.
[INFO ]
[INFO ] Execution parameters:
[INFO ] distances_file: distances_out.csv
[INFO ] help: false
[INFO ] info: ""
[INFO ] input_model_file: ""
[INFO ] k: 5
[INFO ] leaf_size: 20
[INFO ] naive: false
[INFO ] neighbors_file: neighbors_out.csv
[INFO ] output_model_file: ""
[INFO ] query_file: ""
[INFO ] random_basis: false
[INFO ] reference_file: dataset.csv
[INFO ] seed: 0
[INFO ] single_mode: false
[INFO ] tree_type: kd
[INFO ] verbose: true
[INFO ] version: false
[INFO ]
[INFO ] Program timers:
[INFO ] computing_neighbors: 0.108968s
[INFO ] loading_data: 0.006495s
[INFO ] saving_data: 0.003843s
[INFO ] total_time: 0.126036s
[INFO ] tree_building: 0.003442s
```
Convenient program timers are given for different parts of the calculation at
the bottom of the output, as well as the parameters the simulation was run with.
Now, if we look at the output files:
```sh
$ head neighbors_out.csv
862,344,224,43,885
703,499,805,639,450
867,472,972,380,601
397,319,277,443,323
840,827,865,38,438
732,876,751,492,616
563,222,569,985,940
361,97,928,437,79
547,695,419,961,716
982,113,689,843,634
$ head distances_out.csv
5.986076164057e-02,7.664920518084e-02,1.116050961847e-01,1.155595474371e-01,1.169810085522e-01
7.532635022982e-02,1.012564715841e-01,1.127846944644e-01,1.209584396720e-01,1.216543647014e-01
7.659571546879e-02,1.014588981948e-01,1.025114621511e-01,1.128082429187e-01,1.131659758673e-01
2.079405647909e-02,4.710724516732e-02,7.597622408419e-02,9.171977778898e-02,1.037033340864e-01
7.082206779700e-02,9.002355499742e-02,1.044181406406e-01,1.093149568834e-01,1.139700558608e-01
5.688056488896e-02,9.478072514474e-02,1.085637706630e-01,1.114177921451e-01,1.139370265105e-01
7.882260880455e-02,9.454474078041e-02,9.724494179950e-02,1.023829575445e-01,1.066927013814e-01
7.005321598247e-02,9.131417221561e-02,9.498248889074e-02,9.897964162308e-02,1.121202216165e-01
5.295654132754e-02,5.509877761894e-02,8.108227366619e-02,9.785461174861e-02,1.043968140367e-01
3.992859920333e-02,4.471418646159e-02,7.346053904990e-02,9.181982339584e-02,9.843075910782e-02
```
So, the nearest neighbor to point 0 is point 862, with a distance of
`5.986076164057e-02`. The second nearest neighbor to point 0 is point 344, with
a distance of `7.664920518084e-02`. The third nearest neighbor to point 5 is
point 751, with a distance of `1.085637706630e-01`.
### Query and reference dataset, 10 nearest neighbors
```sh
$ mlpack_knn -q query_dataset.csv -r reference_dataset.csv \
> -n neighbors_out.csv -d distances_out.csv -k 10 -v
[INFO ] Loading 'reference_dataset.csv' as CSV data. Size is 3 x 1000.
[INFO ] Loaded reference data from 'reference_dataset.csv' (3 x 1000).
[INFO ] Building reference tree...
[INFO ] Tree built.
[INFO ] Loading 'query_dataset.csv' as CSV data. Size is 3 x 50.
[INFO ] Loaded query data from 'query_dataset.csv' (3x50).
[INFO ] Searching for 10 nearest neighbors with dual-tree kd-tree search...
[INFO ] Building query tree...
[INFO ] Tree built.
[INFO ] Search complete.
[INFO ] Saving CSV data to 'neighbors_out.csv'.
[INFO ] Saving CSV data to 'distances_out.csv'.
[INFO ]
[INFO ] Execution parameters:
[INFO ] distances_file: distances_out.csv
[INFO ] help: false
[INFO ] info: ""
[INFO ] input_model_file: ""
[INFO ] k: 10
[INFO ] leaf_size: 20
[INFO ] naive: false
[INFO ] neighbors_file: neighbors_out.csv
[INFO ] output_model_file: ""
[INFO ] query_file: query_dataset.csv
[INFO ] random_basis: false
[INFO ] reference_file: reference_dataset.csv
[INFO ] seed: 0
[INFO ] single_mode: false
[INFO ] tree_type: kd
[INFO ] verbose: true
[INFO ] version: false
[INFO ]
[INFO ] Program timers:
[INFO ] computing_neighbors: 0.022589s
[INFO ] loading_data: 0.003572s
[INFO ] saving_data: 0.000755s
[INFO ] total_time: 0.032197s
[INFO ] tree_building: 0.002590s
```
### One dataset, 3 nearest neighbors, leaf size of 15 points
```sh
$ mlpack_knn -r dataset.csv -n neighbors_out.csv -d distances_out.csv -k 3 -l 15 -v
[INFO ] Loading 'dataset.csv' as CSV data. Size is 3 x 1000.
[INFO ] Loaded reference data from 'dataset.csv' (3 x 1000).
[INFO ] Building reference tree...
[INFO ] Tree built.
[INFO ] Searching for 3 nearest neighbors with dual-tree kd-tree search...
[INFO ] 19692 node combinations were scored.
[INFO ] 36263 base cases were calculated.
[INFO ] Search complete.
[INFO ] Saving CSV data to 'neighbors_out.csv'.
[INFO ] Saving CSV data to 'distances_out.csv'.
[INFO ]
[INFO ] Execution parameters:
[INFO ] distances_file: distances_out.csv
[INFO ] help: false
[INFO ] info: ""
[INFO ] input_model_file: ""
[INFO ] k: 3
[INFO ] leaf_size: 15
[INFO ] naive: false
[INFO ] neighbors_file: neighbors_out.csv
[INFO ] output_model_file: ""
[INFO ] query_file: ""
[INFO ] random_basis: false
[INFO ] reference_file: dataset.csv
[INFO ] seed: 0
[INFO ] single_mode: false
[INFO ] tree_type: kd
[INFO ] verbose: true
[INFO ] version: false
[INFO ]
[INFO ] Program timers:
[INFO ] computing_neighbors: 0.059020s
[INFO ] loading_data: 0.002791s
[INFO ] saving_data: 0.002369s
[INFO ] total_time: 0.069277s
[INFO ] tree_building: 0.002713s
```
Further documentation on options should be found by using the `--help` option.
## The `KNN` class
The `KNN` class is, specifically, a typedef of the more extensible
`NeighborSearch` class, querying for nearest neighbors using the Euclidean
distance.
```c++
using KNN = NeighborSearch<NearestNeighborSort, EuclideanDistance>;
```
Using the `KNN` class is particularly simple; first, the object must be
constructed and given a dataset. Then, the method is run, and two matrices are
returned: one which holds the indices of the nearest neighbors, and one which
holds the distances of the nearest neighbors. These are of the same structure
as the output `--neighbors_file` and `--distances_file` for the command-line
program (see above). A handful of examples of simple usage of the KNN class are
given below.
### 5 nearest neighbors on a single dataset
```c++
#include <mlpack.hpp>
using namespace mlpack;
// Our dataset matrix, which is column-major.
extern arma::mat data;
KNN a(data);
// The matrices we will store output in.
arma::Mat<size_t> resultingNeighbors;
arma::mat resultingDistances;
a.Search(5, resultingNeighbors, resultingDistances);
```
The output of the search is stored in `resultingNeighbors` and
`resultingDistances`.
### 10 nearest neighbors on a query and reference dataset
```c++
#include <mlpack.hpp>
using namespace mlpack;
// Our dataset matrices, which are column-major.
extern arma::mat queryData, referenceData;
KNN a(referenceData);
// The matrices we will store output in.
arma::Mat<size_t> resultingNeighbors;
arma::mat resultingDistances;
a.Search(queryData, 10, resultingNeighbors, resultingDistances);
```
### Naive (exhaustive) search for 6 nearest neighbors on one dataset
This example uses the `O(n^2)` naive search (not the tree-based search).
```c++
#include <mlpack.hpp>
using namespace mlpack;
// Our dataset matrix, which is column-major.
extern arma::mat dataset;
KNN a(dataset, true);
// The matrices we will store output in.
arma::Mat<size_t> resultingNeighbors;
arma::mat resultingDistances;
a.Search(6, resultingNeighbors, resultingDistances);
```
Needless to say, naive search can be very slow...
## The extensible `NeighborSearch` class
The `NeighborSearch` class is very extensible, having the following template
arguments:
```c++
template<
typename SortPolicy = NearestNeighborSort,
typename DistanceType = EuclideanDistance,
typename MatType = arma::mat,
template<typename TreeDistanceType,
typename TreeStatType,
typename TreeMatType> class TreeType = KDTree,
template<typename RuleType> class TraversalType =
TreeType<DistanceType, NeighborSearchStat<SortPolicy>,
MatType>::template DualTreeTraverser>
>
class NeighborSearch;
```
By choosing different components for each of these template classes, a very
arbitrary neighbor searching object can be constructed. Note that each of these
template parameters have defaults, so it is not necessary to specify each one.
### `SortPolicy` policy class
The `SortPolicy` template parameter allows specification of how the
NeighborSearch object will decide which points are to be searched for. The
`NearestNeighborSort` class is a well-documented example. A custom `SortPolicy`
class must implement the same methods which `NearestNeighborSort` does:
```c++
static size_t SortDistance(const arma::vec& list, double newDistance);
static bool IsBetter(const double value, const double ref);
template<typename TreeType>
static double BestNodeToNodeDistance(const TreeType* queryNode,
const TreeType* referenceNode);
template<typename TreeType>
static double BestPointToNodeDistance(const arma::vec& queryPoint,
const TreeType* referenceNode);
static const double WorstDistance();
static const double BestDistance();
```
The `FurthestNeighborSort` class is another implementation, which is used to
create the `KFN` typedef class, which finds the furthest neighbors, as opposed
to the nearest neighbors.
## `DistanceType` policy class
The `DistanceType` policy class allows the neighbor search to take place in any
arbitrary metric space. The `LMetric` class is a good example implementation.
A `DistanceType` class must provide the following functions:
```c++
// Empty constructor is required.
DistanceType();
// Compute the distance between two points.
template<typename VecType>
double Evaluate(const VecType& a, const VecType& b);
```
Internally, the `NeighborSearch` class keeps an instantiated `DistanceType` class
(which can be given in the constructor). This is useful for a distance metric
like the Mahalanobis distance (`MahalanobisDistance`), which must store state
(the covariance matrix). Therefore, you can write a non-static DistanceType
class and use it seamlessly with `NeighborSearch`.
For more information on the `DistanceType` policy, see the [documentation for
`DistanceType`](../developer/distances.md).
### `MatType` policy class
The `MatType` template parameter specifies the type of data matrix used. This
type must implement the same operations as an Armadillo matrix, and so standard
choices are `arma::mat` and `arma::sp_mat`.
### `TreeType` policy class
The NeighborSearch class allows great extensibility in the selection of the type
of tree used for search. This type must follow the typical mlpack TreeType
policy, documented [here](../developer/trees.md).
Typical choices might include `KDTree`, `BallTree`, `StandardCoverTree`,
`RTree`, or `RStarTree`. It is easily possible to make your own tree type for
use with NeighborSearch; consult the [TreeType
documentation](../developer/trees.md) for more details.
An example of using the `NeighborSearch` class with a ball tree is given below.
```c++
// Construct a NeighborSearch object with ball bounds.
NeighborSearch<
NearestNeighborSort,
EuclideanDistance,
arma::mat,
BallTree
> neighborSearch(dataset);
```
### `TraverserType` policy class
The last template parameter the `NeighborSearch` class offers is the
`TraverserType` class. The `TraverserType` class holds the strategy used to
traverse the trees in either single-tree or dual-tree search mode. By default,
it is set to use the default traverser of the given `TreeType` (which is the
member `TreeType::DualTreeTraverser`).
This class must implement the following two methods:
```c++
// Instantiate with a given RuleType.
TraverserType(RuleType& rule);
// Traverse with two trees.
void Traverse(TreeType& queryNode, TreeType& referenceNode);
```
The `RuleType` class provides the following functions for use in the traverser:
```c++
// Evaluate the base case between two points.
double BaseCase(const size_t queryIndex, const size_t referenceIndex);
// Score the two nodes to see if they can be pruned, returning DBL_MAX if so.
double Score(TreeType& queryNode, TreeType& referenceNode);
```
Note also that any traverser given must satisfy the definition of a pruning
dual-tree traversal given in the paper "Tree-independent dual-tree algorithms".
## Further documentation
For further documentation on the NeighborSearch class, consult the comments in
the source code, found in `mlpack/methods/neighbor_search/`.
|