File: KD_Tree.cpp

package info (click to toggle)
lammps 20220106.git7586adbb6a%2Bds1-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 348,064 kB
  • sloc: cpp: 831,421; python: 24,896; xml: 14,949; f90: 10,845; ansic: 7,967; sh: 4,226; perl: 4,064; fortran: 2,424; makefile: 1,501; objc: 238; lisp: 163; csh: 16; awk: 14; tcl: 6
file content (169 lines) | stat: -rw-r--r-- 5,567 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#include "KD_Tree.h"
#include <cassert>

using std::vector;

KD_Tree *KD_Tree::create_KD_tree(const int nNodesPerElem, const int nNodes,
                                 const DENS_MAT *nodalCoords, const int nElems,
                                 const Array2D<int> &conn) {
  vector<Node> *points = new vector<Node>(); // Initialize an empty list of Nodes
  for (int node = 0; node < nNodes; node++) { // Insert all nodes into list
    points->push_back(Node(node, (*nodalCoords)(0, node),
                                 (*nodalCoords)(1, node),
                                 (*nodalCoords)(2, node)));
  }
  vector<Elem> *elements = new vector<Elem>();
  for (int elem = 0; elem < nElems; elem++) {
    vector<Node> nodes = vector<Node>();
    for (int node = 0; node < nNodesPerElem; node++) {
      nodes.push_back((*points)[conn(node, elem)]);
    }
    elements->push_back(Elem(elem, nodes));
  }
  return new KD_Tree(points, elements);
}

KD_Tree::~KD_Tree() {
  delete sortedPts_;
  delete candElems_;
  delete leftChild_;
  delete rightChild_;
}

KD_Tree::KD_Tree(vector<Node> *points, vector<Elem> *elements,
                 int dimension)
  : candElems_(elements) {
  // Set up comparison functions
  bool (*compare)(Node, Node);
  if (dimension == 0) {
    compare = Node::compareX;
  } else if (dimension == 1) {
    compare = Node::compareY;
  } else {
    compare = Node::compareZ;
  }

  // Sort points by their coordinate in the current dimension
  sort(points->begin(), points->end(), compare);
  sortedPts_ = points;

  // Pick the median point as the root of the tree
  size_t nNodes = points->size();
  size_t med = nNodes/2;
  value_ = (*sortedPts_)[med];

  // Recursively construct the left sub-tree
  vector<Node> *leftPts   = new vector<Node>;
  vector<Elem> *leftElems = new vector<Elem>;
  // Recursively construct the right sub-tree
  vector<Node> *rightPts   = new vector<Node>;
  vector<Elem> *rightElems = new vector<Elem>;
  for (vector<Elem>::iterator elit = candElems_->begin();
       elit != candElems_->end(); elit++) {
    // Identify elements that should be kept on either side
    bool foundElemLeft  = false;
    bool foundElemRight = false;
    for (vector<Node>::iterator ndit = elit->second.begin();
         ndit != elit->second.end(); ndit++) {
      // Search this node
      if (compare(*ndit, value_)) {
        if (find(leftPts->begin(), leftPts->end(), *ndit) == leftPts->end()) {
          leftPts->push_back(*ndit);
        }
        foundElemLeft = true;
      }
      if (compare(value_, *ndit)) {
        if (find(rightPts->begin(), rightPts->end(), *ndit) == rightPts->end()) {
          rightPts->push_back(*ndit);
        }
        foundElemRight = true;
      }
    }
    if (foundElemLeft) leftElems->push_back(*elit);
    if (foundElemRight) rightElems->push_back(*elit);
  }

  // Create child tree, or nullptr if there's nothing to create
  if (candElems_->size() - leftElems->size() < 4 || leftElems->size() == 0) {
    leftChild_ = nullptr;
    delete leftPts;
    delete leftElems;
  } else {
    leftChild_ = new KD_Tree(leftPts, leftElems, (dimension+1) % 3);
  }
  // Create child tree, or nullptr if there's nothing to create
  if (candElems_->size() - rightElems->size() < 4 || rightElems->size() == 0) {
    rightChild_ = nullptr;
    delete rightPts;
    delete rightElems;
  } else {
    rightChild_ = new KD_Tree(rightPts, rightElems, (dimension+1) % 3);
  }
}

vector<int> KD_Tree::find_nearest_elements(Node query, int dimension) {
  // if the root coordinate is less than the query coordinate


  // If the query point is less that the value (split) point of this
  // tree, either recurse to the left or return this node's elements
  // if there is no left child.
  if (query.lessThanInDimension(value_, dimension)) {
    if (leftChild_ == nullptr) {
      vector<int> result = vector<int>();
      for (vector<Elem>::iterator elem = candElems_->begin();
           elem != candElems_->end(); elem++) {
        result.push_back(elem->first);
      }
      return result;
    }
    return leftChild_->find_nearest_elements(query, (dimension+1) % 3);
  } else {
    if (rightChild_ == nullptr) {
      vector<int> result = vector<int>();
      for (vector<Elem>::iterator elem = candElems_->begin();
           elem != candElems_->end(); elem++) {
        result.push_back(elem->first);
      }
      return result;
    }
    return rightChild_->find_nearest_elements(query, (dimension+1) % 3);
  }
}

vector<vector<int> > KD_Tree::getElemIDs(int depth) {

  vector<vector<int> > result;
  vector<vector<int> > temp;

  assert(depth >= 0 );
  if (depth == 0) {
    vector<int> candElemIDs;
    vector<Elem>::iterator it;
    for(it = candElems_->begin(); it != candElems_->end(); ++it) {
      candElemIDs.push_back((*it).first);
    }

    sort(candElemIDs.begin(), candElemIDs.end());
    result.push_back(candElemIDs);

  } else if (leftChild_ == nullptr || rightChild_ == nullptr) {
    // Insert all nodes at this level once,
    // then insert a bunch of empty vectors.
    temp = this->getElemIDs(0);
    result.insert(result.end(), temp.begin(), temp.end());
    int numRequested = floor(pow(2,depth));
    for (int i = 0; i < numRequested - 1; ++i) {
      vector<int> emptyVec;
      result.push_back(emptyVec);
    }
  } else {
    --depth;
    temp = leftChild_->getElemIDs(depth);
    result.insert(result.end(), temp.begin(), temp.end());
    temp = rightChild_->getElemIDs(depth);
    result.insert(result.end(), temp.begin(), temp.end());
  }

  return result;
}