Alexandria  2.25.0
SDC-CH common library for the Euclid project
KdTree.icpp
Go to the documentation of this file.
1 /** Copyright © 2021 Université de Genève, LMU Munich - Faculty of Physics, IAP-CNRS/Sorbonne Université
2  *
3  * This library is free software; you can redistribute it and/or modify it under
4  * the terms of the GNU Lesser General Public License as published by the Free
5  * Software Foundation; either version 3.0 of the License, or (at your option)
6  * any later version.
7  *
8  * This library is distributed in the hope that it will be useful, but WITHOUT
9  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
10  * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
11  * details.
12  *
13  * You should have received a copy of the GNU Lesser General Public License
14  * along with this library; if not, write to the Free Software Foundation, Inc.,
15  * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16  */
17 
18 #include <stdexcept>
19 
20 namespace KdTree {
21 
22 template <typename T, typename DistanceMethod>
23 class KdTree<T, DistanceMethod>::Node {
24 public:
25  virtual void findPointsWithinRadius(const T& coord, double radius, std::vector<T>& selection) const = 0;
26  virtual std::size_t countPointsWithinRadius(const T& coord, double radius) const = 0;
27  virtual ~Node() = default;
28 };
29 
30 template <typename T, typename DistanceMethod>
31 class KdTree<T, DistanceMethod>::Leaf : public KdTree::Node {
32 public:
33  explicit Leaf(const std::vector<T>&& data) : m_data(data) {}
34  virtual ~Leaf() = default;
35 
36  void findPointsWithinRadius(const T& coord, double radius, std::vector<T>& selection) const override {
37  selection.reserve(selection.size() + m_data.size());
38  for (auto& entry : m_data) {
39  if (DistanceMethod::isCloserThan(entry, coord, radius)) {
40  selection.emplace_back(entry);
41  }
42  }
43  }
44 
45  std::size_t countPointsWithinRadius(const T& coord, double radius) const override {
46  std::size_t count = 0;
47  for (auto& entry : m_data) {
48  if (DistanceMethod::isCloserThan(entry, coord, radius)) {
49  ++count;
50  }
51  }
52  return count;
53  };
54 
55 private:
56  const std::vector<T> m_data;
57 };
58 
59 template <typename T, typename DistanceMethod>
60 class KdTree<T, DistanceMethod>::Split : public KdTree::Node {
61 public:
62  virtual ~Split() = default;
63  explicit Split(std::size_t dimensionality, std::size_t leaf_size, std::vector<T> data, size_t axis) : m_axis(axis) {
64  std::sort(data.begin(), data.end(),
65  [axis](const T& a, const T& b) -> bool { return Traits::getCoord(a, axis) < Traits::getCoord(b, axis); });
66 
67  double a = Traits::getCoord(data.at(data.size() / 2 - 1), axis);
68  double b = Traits::getCoord(data.at(data.size() / 2), axis);
69 
70  if (a == b) {
71  // avoid a possible rounding issue
72  m_split_value = a;
73  } else {
74  m_split_value = (a + b) / 2.0;
75  }
76 
77  std::vector<T> left(data.begin(), data.begin() + data.size() / 2);
78  std::vector<T> right(data.begin() + data.size() / 2, data.end());
79 
80  if (left.size() > leaf_size) {
81  m_left_child = std::make_shared<Split>(dimensionality, leaf_size, std::move(left), (axis + 1) % dimensionality);
82  } else {
83  m_left_child = std::make_shared<Leaf>(std::move(left));
84  }
85  if (right.size() > leaf_size) {
86  m_right_child = std::make_shared<Split>(dimensionality, leaf_size, std::move(right), (axis + 1) % dimensionality);
87  } else {
88  m_right_child = std::make_shared<Leaf>(std::move(right));
89  }
90  }
91 
92  void findPointsWithinRadius(const T& coord, double radius, std::vector<T>& selection) const override {
93  if (Traits::getCoord(coord, m_axis) + radius < m_split_value) {
94  m_left_child->findPointsWithinRadius(coord, radius, selection);
95  } else if (Traits::getCoord(coord, m_axis) - radius > m_split_value) {
96  m_right_child->findPointsWithinRadius(coord, radius, selection);
97  } else {
98  m_left_child->findPointsWithinRadius(coord, radius, selection);
99  m_right_child->findPointsWithinRadius(coord, radius, selection);
100  }
101  }
102 
103  std::size_t countPointsWithinRadius(const T& coord, double radius) const override {
104  if (Traits::getCoord(coord, m_axis) + radius < m_split_value) {
105  return m_left_child->countPointsWithinRadius(coord, radius);
106  } else if (Traits::getCoord(coord, m_axis) - radius > m_split_value) {
107  return m_right_child->countPointsWithinRadius(coord, radius);
108  } else {
109  return m_left_child->countPointsWithinRadius(coord, radius) +
110  m_right_child->countPointsWithinRadius(coord, radius);
111  }
112  }
113 
114 private:
115  size_t m_axis;
116  double m_split_value;
117 
118  std::shared_ptr<Node> m_left_child;
119  std::shared_ptr<Node> m_right_child;
120 };
121 
122 template <typename T, typename DistanceMethod>
123 KdTree<T, DistanceMethod>::KdTree(const std::vector<T>& data, std::size_t leaf_size) {
124  if (!data.empty()) {
125  m_dimensionality = Traits::getDimensions(data.front());
126  } else {
127  m_dimensionality = 0;
128  }
129 
130  if (data.size() > leaf_size) {
131  m_root = std::make_shared<Split>(m_dimensionality, leaf_size, data, 0);
132  } else {
133  std::vector<T> data_copy(data);
134  m_root = std::make_shared<Leaf>(std::move(data_copy));
135  }
136 }
137 
138 template <typename T, typename DistanceMethod>
139 std::vector<T> KdTree<T, DistanceMethod>::findPointsWithinRadius(const T& coord, double radius) const {
140  std::vector<T> output;
141  m_root->findPointsWithinRadius(coord, radius, output);
142  return output;
143 }
144 
145 template <typename T, typename DistanceMethod>
146 std::size_t KdTree<T, DistanceMethod>::countPointsWithinRadius(const T& coord, double radius) const {
147  return m_root->countPointsWithinRadius(coord, radius);
148 }
149 
150 template <typename T>
151 bool EuclideanDistance<T>::isCloserThan(const T& a, const T& b, double distance) {
152  using Traits = KdTreeTraits<T>;
153  double square_dist = 0.0;
154  const std::size_t dim = Traits::getDimensions(a);
155  for (std::size_t i = 0; i < dim; i++) {
156  double delta = Traits::getCoord(a, i) - Traits::getCoord(b, i);
157  square_dist += delta * delta;
158  }
159  return square_dist < distance * distance;
160 }
161 
162 template <typename T>
163 bool ChebyshevDistance<T>::isCloserThan(const T& a, const T& b, double distance) {
164  using Traits = KdTreeTraits<T>;
165  double max_d = 0.;
166  const std::size_t dim = Traits::getDimensions(a);
167  for (std::size_t i = 0; i < dim; ++i) {
168  double delta = std::abs(Traits::getCoord(a, i) - Traits::getCoord(b, i));
169  if (delta > max_d) {
170  max_d = delta;
171  }
172  }
173  return max_d <= distance;
174 }
175 
176 } // namespace KdTree