diff --git a/main.cpp b/main.cpp index 52a00bc..b103b91 100755 --- a/main.cpp +++ b/main.cpp @@ -29,7 +29,7 @@ int main(int argc, char** argv) { //::testing::GTEST_FLAG(filter) = "*WiFiOptimizer*"; - ::testing::GTEST_FLAG(filter) = "*FloorplanCeilings*"; + ::testing::GTEST_FLAG(filter) = "*KullbackLeibler*"; //::testing::GTEST_FLAG(filter) = "*Barometer*"; //::testing::GTEST_FLAG(filter) = "*GridWalk2RelPressure*"; diff --git a/math/divergence/KullbackLeibler.h b/math/divergence/KullbackLeibler.h index 357b1c3..1ae4433 100644 --- a/math/divergence/KullbackLeibler.h +++ b/math/divergence/KullbackLeibler.h @@ -53,6 +53,10 @@ namespace Divergence { auto log1 = std::log(det1); auto log2 = std::log(det2); + //determinate shouldn't be 0! + Assert::isNot0(det1, "Determinat of the first Gauss is Zero! Check the Cov Matrix."); + Assert::isNot0(det2, "Determinat of the second Gauss is Zero! Check the Cov Matrix."); + //trace Eigen::MatrixXd toTrace(norm1.getSigma().rows(),norm1.getSigma().cols()); toTrace = norm2.getSigmaInv() * norm1.getSigma(); @@ -74,6 +78,8 @@ namespace Divergence { if(klb < 0.0){ Assert::doThrow("The Kullback Leibler Distance is < 0! Thats not possible"); } + Assert::isNotNaN(klb, "The Kullback Leibler Distance is NaN!"); + return klb; } diff --git a/tests/math/divergence/TestKullbackLeibler.cpp b/tests/math/divergence/TestKullbackLeibler.cpp index 0a14f92..0a899f1 100644 --- a/tests/math/divergence/TestKullbackLeibler.cpp +++ b/tests/math/divergence/TestKullbackLeibler.cpp @@ -201,7 +201,7 @@ TEST(KullbackLeibler, multivariateGaussGeCov) { double kld12 = Divergence::KullbackLeibler::getMultivariateGauss(norm1, norm2); double kld34 = Divergence::KullbackLeibler::getMultivariateGauss(norm3, norm4); - std::cout << kld34 << " >" << kld12 << std::endl; + std::cout << kld34 << " > " << kld12 << std::endl; double kld12sym = Divergence::KullbackLeibler::getMultivariateGaussSymmetric(norm1, norm2); double kld34sym = Divergence::KullbackLeibler::getMultivariateGaussSymmetric(norm3, norm4);