#ifndef KNN_H #define KNN_H #include "../lib/nanoflann/nanoflann.hpp" #include "Debug.h" /** * helper class to extract k-nearest-neighbors * from a given input data structure. * uses nanoflann * * usage: * Grid<30, T> theGrid; * KNN, 3, float> knn(theGrid); * std::vector elems = knn.get({0,0,0}, 10); */ template class KNN { private: static constexpr const char* name = "KNN"; /** type-definition for the nanoflann KD-Tree used for searching */ typedef nanoflann::KDTreeSingleIndexAdaptor, DataStructure, dim> Tree; /** the maximum depth of the tree */ static constexpr int maxLeafs = 10; /** the constructed tree used for searching */ Tree tree; /** the underlying data-structure we want to search within */ DataStructure& data; public: /** ctor */ KNN(DataStructure& data) : tree(dim, data, nanoflann::KDTreeSingleIndexAdaptorParams(maxLeafs)), data(data) { Log::add(name, "building kd-tree for " + std::to_string(data.kdtree_get_point_count()) + " elements", false); Log::tick(); tree.buildIndex(); Log::tock(); } /** get the k-nearest-neighbors for the given input point */ template std::vector get(const Scalar* point, const int numNeighbors, const float maxDistSquared = 99999) const { // buffer for to-be-fetched neighbors size_t indices[numNeighbors]; float distances[numNeighbors]; // find k-nearest-neighbors tree.knnSearch(point, numNeighbors, indices, distances); // construct output std::vector elements; for (int i = 0; i < numNeighbors; ++i) { if (distances[i] > maxDistSquared) {continue;} // too far away? elements.push_back(data[indices[i]]); } return elements; } /** get the k-nearest-neighbors for the given input point */ template std::vector get(std::initializer_list point, const int numNeighbors, const float maxDistSquared = 99999) const { return get(point.begin(), numNeighbors, maxDistSquared); } /** get the nearest neighbor and its distance */ void getNearest(const Scalar* point, size_t& idx, float& distSquared) { // find 1-nearest-neighbors tree.knnSearch(point, 1, &idx, &distSquared); } /** get the index of the element nearest to the given point */ size_t getNearestIndex(const Scalar* point) { size_t idx; float distSquared; tree.knnSearch(point, 1, &idx, &distSquared); return idx; } /** get the index of the element nearest to the given point */ size_t getNearestIndex(const std::initializer_list lst) { size_t idx; float distSquared; tree.knnSearch(lst.begin(), 1, &idx, &distSquared); return idx; } /** get the distance to the element nearest to the given point */ float getNearestDistance(const std::initializer_list lst) { size_t idx; float distSquared; tree.knnSearch(lst.begin(), 1, &idx, &distSquared); return std::sqrt(distSquared); } void get(const Scalar* point, const int numNeighbors, size_t* indices, float* squaredDist) { // find k-nearest-neighbors tree.knnSearch(point, numNeighbors, indices, squaredDist); } }; #endif // KNN_H