Files
BeatDetector/lib/BeatDetection2.h

245 lines
5.4 KiB
C++
Executable File

#ifndef BEATDETECTION2_H
#define BEATDETECTION2_H
#include "BiquadFilterGate.h"
#include "MovingAVG.h"
#include "EndlessAVG.h"
#include <iostream>
//#define PLOT_ME
#ifdef PLOT_ME
#include <KLib/misc/gnuplot/Gnuplot.h>
#include <KLib/misc/gnuplot/GnuplotPlot.h>
#include <KLib/misc/gnuplot/GnuplotPlotElementLines.h>
#include <KLib/misc/gnuplot/GnuplotSplot.h>
#endif
#define BD2_SHORT 1024
template <typename Scalar> struct BeatBand2 {
BiquadFilterGate<2> filter;
MovingAVG<float> avgShort;
MovingAVG<float> avgLong;
MovingAVG<float> avgLongSquared;
MovingAVG<float> avgDiff;
MovingAVG<float> avgDiffSquared;
BeatBand2() : avgShort(BD2_SHORT), avgLong(1024*64), avgLongSquared(1024*64), avgDiff(1024*80), avgDiffSquared(1024*80) {
;
}
void add(Scalar left, Scalar right) {
// filter
const Scalar fLeft = std::abs(filter.filter(0, left));
const Scalar fRight = std::abs(filter.filter(1, right));
//const Scalar energy = fLeft*fLeft + fRight*fRight;
//const Scalar energy = fLeft + fRight;
const Scalar energy = fLeft*fLeft*fLeft + fRight*fRight*fRight;
// update
avgShort.add(energy);
avgLong.add(energy);
avgLongSquared.add(energy*energy);
const float diff = avgShort.get() - avgLong.get();
avgDiff.add(diff);
avgDiffSquared.add(diff*diff);
}
float getRatio() const {
return avgShort.get() / avgLong.get();
}
float getVariance() const {
return avgLongSquared.get() - (avgLong.get() * avgLong.get());
}
float getStdDev() const {
return std::sqrt(getVariance());
}
float getDiffVariance() const {
return avgDiffSquared.get() - (avgDiff.get() * avgDiff.get());
}
float getDiffStdDev() const {
return std::sqrt(getDiffVariance());
}
};
enum class Mode {
BASS,
SNARE,
};
template <typename Scalar> class BeatDetection2 {
private:
#ifdef PLOT_ME
K::Gnuplot gp;
K::GnuplotPlot plot;
K::GnuplotPlotElementLines lines0;
K::GnuplotPlotElementLines linesVar;
K::GnuplotPlotElementLines linesAvgLong;
#endif
int cnt = 0;
BeatBand2<float> band0;
Mode mode;
float thresholdMul;
public:
BeatDetection2() : BeatDetection2(Mode::BASS) {
;
}
/** setup */
BeatDetection2(Mode mode) : mode(mode) {
setSampleRate(44100);
switch(mode) {
case Mode::BASS: thresholdMul = 0.9; break;
case Mode::SNARE: thresholdMul = 1.5; break;
default: throw "invalid mode";
}
#ifdef PLOT_ME
gp.setTerminal("wxt", K::GnuplotSize(20, 12));
plot.add(&lines0);
plot.add(&linesVar); linesVar.getStroke().getColor().setHexStr("#0000ff");
plot.add(&linesAvgLong); linesAvgLong.getStroke().getColor().setHexStr("#00aa00");
#endif
}
void setSampleRate(int srate) {
std::cout << "setting sample-rate to " << srate << std::endl;
//band0.filter.setLowPass(110, 1, srate);
switch(mode) {
case Mode::BASS: band0.filter.setBandPass(40, 0.90f, srate); break;
case Mode::SNARE: band0.filter.setBandPass(200, 0.20f, srate); break;
default: throw "invalid mode";
}
//band0.filter.setBandPass(5000, 1, srate);
}
int block = 0;
bool lastWasBeat = false;
bool gapFound = false;
/** add single value */
bool add(Scalar left, Scalar right) {
static int x = 0; ++x;
left /= 36768.0f;
right /= 36768.0f;
++cnt;
band0.add(left, right);
bool curIsBeat = false;
if (cnt == BD2_SHORT) {
cnt = 0;
//const float ratio0 = band0.getRatio();
//const float var0 = band0.getVariance();
const float stdDev0 = band0.getDiffStdDev();
//const float diff0 = band0.avgShort.get() - band0.avgLong.get();
const float avgShort0 = band0.avgShort.get();
//const float avgLong0 = band0.avgLong.get();
const float threshold0 = band0.avgLong.get() + stdDev0 * thresholdMul; // HERE!
//avgRes.add(avgShort0);
//const float zz = avgRes.get();
const float zz = avgShort0;
//const float C = 2.7 - (20 * var0);
//const float C = 1.0 + (var0);
//if (ratio0 > C) {curIsBeat = true;}
//if (zz > 0.05 && zz > threshold0) {curIsBeat = true;}
if (zz > threshold0) {curIsBeat = true;}
#ifdef PLOT_ME
lines0.add(K::GnuplotPoint2(x, zz));
linesVar.add(K::GnuplotPoint2(x, threshold0));
//linesAvgLong.add(K::GnuplotPoint2(x, avgLong0));
static int xx = 0; ++xx;
if (xx % 5 == 0) {
show();
}
#endif
if (!curIsBeat) {gapFound = true;}
//if (block > 0 && !curIsBeat) {--block;}
if (block > 0 && gapFound) {--block;}
if (block == 0 && curIsBeat) {
block = 8; // very short!
gapFound = false;
//std::cout << ratio0 << " : " << var0 << " : " << C << std::endl;
return true;
}
//const bool res = curIsBeat && !lastWasBeat;
//lastWasBeat = curIsBeat;
// if (res) {
// std::cout << ratio0 << " : " << var0 << " : " << C << std::endl;
// }
//return res;
}
return false;
}
#ifdef PLOT_ME
void show() {
int limit = 200;
int d0 = lines0.size() - limit;
if (d0 > 0) {
lines0.remove(0, d0);
}
int x0 = lines0[0].x;
int x1 = lines0[lines0.size()-1].x;
plot.getAxisX().setRange(x0, x1);
//plot.getAxisY().setRange(-0.1, 0.25);
//plot.getAxisY().setRange(-32768, +32768);
gp.draw(plot);
gp.flush();
}
#endif
};
#endif // BEATDETECTION2_H