SPH
KdTree.inl.h
Go to the documentation of this file.
1 #include "objects/finders/KdTree.h" // including the header just to make syntax highlighting work and code navigation
2 
4 
5 enum class KdChild {
6  LEFT = 0,
7  RIGHT = 1,
8 };
9 
10 template <typename TNode, typename TMetric>
13 
14  static_assert(sizeof(LeafNode<TNode>) == sizeof(InnerNode<TNode>), "Sizes of nodes must match");
15 
16  // clean the current tree
17  const Size currentCnt = nodes.size();
18  this->init();
19 
20  for (const auto& i : iterateWithIndex(points)) {
21  entireBox.extend(i.value());
22  idxs.push(i.index());
23  }
24 
25  if (SPH_UNLIKELY(points.empty())) {
26  return;
27  }
28 
29  const Size nodeCnt = max(2 * points.size() / config.leafSize + 1, currentCnt);
30  nodes.resize(nodeCnt);
31 
32  SharedPtr<ITask> rootTask = scheduler.submit([this, &scheduler, points] {
33  this->buildTree(scheduler, ROOT_PARENT_NODE, KdChild(-1), 0, points.size(), entireBox, 0, 0);
34  });
35  rootTask->wait();
36 
37  // shrink nodes to only the constructed ones
38  nodes.resize(nodeCounter);
39 
40  SPH_ASSERT(this->sanityCheck(), this->sanityCheck().error());
41 }
42 
43 template <typename TNode, typename TMetric>
45  const Size parent,
46  const KdChild child,
47  const Size from,
48  const Size to,
49  const Box& box,
50  const Size slidingCnt,
51  const Size depth) {
52 
53  Box box1, box2;
54  Vector boxSize = box.size();
55 
56  // split by the dimension of largest extent
57  Size splitIdx = argMax(boxSize);
58 
59  bool slidingMidpoint = false;
60  bool degeneratedBox = false;
61 
62  if (to - from <= config.leafSize) {
63  // enough points to fit inside one leaf
64  this->addLeaf(parent, child, from, to);
65  return;
66  } else {
67  // check for singularity of dimensions
68  for (Size dim = 0; dim < 3; ++dim) {
69  if (this->isSingular(from, to, splitIdx)) {
70  boxSize[splitIdx] = 0.f;
71  // find new largest dimension
72  splitIdx = argMax(boxSize);
73 
74  if (boxSize == Vector(0._f)) {
75  // too many overlapping points, just split until they fit within a leaf,
76  // the code can handle this case, but it smells with an error ...
77  SPH_ASSERT(false, "Too many overlapping points, something is probably wrong ...");
78  degeneratedBox = true;
79  break;
80  }
81  } else {
82  break;
83  }
84  }
85 
86  // split around center of the box
87  Float splitPosition = box.center()[splitIdx];
88  std::make_signed_t<Size> n1 = from, n2 = to - 1; // use ints for easier for loop ending with 0
89 
90  if (slidingCnt <= 5 && !degeneratedBox) {
91  for (;; std::swap(idxs[n1], idxs[n2])) {
92  for (; n1 < int(to) && this->values[idxs[n1]][splitIdx] <= splitPosition; ++n1)
93  ;
94  for (; n2 >= int(from) && this->values[idxs[n2]][splitIdx] >= splitPosition; --n2)
95  ;
96  if (n1 >= n2) {
97  break;
98  }
99  }
100 
101  if (n1 == int(from)) {
102  Size idx = from;
103  splitPosition = this->values[idxs[from]][splitIdx];
104  for (Size i = from + 1; i < to; ++i) {
105  const Float x1 = this->values[idxs[i]][splitIdx];
106  if (x1 < splitPosition) {
107  idx = i;
108  splitPosition = x1;
109  }
110  }
111  std::swap(idxs[from], idxs[idx]);
112  n1++;
113  slidingMidpoint = true;
114  } else if (n1 == int(to)) {
115  Size idx = from;
116  splitPosition = this->values[idxs[from]][splitIdx];
117  for (Size i = from + 1; i < to; ++i) {
118  const Float x2 = this->values[idxs[i]][splitIdx];
119  if (x2 > splitPosition) {
120  idx = i;
121  splitPosition = x2;
122  }
123  }
124  std::swap(idxs[to - 1], idxs[idx]);
125  n1--;
126  slidingMidpoint = true;
127  }
128 
129  tie(box1, box2) = box.split(splitIdx, splitPosition);
130  } else {
131  n1 = (from + to) >> 1;
132  // do quick select to sort elements around the midpoint
133  Iterator<Size> iter = idxs.begin();
134  if (!degeneratedBox) {
135  std::nth_element(iter + from, iter + n1, iter + to, [this, splitIdx](Size i1, Size i2) {
136  return this->values[i1][splitIdx] < this->values[i2][splitIdx];
137  });
138  }
139 
140  tie(box1, box2) = box.split(splitIdx, this->values[idxs[n1]][splitIdx]);
141  }
142 
143  // sanity check
144  SPH_ASSERT(this->checkBoxes(from, to, n1, box1, box2));
145 
146  // add inner node and connect it to the parent
147  const Size index = this->addInner(parent, child, splitPosition, splitIdx);
148 
149  // recurse to left and right subtree
150  const Size nextSlidingCnt = slidingMidpoint ? slidingCnt + 1 : 0;
151  auto processRightSubTree = [this, &scheduler, index, to, n1, box2, nextSlidingCnt, depth] {
152  this->buildTree(scheduler, index, KdChild::RIGHT, n1, to, box2, nextSlidingCnt, depth + 1);
153  };
154  if (depth < config.maxParallelDepth) {
155  // ad hoc decision - split the build only for few topmost nodes, there is no point in splitting
156  // the work for child node in the bottom, it would only overburden the ThreadPool.
157  scheduler.submit(processRightSubTree);
158  } else {
159  // otherwise simply process both subtrees in the same thread
160  processRightSubTree();
161  }
162  this->buildTree(scheduler, index, KdChild::LEFT, from, n1, box1, nextSlidingCnt, depth + 1);
163  }
164 }
165 
166 template <typename TNode, typename TMetric>
167 void KdTree<TNode, TMetric>::addLeaf(const Size parent, const KdChild child, const Size from, const Size to) {
168  const Size index = nodeCounter++;
169  if (index >= nodes.size()) {
170  // needs more nodes than estimated; allocate up to 2x more than necessary to avoid frequent
171  // reallocations
172  nodesMutex.lock();
173  nodes.resize(max(2 * index, nodes.size()));
174  nodesMutex.unlock();
175  }
176 
177  nodesMutex.lock_shared();
178  auto releaseLock = finally([this] { nodesMutex.unlock_shared(); });
179 
180  LeafNode<TNode>& node = (LeafNode<TNode>&)nodes[index];
181  node.type = KdNode::Type::LEAF;
182  SPH_ASSERT(node.isLeaf());
183 
184 #ifdef SPH_DEBUG
185  node.from = node.to = -1;
186 #endif
187 
188  node.from = from;
189  node.to = to;
190 
191  // find the bounding box of the leaf
192  Box box;
193  for (Size i = from; i < to; ++i) {
194  box.extend(this->values[idxs[i]]);
195  }
196  node.box = box;
197 
198  if (parent == ROOT_PARENT_NODE) {
199  return;
200  }
201  InnerNode<TNode>& parentNode = (InnerNode<TNode>&)nodes[parent];
202  SPH_ASSERT(!parentNode.isLeaf());
203  if (child == KdChild::LEFT) {
204  // left child
205  parentNode.left = index;
206  } else {
207  SPH_ASSERT(child == KdChild::RIGHT);
208  // right child
209  parentNode.right = index;
210  }
211 }
212 
213 template <typename TNode, typename TMetric>
215  const KdChild child,
216  const Float splitPosition,
217  const Size splitIdx) {
218  static_assert(int(KdNode::Type::X) == 0 && int(KdNode::Type::Y) == 1 && int(KdNode::Type::Z) == 2,
219  "Invalid values of KdNode::Type enum");
220 
221  const Size index = nodeCounter++;
222  if (index >= nodes.size()) {
223  // needs more nodes than estimated; allocate up to 2x more than necessary to avoid frequent
224  // reallocations
225  nodesMutex.lock();
226  nodes.resize(max(2 * index, nodes.size()));
227  nodesMutex.unlock();
228  }
229 
230  nodesMutex.lock_shared();
231  auto releaseLock = finally([this] { nodesMutex.unlock_shared(); });
232  InnerNode<TNode>& node = (InnerNode<TNode>&)nodes[index];
233  node.type = KdNode::Type(splitIdx);
234  SPH_ASSERT(!node.isLeaf());
235 
236 #ifdef SPH_DEBUG
237  node.left = node.right = -1;
238  node.box = Box(); // will be computed later
239 #endif
240 
241  node.splitPosition = float(splitPosition);
242 
243  if (parent == ROOT_PARENT_NODE) {
244  // no need to set up parents
245  return index;
246  }
247  InnerNode<TNode>& parentNode = (InnerNode<TNode>&)nodes[parent];
248  if (child == KdChild::LEFT) {
249  // left child
250  SPH_ASSERT(parentNode.left == Size(-1));
251  parentNode.left = index;
252  } else {
253  SPH_ASSERT(child == KdChild::RIGHT);
254  // right child
255  SPH_ASSERT(parentNode.right == Size(-1));
256  parentNode.right = index;
257  }
258 
259  return index;
260 }
261 
262 template <typename TNode, typename TMetric>
264  entireBox = Box();
265  idxs.clear();
266  nodes.clear();
267  nodeCounter = 0;
268 }
269 
270 template <typename TNode, typename TMetric>
271 bool KdTree<TNode, TMetric>::isSingular(const Size from, const Size to, const Size splitIdx) const {
272  for (Size i = from; i < to; ++i) {
273  if (this->values[idxs[i]][splitIdx] != this->values[idxs[to - 1]][splitIdx]) {
274  return false;
275  }
276  }
277  return true;
278 }
279 
280 template <typename TNode, typename TMetric>
282  const Size to,
283  const Size mid,
284  const Box& box1,
285  const Box& box2) const {
286  for (Size i = from; i < to; ++i) {
287  if (i < mid && !box1.contains(this->values[idxs[i]])) {
288  return false;
289  }
290  if (i >= mid && !box2.contains(this->values[idxs[i]])) {
291  return false;
292  }
293  }
294  return true;
295 }
296 
303 
305 
307 };
308 
312 extern thread_local Array<ProcessedNode> nodeStack;
313 
314 template <typename TNode, typename TMetric>
315 template <bool FindAll>
317  const Size index,
318  const Float radius,
319  Array<NeighbourRecord>& neighbours) const {
320 
321  SPH_ASSERT(neighbours.empty());
322  const Float radiusSqr = sqr(radius);
323  const Vector maxDistSqr = sqr(max(Vector(0._f), entireBox.lower() - r0, r0 - entireBox.upper()));
324 
325  // L1 norm
326  const Float l1 = l1Norm(maxDistSqr);
327  ProcessedNode node{ 0, maxDistSqr, l1 };
328 
329  SPH_ASSERT(nodeStack.empty()); // not sure if there can be some nodes from previous search ...
330 
331  TMetric metric;
332  while (node.distanceSqr < radiusSqr) {
333  if (nodes[node.idx].isLeaf()) {
334  // for leaf just add all
335  const LeafNode<TNode>& leaf = (const LeafNode<TNode>&)nodes[node.idx];
336  if (leaf.size() > 0) {
337  const Float leafDistSqr =
338  metric(max(Vector(0._f), leaf.box.lower() - r0, r0 - leaf.box.upper()));
339  if (leafDistSqr < radiusSqr) {
340  // leaf intersects the sphere
341  for (Size i = leaf.from; i < leaf.to; ++i) {
342  const Size actIndex = idxs[i];
343  const Float distSqr = metric(this->values[actIndex] - r0);
344  if (distSqr < radiusSqr && (FindAll || this->rank[actIndex] < this->rank[index])) {
346  neighbours.push(NeighbourRecord{ actIndex, distSqr });
347  }
348  }
349  }
350  }
351  if (nodeStack.empty()) {
352  break;
353  }
354  node = nodeStack.pop();
355  } else {
356  // inner node
357  const InnerNode<TNode>& inner = (InnerNode<TNode>&)nodes[node.idx];
358  const Size splitDimension = Size(inner.type);
359  SPH_ASSERT(splitDimension < 3);
360  const Float splitPosition = inner.splitPosition;
361  if (r0[splitDimension] < splitPosition) {
362  // process left subtree, put right on stack
363  ProcessedNode right = node;
364  node.idx = inner.left;
365 
366  const Float dx = splitPosition - r0[splitDimension];
367  right.distanceSqr += sqr(dx) - right.sizeSqr[splitDimension];
368  right.sizeSqr[splitDimension] = sqr(dx);
369  if (right.distanceSqr < radiusSqr) {
370  const InnerNode<TNode>& next = (const InnerNode<TNode>&)nodes[right.idx];
371  right.idx = next.right;
372  nodeStack.push(right);
373  }
374  } else {
375  // process right subtree, put left on stack
376  ProcessedNode left = node;
377  node.idx = inner.right;
378  const Float dx = splitPosition - r0[splitDimension];
379  left.distanceSqr += sqr(dx) - left.sizeSqr[splitDimension];
380  left.sizeSqr[splitDimension] = sqr(dx);
381  if (left.distanceSqr < radiusSqr) {
382  const InnerNode<TNode>& next = (const InnerNode<TNode>&)nodes[left.idx];
383  left.idx = next.left;
384  nodeStack.push(left);
385  }
386  }
387  }
388  }
389 
390  return neighbours.size();
391 }
392 
393 template <typename TNode, typename TMetric>
395  if (this->values.size() != idxs.size()) {
396  return makeFailed("Number of values does not match the number of indices");
397  }
398 
399  // check bounding box
400  for (const Vector& v : this->values) {
401  if (!entireBox.contains(v)) {
402  return makeFailed("Points are not strictly within the bounding box");
403  }
404  }
405 
406  // check node connectivity
407  Size counter = 0;
408  std::set<Size> indices;
409 
410  Function<Outcome(const Size idx)> countNodes = [this, &indices, &counter, &countNodes](
411  const Size idx) -> Outcome {
412  // count this
413  counter++;
414 
415  // check index validity
416  if (idx >= nodes.size()) {
417  return makeFailed("Invalid index found: ", idx, " (", nodes.size(), ")");
418  }
419 
420  // if inner node, count children
421  if (!nodes[idx].isLeaf()) {
422  const InnerNode<TNode>& inner = (const InnerNode<TNode>&)nodes[idx];
423  return countNodes(inner.left) && countNodes(inner.right);
424  } else {
425  // check that all points fit inside the bounding box of the leaf
426  const LeafNode<TNode>& leaf = (const LeafNode<TNode>&)nodes[idx];
427  if (leaf.to == leaf.from) {
428  // empty leaf?
429  return makeFailed("Empty leaf: ", leaf.to);
430  }
431  for (Size i = leaf.from; i < leaf.to; ++i) {
432  if (!leaf.box.contains(this->values[idxs[i]])) {
433  return makeFailed("Leaf points do not fit inside the bounding box");
434  }
435  if (indices.find(i) != indices.end()) {
436  // child referenced twice?
437  return makeFailed("Index repeated: ", i);
438  }
439  indices.insert(i);
440  }
441  }
442  return SUCCESS;
443  };
444  const Outcome result = countNodes(0);
445  if (!result) {
446  return result;
447  }
448  // we should count exactly nodes.size()
449  if (counter != nodes.size()) {
450  return makeFailed("Unexpected number of nodes: ", counter, " == ", nodes.size());
451  }
452  // each index should have been inserted exactly once
453  Size i = 0;
454  for (Size idx : indices) {
455  // std::set is sorted, so we can check sequentially
456  if (idx != i) {
457  return makeFailed("Invalid index: ", idx, " == ", i);
458  }
459  ++i;
460  }
461  return SUCCESS;
462 }
463 
464 template <IterateDirection Dir, typename TNode, typename TMetric, typename TFunctor>
466  IScheduler& scheduler,
467  const TFunctor& functor,
468  const Size nodeIdx,
469  const Size depthLimit) {
470  TNode& node = tree.getNode(nodeIdx);
471  if (Dir == IterateDirection::TOP_DOWN) {
472  if (node.isLeaf()) {
473  functor(node, nullptr, nullptr);
474  } else {
475  InnerNode<TNode>& inner = reinterpret_cast<InnerNode<TNode>&>(node);
476  if (!functor(inner, &tree.getNode(inner.left), &tree.getNode(inner.right))) {
477  return;
478  }
479  }
480  }
481  SharedPtr<ITask> task;
482  if (!node.isLeaf()) {
483  InnerNode<TNode>& inner = reinterpret_cast<InnerNode<TNode>&>(node);
484 
485  const Size newDepth = depthLimit == 0 ? 0 : depthLimit - 1;
486  auto iterateRightSubtree = [&tree, &scheduler, &functor, &inner, newDepth] {
487  iterateTree<Dir>(tree, scheduler, functor, inner.right, newDepth);
488  };
489  if (newDepth > 0) {
490  task = scheduler.submit(iterateRightSubtree);
491  } else {
492  iterateRightSubtree();
493  }
494  iterateTree<Dir>(tree, scheduler, functor, inner.left, newDepth);
495  }
496  if (task) {
497  task->wait();
498  }
499  if (Dir == IterateDirection::BOTTOM_UP) {
500  if (node.isLeaf()) {
501  functor(node, nullptr, nullptr);
502  } else {
503  InnerNode<TNode>& inner = reinterpret_cast<InnerNode<TNode>&>(node);
504  functor(inner, &tree.getNode(inner.left), &tree.getNode(inner.right));
505  }
506  }
507 }
508 
510 template <IterateDirection Dir, typename TNode, typename TMetric, typename TFunctor>
512  IScheduler& scheduler,
513  const TFunctor& functor,
514  const Size nodeIdx,
515  const Size depthLimit) {
516  // use non-const overload using const_cast, but call the functor with const reference
517  auto actFunctor = [&functor](TNode& node, TNode* left, TNode* right)
518  INL { return functor(asConst(node), left, right); };
519  iterateTree<Dir>(const_cast<KdTree<TNode, TMetric>&>(tree), scheduler, actFunctor, nodeIdx, depthLimit);
520 }
521 
#define SPH_ASSERT(x,...)
Definition: Assert.h:94
NAMESPACE_SPH_BEGIN
Definition: BarnesHut.cpp:13
const float radius
Definition: CurveDialog.cpp:18
uint32_t Size
Integral type used to index arrays (by default).
Definition: Globals.h:16
double Float
Precision used withing the code. Use Float instead of float or double where precision is important.
Definition: Globals.h:13
IndexAdapter< TContainer > iterateWithIndex(TContainer &&container)
K-d tree for efficient search of neighbouring particles.
@ BOTTOM_UP
From leaves to root.
@ TOP_DOWN
From root to leaves.
thread_local Array< ProcessedNode > nodeStack
Cached stack to avoid reallocation.
Definition: KdTree.cpp:5
KdChild
Definition: KdTree.inl.h:5
void iterateTree(KdTree< TNode, TMetric > &tree, IScheduler &scheduler, const TFunctor &functor, const Size nodeIdx, const Size depthLimit)
Calls a functor for every node of a K-d tree tree in specified direction.
Definition: KdTree.inl.h:465
#define VERBOSE_LOG
Helper macro, creating.
Definition: Logger.h:242
constexpr INLINE T max(const T &f1, const T &f2)
Definition: MathBasic.h:20
constexpr INLINE T sqr(const T &f) noexcept
Return a squared value.
Definition: MathUtils.h:67
#define SPH_UNLIKELY(x)
Definition: Object.h:50
#define NAMESPACE_SPH_END
Definition: Object.h:12
#define INL
Definition: Object.h:32
BasicOutcome< std::string > Outcome
Alias for string error message.
Definition: Outcome.h:138
const SuccessTag SUCCESS
Global constant for successful outcome.
Definition: Outcome.h:141
INLINE Outcome makeFailed(TArgs &&... args)
Constructs failed object with error message.
Definition: Outcome.h:157
StaticArray< T0 &, sizeof...(TArgs)+1 > tie(T0 &t0, TArgs &... rest)
Creates a static array from a list of l-value references.
Definition: StaticArray.h:281
INLINE const T & asConst(T &ref)
Converts a non-const reference to const one.
Definition: Traits.h:237
INLINE Size argMax(const Vector &v)
Returns the index of the maximum element.
Definition: Vector.h:697
BasicVector< Float > Vector
Definition: Vector.h:539
INLINE Float l1Norm(const Vector &v)
Returns the L1 norm (sum of absolute values) of the vector.
Definition: Vector.h:729
Object providing safe access to continuous memory of data.
Definition: ArrayView.h:17
INLINE bool empty() const
Definition: ArrayView.h:105
INLINE TCounter size() const
Definition: ArrayView.h:101
Generic dynamically allocated resizable storage.
Definition: Array.h:43
INLINE void push(U &&u)
Adds new element to the end of the array, resizing the array if necessary.
Definition: Array.h:306
INLINE TCounter size() const noexcept
Definition: Array.h:193
INLINE bool empty() const noexcept
Definition: Array.h:201
Helper object defining three-dimensional interval (box).
Definition: Box.h:17
INLINE void extend(const Vector &v)
Enlarges the box to contain the vector.
Definition: Box.h:49
INLINE Vector center() const
Returns the center of the box.
Definition: Box.h:112
INLINE bool contains(const Vector &v) const
Checks if the vector lies inside the box.
Definition: Box.h:66
INLINE Vector size() const
Returns box dimensions.
Definition: Box.h:106
INLINE Pair< Box > split(const Size dim, const Float x) const
Splits the box along given coordinate.
Definition: Box.h:144
Interface that allows unified implementation of sequential and parallelized versions of algorithms.
Definition: Scheduler.h:27
virtual SharedPtr< ITask > submit(const Function< void()> &task)=0
Submits a task to be potentially executed asynchronously.
Simple (forward) iterator over continuous array of objects of type T.
Definition: Iterator.h:18
K-d tree, used for hierarchical clustering of particles and accelerated Kn queries.
Definition: KdTree.h:136
INLINE TNode & getNode(const Size nodeIdx)
Returns the node with given index.
Definition: KdTree.h:169
virtual void buildImpl(IScheduler &scheduler, ArrayView< const Vector > points) override
Builds finder from set of vectors.
Definition: KdTree.inl.h:11
Size find(const Vector &pos, const Size index, const Float radius, Array< NeighbourRecord > &neighs) const
Definition: KdTree.inl.h:316
Outcome sanityCheck() const
Performs some checks of KdTree consistency, returns SUCCESS if everything is OK.
Definition: KdTree.inl.h:394
void swap(Sph::Array< T, TCounter > &ar1, Sph::Array< T, TCounter > &ar2)
Definition: Array.h:580
Inner node of K-d tree.
Definition: KdTree.h:42
Size right
Index of right child node.
Definition: KdTree.h:50
Size left
Index of left child node.
Definition: KdTree.h:47
float splitPosition
Position where the selected dimension is split.
Definition: KdTree.h:44
Type
Here X, Y, Z must be 0, 1, 2.
Definition: KdTree.h:26
Leaf (bucket) node of K-d tree.
Definition: KdTree.h:61
Size to
One-past-last index of particles belonging to the leaf.
Definition: KdTree.h:66
INLINE Size size() const
Returns the number of points in the leaf. Can be zero.
Definition: KdTree.h:75
Size from
First index of particlse belonging to the leaf.
Definition: KdTree.h:63
Holds information about a neighbour particles.
Object used during traversal.
Definition: KdTree.inl.h:300
Float distanceSqr
Definition: KdTree.inl.h:306
Size idx
Index into the nodeStack array. We cannot use pointers because the array might get reallocated.
Definition: KdTree.inl.h:302
Vector sizeSqr
Definition: KdTree.inl.h:304