This repository has been archived on 2020-04-08. You can view files and clone it, but cannot push or open issues or pull requests.
Files
DSem1/walky/SlimParticleFilter.h
2018-06-05 16:55:45 +02:00

213 lines
5.6 KiB
C++

#ifndef SLIMPARTICLEFILTER_H
#define SLIMPARTICLEFILTER_H
#include <functional>
#include <vector>
#include <random>
#include <algorithm>
namespace SPF {
template <typename State> struct Particle {
State state;
double weight;
};
template <typename State, typename Control, typename Observation, bool OMP> class Filter {
using _Particle = Particle<State>;
using FuncInitSingle = std::function<void(_Particle&)>;
using FuncInitAll = std::function<void(std::vector<_Particle>&)>;
using FuncEvalSingle = std::function<double(const _Particle&, const Observation& obs)>;
using FuncTransitionSingle = std::function<void(_Particle&, const Control& ctrl)>;
using FuncTransitionAll = std::function<void(std::vector<_Particle>&, const Control& ctrl)>;
std::vector<Particle<State>> particles;
public:
/** initialize the given number of particles */
void initialize(const int num, FuncInitAll fInit) {
particles.resize(num);
fInit(particles);
}
/** initialize the given number of particles */
void initialize(const int num, FuncInitSingle fInit) {
particles.resize(num);
for (_Particle& p : particles) {
fInit(p);
}
}
const std::vector<_Particle>& getParticles() const {
return particles;
}
double lastNEff = 1;
double nEffThresholdPercent = 0.9;
/** perform resampling -> transition -> evaluation -> estimation */
template <typename Transition>
State update(const Control& control, const Observation& observation, Transition fTrans, FuncEvalSingle fEval) {
// if the number of efficient particles is too low, perform resampling
if (lastNEff < particles.size() * nEffThresholdPercent) { resample(particles); }
//resample(particles);
// perform the transition step
transition(fTrans, control);
// perform the evaluation step
#pragma omp parallel for if(OMP)
for (size_t i = 0; i < particles.size(); ++i) {
particles[i].weight *= fEval(particles[i], observation);
}
// normalize the particle weights and thereby calculate N_eff
lastNEff = normalize();
std::cout << "NEff: " << lastNEff << std::endl;
//std::cout << "normalized. n_eff is " << lastNEff << std::endl;
// estimate the current state
const State est = estimate(particles);
// done
return est;
}
private:
/** all particles at once transition */
void transition(FuncTransitionAll fTrans, const Control& control) {
fTrans(particles, control);
}
/** single particle at once transition */
void transition(FuncTransitionSingle fTrans, const Control& control) {
#pragma omp parallel for if(OMP)
for (size_t i = 0; i < particles.size(); ++i) {
fTrans(particles[i], control);
}
}
State estimate(std::vector<_Particle>& particles) const {
State tmp;
// calculate weighted average
double weightSum = 0;
for (const _Particle& p : particles) {
const double weight = p.weight;// * p.weight * p.weight;
tmp += p.state * weight;
weightSum += weight;
}
_assertTrue( (weightSum == weightSum), "the sum of particle weights is NaN!");
_assertTrue( (weightSum != 0), "the sum of particle weights is null!");
// normalize
tmp /= weightSum;
return tmp;
}
/** normalize the weight of all particles to 1.0 and perform some sanity checks */
double normalize() {
// calculate sum(weights)
//double min1 = 9999999;
double weightSum = 0.0;
for (const _Particle& p : particles) {
weightSum += p.weight;
//if (p.weight < min1) {min1 = p.weight;}
}
// sanity check. always!
if (weightSum != weightSum) {
throw Exception("sum of paticle-weights is NaN");
}
if (weightSum == 0) {
throw Exception("sum of paticle-weights is 0.0");
}
// normalize and calculate N_eff
double sum2 = 0.0;
//double min2 = 9999999;
for (_Particle& p : particles) {
p.weight /= weightSum;
//if (p.weight < min2) {min2 = p.weight;}
sum2 += (p.weight * p.weight);
}
// N_eff
return 1.0 / sum2;
}
std::vector<_Particle> particlesCopy;
std::minstd_rand gen;
void resample(std::vector<_Particle>& particles) {
// compile-time sanity checks
// TODO: this solution requires EXPLICIT overloading which is bad...
//static_assert( HasOperatorAssign<State>::value, "your state needs an assignment operator!" );
const uint32_t cnt = (uint32_t) particles.size();
// equal weight for all particles. sums up to 1.0
const double equalWeight = 1.0 / (double) cnt;
// ensure the copy vector has the same size as the real particle vector
particlesCopy.resize(cnt);
// swap both vectors
particlesCopy.swap(particles);
// calculate cumulative weight
double cumWeight = 0;
for (uint32_t i = 0; i < cnt; ++i) {
cumWeight += particlesCopy[i].weight;
particlesCopy[i].weight = cumWeight;
}
// now draw from the copy vector and fill the original one
// with the resampled particle-set
for (uint32_t i = 0; i < cnt; ++i) {
particles[i] = draw(cumWeight);
particles[i].weight = equalWeight;
}
}
/** draw one particle according to its weight from the copy vector */
const _Particle& draw(const double cumWeight) {
// generate random values between [0:cumWeight]
std::uniform_real_distribution<float> dist(0, cumWeight);
// draw a random value between [0:cumWeight]
const float rand = dist(gen);
// search comparator (cumWeight is ordered -> use binary search)
auto comp = [] (const Particle<State>& s, const float d) {return s.weight < d;};
auto it = std::lower_bound(particlesCopy.begin(), particlesCopy.end(), rand, comp);
return *it;
}
};
}
#endif // SLIMPARTICLEFILTER_H