added ground truth to java method

fixed some bugs
improved algo and results
This commit is contained in:
toni
2019-01-27 10:47:46 +01:00
parent 6bb8bb6b4f
commit 49042a0cfb
8 changed files with 222 additions and 70 deletions

View File

@@ -8,8 +8,10 @@ import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.ParseException;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.Vector;
import java.util.stream.IntStream;
@@ -21,30 +23,54 @@ public class Main {
public static void main(String [ ] args) {
//File folder = new File("/home/toni/Documents/programme/dirigent/measurements/2017.06/lgWear");
//File folder = new File("/home/toni/Documents/programme/dirigent/measurements/peter_failed");
File folder = new File("/home/toni/Documents/programme/dirigent/measurements/2018.06/leon/mSensor");
File folder = new File("/home/toni/Documents/programme/dirigent/measurements/2018.06/frank/mSensorTest");
File[] listOfFiles = folder.listFiles();
Arrays.sort(listOfFiles);
//calc results
BpmHistory historyAll = new BpmHistory();
BpmHistory historyMag = new BpmHistory();
BpmHistory history3D = new BpmHistory();
// iterate trough files in measurements folder
for (File file : listOfFiles) {
if (file.isFile() && file.getName().contains(".csv")) {
AccelerometerWindowBuffer accWindowBuffer = new AccelerometerWindowBuffer(6000, 750);
BpmEstimator bpmEstimator = new BpmEstimator(accWindowBuffer, 0, 5000);
AccelerometerWindowBuffer accWindowBuffer = new AccelerometerWindowBuffer(6000, 875);
BpmEstimator bpmEstimator = new BpmEstimator(accWindowBuffer, 4, 50000);
//read the file line by line
try (BufferedReader br = new BufferedReader(new FileReader(file))) {
//read the first three lines and print out what file it is!
String comment = br.readLine();
String groundTruthLine = br.readLine();
br.readLine();
String groundTruth = br.readLine();
System.out.println(comment);
//long startTs = Long.parseLong(br.readLine().split(";")[0]);
System.out.println(file.getName());
//load ground truth file
final long startTs = Long.parseLong(br.readLine().split(";")[0]);
String gtFile = groundTruthLine.substring(groundTruthLine.indexOf(':') + 2);
Utils.GroundTruthData gtData = new Utils.GroundTruthData();
double gtCurValue = 0d;
if (gtFile.contains(".csv")) {
try (BufferedReader gtBr = new BufferedReader(new FileReader("../../measurements/2018.06/gt_toni/" + gtFile))) {
for (String gtLine; (gtLine = gtBr.readLine()) != null; ) {
gtData.setValuesFromString(gtLine);
}
} catch (IOException e) {
e.printStackTrace();
} catch (ParseException e) {
e.printStackTrace();
}
} else {
gtData.setSingleBPM(Double.valueOf(gtFile));
}
//read sensor measurements line by line
for (String line; (line = br.readLine()) != null; ) {
// process the line.
String[] measurement = line.split(";");
@@ -66,13 +92,13 @@ public class Main {
// Calculate the BPM for different window sizes
double bpm60 = bpmEstimator.estimate();
double bpm85 = bpmEstimator.estimate(3500, 750);
double bpm110 = bpmEstimator.estimate(2600, 750);
double bpm135 = bpmEstimator.estimate(2000, 750);
double bpm160 = bpmEstimator.estimate(1600,750);
double bpm200 = bpmEstimator.estimate(1200, 750);
double bpm85 = bpmEstimator.estimate(3500, 875);
double bpm110 = bpmEstimator.estimate(2600, 875);
double bpm135 = bpmEstimator.estimate(2000, 875);
double bpm160 = bpmEstimator.estimate(1600, 875);
double bpm200 = bpmEstimator.estimate(1200, 875);
System.out.println("--------------------------------------------------");
//System.out.println("--------------------------------------------------");
bpmList.add(bpm60);
bpmList.add(bpm85);
@@ -81,7 +107,8 @@ public class Main {
bpmList.add(bpm160);
bpmList.add(bpm200);
while(bpmList.remove(Double.valueOf(-1))) {}
while (bpmList.remove(Double.valueOf(-1))) {
}
Utils.removeOutliersZScore(bpmList, 3.4);
double bpmMean = Utils.mean(bpmList);
@@ -91,8 +118,30 @@ public class Main {
//double bpmSingle = bpmEstimator.getBestSingleAxis();
double bpmAllAverage = bpmEstimator.getAverageOfAllWindows();
System.out.println( ts + " all: " + Math.round(bpmMean) + " avg_all: " + Math.round(bpmAllAverage) + " 3D: " + Math.round(bpmDist));
//System.out.println( ts + " all: " + Math.round(bpmMean) + " avg_all: " + Math.round(bpmAllAverage) + " 3D: " + Math.round(bpmDist));
//System.out.println(" ");
//calc error using ground truth
long curTS = accWindowBuffer.getYongest().ts - startTs;
int idx = 0;
while (curTS > gtData.getTimestamp(idx) && idx < gtData.getSize() - 1) {
++idx;
}
gtCurValue = gtData.getBPM(idx);
//fill histories
historyAll.add(bpmAllAverage - gtCurValue);
historyMag.add(magMean - gtCurValue);
history3D.add(bpmDist - gtCurValue);
if(true){
System.out.println("all: " + bpmAllAverage);
System.out.println("mag: " + magMean);
System.out.println("3D: " + bpmDist);
System.out.println("GT: " + gtCurValue);
System.out.println(" ");
}
int dummyForBreakpoint = 0;
}
@@ -113,10 +162,18 @@ public class Main {
e.printStackTrace();
}
//print overall stats for a single data series
System.out.println("all: " + historyAll.getMean() + "(" + historyAll.getStd() + ")");
System.out.println("mag: " + historyMag.getMean() + "(" + historyMag.getStd() + ")");
System.out.println(" 3D: " + history3D.getMean() + "(" + history3D.getStd() + ")");
System.out.println(" ");
history3D.clear();
historyMag.clear();
historyAll.clear();
}
// try {
// System.in.read();
// } catch (IOException e) {

View File

@@ -67,7 +67,7 @@ public class AccelerometerWindowBuffer extends ArrayList<AccelerometerData> {
public boolean isNextWindowReady(){
if(!isEmpty()){
if(((getYongest().ts - getOldest().ts) > mWindowSize / 4) && mOverlapCounter > mOverlapSize){
if(((getYongest().ts - getOldest().ts) > mWindowSize) && mOverlapCounter > mOverlapSize){
mOverlapCounter = 0;
return true;

View File

@@ -118,7 +118,7 @@ public class BpmEstimator {
double[] magAutoCorr = new AutoCorrelation(magButter, tmpBuffer.size()).getCorr();
//dist correlation
double[] distCorr = new DistanceCorrelation(interp, (int) (interp.size() * 0.8)).getCorr();
double[] distCorr = new DistanceCorrelation(interp, (int) (interp.size() / 2)).getCorr();
//find a peak within range of 250 ms
int peakWidth = (int) Math.round(250 / sampleRate);
@@ -137,15 +137,15 @@ public class BpmEstimator {
if(DEBUG_MODE){
//plotter.setPlotRawX(interp.getTs(), interp.getX());
//plotter.setPlotRawY(interp.getTs(), interp.getY());
//plotter.setPlotRawZ(interp.getTs(), interp.getZ());
//plotter.setPlotRawMag(interp.getTs(), magRaw);
plotter.setPlotRawX(interp.getTs(), interp.getX());
plotter.setPlotRawY(interp.getTs(), interp.getY());
plotter.setPlotRawZ(interp.getTs(), interp.getZ());
plotter.setPlotRawMag(interp.getTs(), magRaw);
//plotter.setPlotButterX(interp.getTs(), xButter);
//plotter.setPlotButterY(interp.getTs(), yButter);
//plotter.setPlotButterZ(interp.getTs(), zButter);
//plotter.setPlotButterMag(interp.getTs(), magButter);
plotter.setPlotButterX(interp.getTs(), xButter);
plotter.setPlotButterY(interp.getTs(), yButter);
plotter.setPlotButterZ(interp.getTs(), zButter);
plotter.setPlotButterMag(interp.getTs(), magButter);
plotter.setPlotCorrX(xAutoCorr, xPeaks);
plotter.setPlotCorrY(yAutoCorr, yPeaks);
@@ -153,15 +153,13 @@ public class BpmEstimator {
plotter.setPlotCorrMag(magAutoCorr, magPeaks);
plotter.setPlotCorr3D(distCorr, distPeaks);
}
//printout the current BPM
System.out.println(length_ms + "; x: " + Math.round(xPeaks.getBPM(sampleRate))
+ "; y: " + Math.round(yPeaks.getBPM(sampleRate))
+ "; z: " + Math.round(zPeaks.getBPM(sampleRate))
+ "; mag: " + Math.round(magPeaks.getBPM(sampleRate))
+ "; 3D: " + Math.round(distPeaks.getBPM(sampleRate)));
}
double estimatedBPM = getBestBpmEstimation(xPeaks, yPeaks, zPeaks, magPeaks);
@@ -198,9 +196,14 @@ public class BpmEstimator {
}
public double getDistEstimation(){
if(!mBpmHistory_Dist.isEmpty()){
BpmHistory tmp = (BpmHistory) mBpmHistory_Dist.clone();
tmp.removeOutliers();
return tmp.getMean();
} else {
return -1;
}
}
public double getMagnitudeMean(){
@@ -208,9 +211,15 @@ public class BpmEstimator {
//Utils.removeOutliersZScore(mBpmHistory_Mag, 3.4);
//double mean = Utils.mean(mBpmHistory_Mag);
//mBpmHistory_Mag.clear();
if(!mBpmHistory_Mag.isEmpty()){
BpmHistory tmp = (BpmHistory) mBpmHistory_Mag.clone();
tmp.removeOutliers();
return tmp.getMean();
} else {
return -1;
}
}
public double getMeanBpm(){
@@ -282,6 +291,7 @@ public class BpmEstimator {
tmpHistory.add(mBpmHistory_Z);
tmpHistory.add(mBpmHistory_Mag);
if(!tmpHistory.isEmpty()){
//remove outliers again
tmpHistory.removeOutliers();
@@ -293,6 +303,9 @@ public class BpmEstimator {
mBpmHistory_Dist.clear();
return tmpHistory.getMean();
} else {
return -1;
}
}

View File

@@ -35,7 +35,7 @@ public class BpmHistory extends LinkedList<Double> {
if(this.size() > 2){
return Utils.mean(this);
} else {
return 333; //TODO: das ist natürlich quatsch und faulheit. mal schaun wie man das am besten löst.
return this.getFirst(); //TODO: das ist natürlich quatsch und faulheit. mal schaun wie man das am besten löst.
}
}
@@ -43,10 +43,14 @@ public class BpmHistory extends LinkedList<Double> {
if(this.size() > 2){
return Utils.var(this);
} else {
return 666; //TODO: das ist natürlich quatsch und faulheit. mal schaun wie man das am besten löst.
return 0; //TODO: das ist natürlich quatsch und faulheit. mal schaun wie man das am besten löst.
}
}
public double getStd(){
return Utils.stdDev(this);
}
public void removeOutliers(){
Utils.removeOutliersZScore(this, 3.4);
}

View File

@@ -3,7 +3,6 @@ package bpmEstimation;
import utilities.Utils;
import java.util.Arrays;
import java.util.Collections;
import java.util.DoubleSummaryStatistics;
/**
@@ -11,6 +10,8 @@ import java.util.DoubleSummaryStatistics;
*/
public class DistanceCorrelation {
//TODO: remove bad peaks found at the very beginning and end of the signal
private static int mMaxLag;
private double[] mCorr;

View File

@@ -125,6 +125,7 @@ public class Peaks {
*/
public double getBPM(double sampleRate_ms){
//todo: rückweisungsklasse kann auch hier mit rein.
if(hasPeaks()){
@@ -148,6 +149,19 @@ public class Peaks {
mPeaksValue.add(mData[idx]);
}
/*
if(hasPeaks()) {
//wir entfernen den ersten und den letzten peak weil die dist correlation
//am anfang und ende oft peaks erkennt, die käse sind und so den fehler in die höhe
//treiben können...
mPeaksPos.removeFirst();
mPeaksIdx.removeFirst();
mPeaksValue.removeFirst();
mPeaksPos.removeLast();
mPeaksIdx.removeLast();
mPeaksValue.removeLast();
}*/
}
//TODO: findPeaks method identical to Matlab... with PeakProminence

View File

@@ -7,6 +7,9 @@ import javax.swing.*;
import java.awt.*;
import java.awt.event.WindowEvent;
import java.awt.image.BufferedImage;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.List;
import java.util.stream.IntStream;
@@ -14,7 +17,7 @@ import java.util.stream.IntStream;
//TODO: change from double to generic type
public class Utils {
public static final boolean DEBUG_MODE = true;
public static final boolean DEBUG_MODE = false;
public static double getDistance(double x1, double y1, double x2, double y2) {
return (double) Math.sqrt((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2));
@@ -96,7 +99,7 @@ public class Utils {
double[] diff = new double[data.length - 1];
int i=0;
for(int j = 1; j < data.length; ++j){
diff[i] = data[j] - data[i];
diff[i] = Math.abs(data[j] - data[i]);
++i;
}
return diff;
@@ -443,4 +446,61 @@ public class Utils {
}
public static class GroundTruthData {
public static class GroundTruthValue {
public Calendar timestamp;
public double bpm;
public GroundTruthValue(double bpm, Calendar ts){
this.timestamp = ts;
this.bpm = bpm;
}
}
private DateFormat format = new SimpleDateFormat("HH:mm:ss.SSS");
private Vector<GroundTruthValue> gtData;
private boolean isFile = true;
public GroundTruthData(){
this.gtData = new Vector<>();
}
public void setValuesFromString(String val) throws ParseException {
String[] measurement = val.split(" ");
if(measurement.length < 2){
throw new RuntimeException("broken Ground Truth format");
}
double bpm = Double.valueOf(measurement[0]);
Calendar date = Calendar.getInstance();
date.setTime(format.parse(measurement[1]));
this.gtData.add(new GroundTruthValue(bpm, date));
}
public void setSingleBPM(double bpm){
isFile = true;
this.gtData.add(new GroundTruthValue(bpm, Calendar.getInstance()));
}
public double getBPM(int idx){
if(isFile) {
return this.gtData.get(idx).bpm;
} else return this.gtData.get(0).bpm;
}
public long getTimestamp(int idx){
if(isFile) {
return 1000 * (60 * this.gtData.get(idx).timestamp.get(Calendar.MINUTE) + this.gtData.get(idx).timestamp.get(Calendar.SECOND)) + this.gtData.get(idx).timestamp.get(Calendar.MILLISECOND);
} else return 0L;
}
public int getSize(){
return this.gtData.size();
}
}
}

View File

@@ -8,8 +8,8 @@
%files = dir(fullfile('../../measurements/2018.06/manfred/LGWatchR/', '*.csv'));
%files = dir(fullfile('../../measurements/2018.06/peter/Huawai/', '*.csv'));
%files = dir(fullfile('../../measurements/2018.06/peter/mSensor/', '*.csv'));
%files = dir(fullfile('../../measurements/2018.06/frank/mSensor/', '*.csv'));
files = dir(fullfile('../../measurements/2018.06/leon/mSensor/', '*.csv'));
files = dir(fullfile('../../measurements/2018.06/frank/mSensorTest/', '*.csv'));
%files = dir(fullfile('../../measurements/2018.06/leon/mSensor/', '*.csv'));
%files_sorted = natsortfiles({files.name});
for file = files'
@@ -98,7 +98,7 @@ for file = files'
%set cur ground truth
if(length(gtData) > 1)
curTimestamp = timestamps(i);
curTimestamp = timestamps(i) - timestamps(1);
while(curTimestamp > gtData(gtIdx,1) && gtIdx < length(gtData))
curGtBpm = gtData(gtIdx,2);
gtIdx = gtIdx + 1;
@@ -106,6 +106,8 @@ for file = files'
else
curGtBpm = gtData;
end
%measure periodicity of window and use axis with best periodicity
[corr_x, lag_x] = xcov(m(i-window_size:i,3), (window_size/2), "coeff");
[corr_y, lag_y] = xcov(m(i-window_size:i,4), (window_size/2), "coeff");
@@ -121,7 +123,7 @@ for file = files'
%distanz zwischen den vektoren nehmen und in eine normale autocorrelation zu packen
%aufpassen wegen der norm, dass die richtung quasi nicht verloren geht.
%https://en.wikipedia.org/wiki/Lp_space
[corr_3D, lag_3D] = distCorr(m(i-window_size:i, 3:5), (round(window_size * 0.8)));
[corr_3D, lag_3D] = distCorr(m(i-window_size:i, 3:5), (window_size/2));
corr_x_pos = corr_x;
corr_y_pos = corr_y;
@@ -150,19 +152,20 @@ for file = files'
idx_y = findFalseDetectedPeaks(idx_y_raw, lag_y, corr_y);
idx_z = findFalseDetectedPeaks(idx_z_raw, lag_z, corr_z);
idx_mag = findFalseDetectedPeaks(idx_mag_raw, lag_mag, corr_mag);
idx_3D = findFalseDetectedPeaks(idx_3D_raw, lag_3D', corr_3D);
%idx_3D = findFalseDetectedPeaks(idx_3D_raw, lag_3D', corr_3D);
idx_3D = idx_3D_raw;
Dwindow = m(i-window_size:i,3);
Dwindow_mean_ts_diff = mean(diff(lag_3D(idx_3D) * sample_rate_ms)); %2.5 ms is the time between two samples at 400hz
Dwindow_mean_bpm = (60000 / (Dwindow_mean_ts_diff));
figure(10);
plot(lag_3D, corr_3D, lag_3D(idx_3D), corr_3D(idx_3D), 'r*', lag_3D(idx_3D_raw), corr_3D(idx_3D_raw), 'g*')
hold ("on")
m_label_ms = strcat(" mean ms: ", num2str(Dwindow_mean_ts_diff));
m_label_bpm = strcat(" mean bpm: ", num2str(Dwindow_mean_bpm));
title(strcat(" ", m_label_ms, " ", m_label_bpm));
hold ("off");
% figure(10);
% plot(lag_3D, corr_3D, lag_3D(idx_3D), corr_3D(idx_3D), 'r*', lag_3D(idx_3D_raw), corr_3D(idx_3D_raw), 'g*')
% hold ("on")
% m_label_ms = strcat(" mean ms: ", num2str(Dwindow_mean_ts_diff));
% m_label_bpm = strcat(" mean bpm: ", num2str(Dwindow_mean_bpm));
% title(strcat(" ", m_label_ms, " ", m_label_bpm));
% hold ("off");
Xwindow = m(i-window_size:i,3);
Xwindow_mean_ts_diff = mean(diff(lag_x(idx_x) * sample_rate_ms)); %2.5 ms is the time between two samples at 400hz