141 lines
3.1 KiB
C++
141 lines
3.1 KiB
C++
/*
|
||
* © Copyright 2014 – Urheberrechtshinweis
|
||
* Alle Rechte vorbehalten / All Rights Reserved
|
||
*
|
||
* Programmcode ist urheberrechtlich geschuetzt.
|
||
* Das Urheberrecht liegt, soweit nicht ausdruecklich anders gekennzeichnet, bei Frank Ebner.
|
||
* Keine Verwendung ohne explizite Genehmigung.
|
||
* (vgl. § 106 ff UrhG / § 97 UrhG)
|
||
*/
|
||
|
||
#ifndef DRAWLIST_H
|
||
#define DRAWLIST_H
|
||
|
||
#include <vector>
|
||
|
||
#include "random/RandomGenerator.h"
|
||
#include "../Assertions.h"
|
||
|
||
/**
|
||
* add elements of a certain probability
|
||
* and randomly draw from them
|
||
*/
|
||
template <typename T> class DrawList {
|
||
|
||
/** one entry */
|
||
struct Entry {
|
||
|
||
/** the user element */
|
||
T element;
|
||
|
||
/** the cumulative probability up to this element */
|
||
double cumProbability;
|
||
|
||
/** the element's own probability */
|
||
double probability;
|
||
|
||
/** ctor */
|
||
Entry(T element, const double cumProbability, const double probability) : element(element), cumProbability(cumProbability), probability(probability) {;}
|
||
|
||
/** compare for searches */
|
||
bool operator < (const double val) const {return cumProbability < val;}
|
||
|
||
};
|
||
|
||
private:
|
||
|
||
/** current cumulative probability */
|
||
double cumProbability;
|
||
|
||
/** all contained elements */
|
||
std::vector<Entry> elements;
|
||
|
||
/** the used random number generator */
|
||
Random::RandomGenerator& gen;
|
||
|
||
|
||
private:
|
||
|
||
/** default random generator. fallback */
|
||
Random::RandomGenerator defRndGen;
|
||
|
||
|
||
public:
|
||
|
||
/** ctor with random seed */
|
||
DrawList() : cumProbability(0), gen(defRndGen) {
|
||
;
|
||
}
|
||
|
||
/** ctor with custom seed */
|
||
DrawList(const uint32_t seed) : cumProbability(0), defRndGen(seed), gen(defRndGen) {
|
||
;
|
||
}
|
||
|
||
/** ctor with custom RandomNumberGenerator */
|
||
DrawList(Random::RandomGenerator& gen) : cumProbability(0), gen(gen) {
|
||
;
|
||
}
|
||
|
||
/** change the seed */
|
||
void setSeed(const uint64_t seed) {
|
||
gen.seed(seed);
|
||
}
|
||
|
||
/** reset */
|
||
void reset() {
|
||
cumProbability = 0;
|
||
elements.clear();
|
||
}
|
||
|
||
/** adjust the reserved list size */
|
||
void reserve(const size_t numElements) {
|
||
elements.reserve(numElements);
|
||
}
|
||
|
||
/** add a new user-element and its probability */
|
||
void add(T element, const double probability) {
|
||
Assert::isTrue(probability >= 0, "probability must not be negative!");
|
||
cumProbability += probability;
|
||
elements.push_back(Entry(element, cumProbability, probability));
|
||
}
|
||
|
||
/** get a random element based on its probability */
|
||
T get() {
|
||
double tmp; // ignored
|
||
return get(tmp);
|
||
}
|
||
|
||
/**
|
||
* get a random element based on its probability.
|
||
* the probability of the picked element is returned using the out parameter
|
||
*/
|
||
T get(double& elemProbability) {
|
||
|
||
// generate random number between [0:cumProbability]
|
||
std::uniform_real_distribution<> dist(0, cumProbability);
|
||
|
||
// get a random value
|
||
const double rndVal = dist(gen);
|
||
|
||
// binary search for the matching entry O(log(n))
|
||
const auto tmp = std::lower_bound(elements.begin(), elements.end(), rndVal);
|
||
|
||
// sanity check
|
||
Assert::isFalse(tmp == elements.end(), "draw() did not find a valid element");
|
||
|
||
// done
|
||
elemProbability = (*tmp).probability;
|
||
return (*tmp).element;
|
||
|
||
}
|
||
|
||
/** get the current, cumulative probability */
|
||
double getCumProbability() const {
|
||
return cumProbability;
|
||
}
|
||
|
||
};
|
||
|
||
#endif // DRAWLIST_H
|