File: KDTree.cpp

package info (click to toggle)
stopt 5.12%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 8,860 kB
  • sloc: cpp: 70,456; python: 5,950; makefile: 72; sh: 57
file content (201 lines) | stat: -rw-r--r-- 4,788 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
#include <iterator>
#include <iostream>
#include <vector>
#include <Eigen/Dense>
#include "KDTree.h"

using namespace std;
using namespace Eigen;


namespace StOpt
{

/// compare in a given dimension 2 points
class comparePt
{
public:
    size_t m_idx;
    explicit comparePt(size_t p_idx): m_idx(p_idx) {}
    // inline bool compareInDim(
    bool operator()(
        const pair< ArrayXd, size_t > &a,
        const pair< ArrayXd, size_t > &b
    )
    {
        return (a.first(m_idx) < b.first(m_idx));
    }
};


inline double dist2(const ArrayXd &a, const ArrayXd &b)
{
    double distc = 0;
    for (int i = 0; i < a.size(); i++)
    {
        double di = a(i) - b(i);
        distc += di * di;
    }
    return distc;
}

inline double dist2(const shared_ptr< KDNode >  &a, const shared_ptr< KDNode > &b)
{
    return dist2(a->getPoint(), b->getPoint());
}


KDTree::KDTree(const ArrayXXd &points)
{
    m_leaf = make_shared<KDNode>();

    vector< pair< ArrayXd, size_t> > vecPoints(points.cols());
    for (int i = 0; i < points.cols(); ++i)
        vecPoints[i] =  make_pair(points.col(i), i);

    auto beg = vecPoints.begin();
    auto end = vecPoints.end();
    int level = 0;
    m_root = createTree(beg, end, vecPoints.size(), level);


}


shared_ptr<KDNode> KDTree::createTree(const vector<pair< ArrayXd, size_t>>::iterator &p_beg,
                                      const vector<pair< ArrayXd, size_t>>::iterator   &p_end,
                                      const size_t &p_nbPoints,
                                      const size_t &p_level)
{
    if (p_beg == p_end)
    {
        return shared_ptr< KDNode >();  // empty tree
    }
    size_t dim = p_beg->first.size();

    if (p_nbPoints > 1)
    {
        sort(p_beg, p_end, comparePt(p_level));
    }

    auto middle = p_beg + (p_nbPoints / 2);

    auto lbeg = p_beg;
    auto lend = middle;
    auto rbeg = middle + 1;
    auto rend = p_end;

    size_t llen = p_nbPoints / 2;
    size_t rlen = p_nbPoints - llen - 1;

    shared_ptr< KDNode >  left;
    if (llen > 0 && dim > 0)
    {
        left = createTree(lbeg, lend, llen, (p_level + 1) % dim);
    }
    else
    {
        left = m_leaf;
    }
    shared_ptr< KDNode > right;
    if (rlen > 0 && dim > 0)
    {
        right = createTree(rbeg, rend, rlen, (p_level + 1) % dim);
    }
    else
    {
        right = m_leaf;
    }

    // KDNode result = KDNode();
    return make_shared< KDNode >(*middle, left, right);
}

shared_ptr< KDNode > KDTree::nearest(
    const shared_ptr< KDNode > &p_branch,
    const ArrayXd &p_pt,
    const size_t &p_level,
    const shared_ptr< KDNode > &p_best,
    const double &p_bestDist) const
{
    double d, dx, dx2;

    if (p_branch->isEmpty())
    {
        return make_shared<KDNode>();  // basically, null
    }

    ArrayXd  branchPt = p_branch->getPoint();
    size_t dim = branchPt.size();

    d = dist2(branchPt, p_pt);
    dx = branchPt(p_level) - p_pt(p_level);
    dx2 = dx * dx;

    shared_ptr< KDNode > bestLoc = p_best;
    double bestDistLoc = p_bestDist;

    if (d < p_bestDist)
    {
        bestDistLoc = d;
        bestLoc = p_branch;
    }

    size_t nextLevel = (p_level + 1) % dim;
    shared_ptr< KDNode > section;
    shared_ptr< KDNode > other;

    // select which p_branch  to check
    if (dx > 0)
    {
        section = p_branch->getLeft();
        other = p_branch->getRight();
    }
    else
    {
        section = p_branch->getRight();
        other = p_branch->getLeft();
    }

    // keep nearest neighbor from further down the tree
    shared_ptr< KDNode > further = nearest(section, p_pt, nextLevel, bestLoc, bestDistLoc);
    if (!further->isEmpty())
    {
        double dl = dist2(further->getPoint(), p_pt);
        if (dl < bestDistLoc)
        {
            bestDistLoc = dl;
            bestLoc = further;
        }
        // only check the other p_branch if it makes sense to do so
        if (dx2 < bestDistLoc)
        {
            further = nearest(other, p_pt, nextLevel, bestLoc, bestDistLoc);
            if (!further->isEmpty())
            {
                dl = dist2(further->getPoint(), p_pt);
                if (dl < bestDistLoc)
                {
                    bestDistLoc = dl;
                    bestLoc = further;
                }
            }
        }
    }

    return bestLoc;
}


shared_ptr< KDNode > KDTree::nearestNode(const ArrayXd   &p_pt) const
{
    size_t level = 0;
    double branchDist = dist2(m_root->getPoint(), p_pt);
    return nearest(m_root,          // beginning of tree
                   p_pt,        // point we are querying
                   level,         // start from level 0
                   m_root,          // best is the root
                   branchDist);  // best_dist = branch_dist
}

}