added kde resampling with percentage
This commit is contained in:
136
smc/filtering/resampling/ParticleFilterResamplingKDEPercent.h
Normal file
136
smc/filtering/resampling/ParticleFilterResamplingKDEPercent.h
Normal file
@@ -0,0 +1,136 @@
|
||||
#ifndef PARTICLEFILTERRESAMPLINGKDEPERCENT_H
|
||||
#define PARTICLEFILTERRESAMPLINGKDEPERCENT_H
|
||||
|
||||
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
|
||||
#include "ParticleFilterResampling.h"
|
||||
#include "../../ParticleAssertions.h"
|
||||
|
||||
#include "../../../math/boxkde/benchmark.h"
|
||||
#include "../../../math/boxkde/DataStructures.h"
|
||||
#include "../../../math/boxkde/Image3D.h"
|
||||
#include "../../../math/boxkde/BoxGaus3D.h"
|
||||
#include "../../../math/boxkde/Grid3D.h"
|
||||
#include "../../../math/distribution/Normal.h"
|
||||
|
||||
#include "../../../navMesh/NavMesh.h"
|
||||
#include "../../../floorplan/v2/FloorplanHelper.h"
|
||||
|
||||
namespace SMC {
|
||||
|
||||
/**
|
||||
* Remove the worst <percent> of particles, then calculate the KDE and reweight the remaining particles accordingly.
|
||||
*/
|
||||
template <typename State, typename Tria>
|
||||
class ParticleFilterResamplingKDEPercent : public ParticleFilterResampling<State> {
|
||||
|
||||
private:
|
||||
|
||||
/** this is a copy of the particle-set to draw from it */
|
||||
std::vector<Particle<State>> particlesCopy;
|
||||
|
||||
/** random number generator */
|
||||
std::minstd_rand gen;
|
||||
|
||||
/** boundingBox for the boxKDE */
|
||||
_BBox3<float> bb;
|
||||
|
||||
/** histogram/grid holding the particles*/
|
||||
std::unique_ptr<Grid3D<float>> grid;
|
||||
|
||||
/** bandwith for KDE */
|
||||
Point3 bandwith;
|
||||
|
||||
/** the current mesh */
|
||||
const NM::NavMesh<Tria>* mesh;
|
||||
|
||||
/** percent to remove **/
|
||||
float percent;
|
||||
|
||||
public:
|
||||
|
||||
/** ctor */
|
||||
ParticleFilterResamplingKDEPercent(const NM::NavMesh<Tria>* mesh, const Point3 gridsize_m, const Point3 bandwith, const float percent) {
|
||||
|
||||
this->mesh = mesh;
|
||||
this->bandwith = bandwith;
|
||||
this->bb = mesh->getBBox();
|
||||
this->bb.grow(10);
|
||||
|
||||
// Create histogram
|
||||
size_t nBinsX = (size_t)((this->bb.getMax().x - this->bb.getMin().x) / gridsize_m.x);
|
||||
size_t nBinsY = (size_t)((this->bb.getMax().y - this->bb.getMin().y) / gridsize_m.y);
|
||||
size_t nBinsZ = (size_t)((this->bb.getMax().z - this->bb.getMin().z) / gridsize_m.z);
|
||||
|
||||
this->grid = std::make_unique<Grid3D<float>>(bb, nBinsX, nBinsY, nBinsZ);
|
||||
|
||||
this->percent = percent;
|
||||
|
||||
gen.seed(1234);
|
||||
}
|
||||
|
||||
void resample(std::vector<Particle<State>>& particles) override {
|
||||
|
||||
// compile-time sanity checks
|
||||
static_assert( HasOperatorPlusEq<State>::value, "your state needs a += operator!" );
|
||||
static_assert( HasOperatorDivEq<State>::value, "your state needs a /= operator!" );
|
||||
static_assert( HasOperatorMul<State>::value, "your state needs a * operator!" );
|
||||
//static_assert( std::is_constructible<State, Point3>::value, "your state needs a constructor with Point3!");
|
||||
//todo: static assert for getx, gety, getz, setposition
|
||||
|
||||
|
||||
// comparator (highest first)
|
||||
static auto comp = [] (const Particle<State>& p1, const Particle<State>& p2) {
|
||||
return p1.weight > p2.weight;
|
||||
};
|
||||
|
||||
const uint32_t cnt = (uint32_t) particles.size();
|
||||
|
||||
// sort particles by weight (highest first)
|
||||
std::sort(particles.begin(), particles.end(), comp);
|
||||
|
||||
// to-be-removed region
|
||||
const int start = particles.size() * (1-percent);
|
||||
const int end = particles.size();
|
||||
std::uniform_int_distribution<int> dist(0, start-1);
|
||||
|
||||
// remove by re-drawing
|
||||
for (uint32_t i = start; i < end; ++i) {
|
||||
const int rnd = dist(gen);
|
||||
particles[i] = particles[rnd];
|
||||
}
|
||||
|
||||
grid->clear();
|
||||
for (Particle<State> p : particles){
|
||||
//grid.add receives position in meter!
|
||||
|
||||
//if weight is to low, remove.
|
||||
if((float) p.weight > 0 ){
|
||||
grid->add(p.state.getX(), p.state.getY(), p.state.getZ(), p.weight);
|
||||
}
|
||||
}
|
||||
|
||||
// init KDE
|
||||
int nFilt = 3;
|
||||
float sigmaX = bandwith.x / grid->binSizeX;
|
||||
float sigmaY = bandwith.y / grid->binSizeY;
|
||||
float sigmaZ = bandwith.z / grid->binSizeZ;
|
||||
|
||||
// process KDE
|
||||
BoxGaus3D<float> boxGaus;
|
||||
boxGaus.approxGaus(grid->image(), Point3(sigmaX, sigmaY, sigmaZ), nFilt);
|
||||
|
||||
|
||||
// reweight the particle using the kde. in theory this should be the same weight.
|
||||
// however, as the kde "smoothes" the pdf, this values are lower.
|
||||
for (Particle<State> p : particles){
|
||||
p.weight = grid->fetch(p.state.getX(), p.state.getY(), p.state.getZ());
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
#endif // PARTICLEFILTERRESAMPLINGKDEPERCENT_H
|
||||
Reference in New Issue
Block a user