File: t_NearestNeighbour1D_std.py

package info (click to toggle)
openturns 1.24-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 66,204 kB
  • sloc: cpp: 256,662; python: 63,381; ansic: 4,414; javascript: 406; sh: 180; xml: 164; yacc: 123; makefile: 98; lex: 55
file content (43 lines) | stat: -rwxr-xr-x 867 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#! /usr/bin/env python

import openturns as ot
import numpy as np
import os

ot.TESTPREAMBLE()


sample = ot.Normal().getSample(200)
tree = ot.NearestNeighbour1D(sample)
print("tree=", tree)

test = ot.Normal().getSample(100)

sample_np = np.array(sample)
test_np = np.array(test)


def nearest_debug(x):
    global sample_np
    a = np.sum(np.square(sample_np - x), axis=1)
    return int(np.argmin(a))


def nearest_debug_indices(x):
    global sample_np
    a = np.sum(np.square(sample_np - x), axis=1)
    return a.argsort()


neighbourIndices = tree.query(test)

neighbourIndices_np = [nearest_debug(x) for x in test]

if neighbourIndices != neighbourIndices_np:
    print("Errors found in query")
    os.exit(1)

for x in test:
    if np.any(nearest_debug_indices(x)[:10] != tree.queryK(x, 10, True)):
        print("Errors found in queryK")
        os.exit(1)