77 lines
2.0 KiB
C++
77 lines
2.0 KiB
C++
#ifndef KNN_H
|
|
#define KNN_H
|
|
|
|
#include "../lib/nanoflann/nanoflann.hpp"
|
|
|
|
/**
|
|
* helper class to extract k-nearest-neighbors
|
|
* from a given input data structure.
|
|
* uses nanoflann
|
|
*
|
|
* usage:
|
|
* KNN<float, Grid<20, T>, T, 3> knn(theGrid);
|
|
* float search[] = {0,0,0};
|
|
* std::vector<T> elems = knn.get(search, 3);
|
|
*/
|
|
template <typename Scalar, typename DataStructure, typename Element, int dim> class KNN {
|
|
|
|
private:
|
|
|
|
/** type-definition for the nanoflann KD-Tree used for searching */
|
|
typedef nanoflann::KDTreeSingleIndexAdaptor<nanoflann::L2_Simple_Adaptor<Scalar, DataStructure>, DataStructure, dim> Tree;
|
|
|
|
/** the maximum depth of the tree */
|
|
static constexpr int maxLeafs = 10;
|
|
|
|
/** the constructed tree used for searching */
|
|
Tree tree;
|
|
|
|
/** the underlying data-structure we want to search within */
|
|
DataStructure& data;
|
|
|
|
public:
|
|
|
|
/** ctor */
|
|
KNN(DataStructure& data) : tree(dim, data, nanoflann::KDTreeSingleIndexAdaptorParams(maxLeafs)), data(data) {
|
|
tree.buildIndex();
|
|
}
|
|
|
|
/** get the k-nearest-neighbors for the given input point */
|
|
std::vector<Element> get(const Scalar* point, const int numNeighbors, const float maxDistSquared = 99999) const {
|
|
|
|
// buffer for to-be-fetched neighbors
|
|
size_t indices[numNeighbors];
|
|
float distances[numNeighbors];
|
|
|
|
// find k-nearest-neighbors
|
|
tree.knnSearch(point, numNeighbors, indices, distances);
|
|
|
|
// construct output
|
|
std::vector<Element> elements;
|
|
for (int i = 0; i < numNeighbors; ++i) {
|
|
if (distances[i] > maxDistSquared) {continue;} // too far away?
|
|
elements.push_back(data[indices[i]]);
|
|
}
|
|
return elements;
|
|
|
|
}
|
|
|
|
/** get the nearest neighbor and its distance */
|
|
void getNearest(const Scalar* point, size_t& idx, float& distSquared) {
|
|
|
|
// find 1-nearest-neighbors
|
|
tree.knnSearch(point, 1, &idx, &distSquared);
|
|
|
|
}
|
|
|
|
void get(const Scalar* point, const int numNeighbors, size_t* indices, float* squaredDist) {
|
|
|
|
// find k-nearest-neighbors
|
|
tree.knnSearch(point, numNeighbors, indices, squaredDist);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
#endif // KNN_H
|