From 49042a0cfb6dab977fafa7c8457e6c17d6ac723c Mon Sep 17 00:00:00 2001 From: toni Date: Sun, 27 Jan 2019 10:47:46 +0100 Subject: [PATCH] added ground truth to java method fixed some bugs improved algo and results --- java/src/main/java/Main.java | 97 +++++++++++++++---- .../AccelerometerWindowBuffer.java | 2 +- .../main/java/bpmEstimation/BpmEstimator.java | 77 +++++++++------ .../main/java/bpmEstimation/BpmHistory.java | 8 +- .../bpmEstimation/DistanceCorrelation.java | 3 +- java/src/main/java/bpmEstimation/Peaks.java | 14 +++ java/src/main/java/utilities/Utils.java | 64 +++++++++++- matlab/AutoCorrMethodNew_Watch.m | 27 +++--- 8 files changed, 222 insertions(+), 70 deletions(-) diff --git a/java/src/main/java/Main.java b/java/src/main/java/Main.java index 2bd650b..8d9e979 100644 --- a/java/src/main/java/Main.java +++ b/java/src/main/java/Main.java @@ -8,8 +8,10 @@ import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; +import java.text.ParseException; import java.util.Arrays; import java.util.LinkedList; +import java.util.Vector; import java.util.stream.IntStream; @@ -21,37 +23,61 @@ public class Main { public static void main(String [ ] args) { //File folder = new File("/home/toni/Documents/programme/dirigent/measurements/2017.06/lgWear"); //File folder = new File("/home/toni/Documents/programme/dirigent/measurements/peter_failed"); - File folder = new File("/home/toni/Documents/programme/dirigent/measurements/2018.06/leon/mSensor"); + File folder = new File("/home/toni/Documents/programme/dirigent/measurements/2018.06/frank/mSensorTest"); File[] listOfFiles = folder.listFiles(); Arrays.sort(listOfFiles); //calc results - + BpmHistory historyAll = new BpmHistory(); + BpmHistory historyMag = new BpmHistory(); + BpmHistory history3D = new BpmHistory(); // iterate trough files in measurements folder for (File file : listOfFiles) { if (file.isFile() && file.getName().contains(".csv")) { - AccelerometerWindowBuffer accWindowBuffer = new AccelerometerWindowBuffer(6000, 750); - BpmEstimator bpmEstimator = new BpmEstimator(accWindowBuffer, 0, 5000); + AccelerometerWindowBuffer accWindowBuffer = new AccelerometerWindowBuffer(6000, 875); + BpmEstimator bpmEstimator = new BpmEstimator(accWindowBuffer, 4, 50000); //read the file line by line try (BufferedReader br = new BufferedReader(new FileReader(file))) { //read the first three lines and print out what file it is! String comment = br.readLine(); + String groundTruthLine = br.readLine(); br.readLine(); - String groundTruth = br.readLine(); System.out.println(comment); - //long startTs = Long.parseLong(br.readLine().split(";")[0]); + System.out.println(file.getName()); + //load ground truth file + final long startTs = Long.parseLong(br.readLine().split(";")[0]); + String gtFile = groundTruthLine.substring(groundTruthLine.indexOf(':') + 2); + Utils.GroundTruthData gtData = new Utils.GroundTruthData(); + double gtCurValue = 0d; + + if (gtFile.contains(".csv")) { + try (BufferedReader gtBr = new BufferedReader(new FileReader("../../measurements/2018.06/gt_toni/" + gtFile))) { + for (String gtLine; (gtLine = gtBr.readLine()) != null; ) { + gtData.setValuesFromString(gtLine); + } + } catch (IOException e) { + e.printStackTrace(); + } catch (ParseException e) { + e.printStackTrace(); + } + + } else { + gtData.setSingleBPM(Double.valueOf(gtFile)); + } + + //read sensor measurements line by line for (String line; (line = br.readLine()) != null; ) { // process the line. String[] measurement = line.split(";"); //if linear acc long ts = 0; - if(measurement[1].equals("3")){ + if (measurement[1].equals("3")) { ts = Long.parseLong(measurement[0]); double x = Double.parseDouble(measurement[2]); double y = Double.parseDouble(measurement[3]); @@ -60,19 +86,19 @@ public class Main { } //do calculation stuff - if(accWindowBuffer.isNextWindowReady()){ + if (accWindowBuffer.isNextWindowReady()) { LinkedList bpmList = new LinkedList<>(); // Calculate the BPM for different window sizes double bpm60 = bpmEstimator.estimate(); - double bpm85 = bpmEstimator.estimate(3500, 750); - double bpm110 = bpmEstimator.estimate(2600, 750); - double bpm135 = bpmEstimator.estimate(2000, 750); - double bpm160 = bpmEstimator.estimate(1600,750); - double bpm200 = bpmEstimator.estimate(1200, 750); + double bpm85 = bpmEstimator.estimate(3500, 875); + double bpm110 = bpmEstimator.estimate(2600, 875); + double bpm135 = bpmEstimator.estimate(2000, 875); + double bpm160 = bpmEstimator.estimate(1600, 875); + double bpm200 = bpmEstimator.estimate(1200, 875); - System.out.println("--------------------------------------------------"); + //System.out.println("--------------------------------------------------"); bpmList.add(bpm60); bpmList.add(bpm85); @@ -81,7 +107,8 @@ public class Main { bpmList.add(bpm160); bpmList.add(bpm200); - while(bpmList.remove(Double.valueOf(-1))) {} + while (bpmList.remove(Double.valueOf(-1))) { + } Utils.removeOutliersZScore(bpmList, 3.4); double bpmMean = Utils.mean(bpmList); @@ -91,8 +118,30 @@ public class Main { //double bpmSingle = bpmEstimator.getBestSingleAxis(); double bpmAllAverage = bpmEstimator.getAverageOfAllWindows(); - System.out.println( ts + " all: " + Math.round(bpmMean) + " avg_all: " + Math.round(bpmAllAverage) + " 3D: " + Math.round(bpmDist)); - System.out.println(" "); + //System.out.println( ts + " all: " + Math.round(bpmMean) + " avg_all: " + Math.round(bpmAllAverage) + " 3D: " + Math.round(bpmDist)); + //System.out.println(" "); + + //calc error using ground truth + long curTS = accWindowBuffer.getYongest().ts - startTs; + int idx = 0; + while (curTS > gtData.getTimestamp(idx) && idx < gtData.getSize() - 1) { + ++idx; + } + gtCurValue = gtData.getBPM(idx); + + //fill histories + historyAll.add(bpmAllAverage - gtCurValue); + historyMag.add(magMean - gtCurValue); + history3D.add(bpmDist - gtCurValue); + + + if(true){ + System.out.println("all: " + bpmAllAverage); + System.out.println("mag: " + magMean); + System.out.println("3D: " + bpmDist); + System.out.println("GT: " + gtCurValue); + System.out.println(" "); + } int dummyForBreakpoint = 0; } @@ -104,7 +153,7 @@ public class Main { //System.out.println("MEAN BPM: " + Math.round(meanBPM)); //System.out.println("MEDIAN BPM: " + Math.round(medianBPM)); - if(Utils.DEBUG_MODE){ + if (Utils.DEBUG_MODE) { bpmEstimator.closeDebugWindows(); } @@ -113,10 +162,18 @@ public class Main { e.printStackTrace(); } + //print overall stats for a single data series + System.out.println("all: " + historyAll.getMean() + "(" + historyAll.getStd() + ")"); + System.out.println("mag: " + historyMag.getMean() + "(" + historyMag.getStd() + ")"); + System.out.println(" 3D: " + history3D.getMean() + "(" + history3D.getStd() + ")"); + System.out.println(" "); + + history3D.clear(); + historyMag.clear(); + historyAll.clear(); + } - - // try { // System.in.read(); // } catch (IOException e) { diff --git a/java/src/main/java/bpmEstimation/AccelerometerWindowBuffer.java b/java/src/main/java/bpmEstimation/AccelerometerWindowBuffer.java index f97ea03..85ecfb9 100644 --- a/java/src/main/java/bpmEstimation/AccelerometerWindowBuffer.java +++ b/java/src/main/java/bpmEstimation/AccelerometerWindowBuffer.java @@ -67,7 +67,7 @@ public class AccelerometerWindowBuffer extends ArrayList { public boolean isNextWindowReady(){ if(!isEmpty()){ - if(((getYongest().ts - getOldest().ts) > mWindowSize / 4) && mOverlapCounter > mOverlapSize){ + if(((getYongest().ts - getOldest().ts) > mWindowSize) && mOverlapCounter > mOverlapSize){ mOverlapCounter = 0; return true; diff --git a/java/src/main/java/bpmEstimation/BpmEstimator.java b/java/src/main/java/bpmEstimation/BpmEstimator.java index 7531e6e..928e6e8 100644 --- a/java/src/main/java/bpmEstimation/BpmEstimator.java +++ b/java/src/main/java/bpmEstimation/BpmEstimator.java @@ -118,7 +118,7 @@ public class BpmEstimator { double[] magAutoCorr = new AutoCorrelation(magButter, tmpBuffer.size()).getCorr(); //dist correlation - double[] distCorr = new DistanceCorrelation(interp, (int) (interp.size() * 0.8)).getCorr(); + double[] distCorr = new DistanceCorrelation(interp, (int) (interp.size() / 2)).getCorr(); //find a peak within range of 250 ms int peakWidth = (int) Math.round(250 / sampleRate); @@ -137,15 +137,15 @@ public class BpmEstimator { if(DEBUG_MODE){ - //plotter.setPlotRawX(interp.getTs(), interp.getX()); - //plotter.setPlotRawY(interp.getTs(), interp.getY()); - //plotter.setPlotRawZ(interp.getTs(), interp.getZ()); - //plotter.setPlotRawMag(interp.getTs(), magRaw); + plotter.setPlotRawX(interp.getTs(), interp.getX()); + plotter.setPlotRawY(interp.getTs(), interp.getY()); + plotter.setPlotRawZ(interp.getTs(), interp.getZ()); + plotter.setPlotRawMag(interp.getTs(), magRaw); - //plotter.setPlotButterX(interp.getTs(), xButter); - //plotter.setPlotButterY(interp.getTs(), yButter); - //plotter.setPlotButterZ(interp.getTs(), zButter); - //plotter.setPlotButterMag(interp.getTs(), magButter); + plotter.setPlotButterX(interp.getTs(), xButter); + plotter.setPlotButterY(interp.getTs(), yButter); + plotter.setPlotButterZ(interp.getTs(), zButter); + plotter.setPlotButterMag(interp.getTs(), magButter); plotter.setPlotCorrX(xAutoCorr, xPeaks); plotter.setPlotCorrY(yAutoCorr, yPeaks); @@ -153,16 +153,14 @@ public class BpmEstimator { plotter.setPlotCorrMag(magAutoCorr, magPeaks); plotter.setPlotCorr3D(distCorr, distPeaks); + //printout the current BPM + System.out.println(length_ms + "; x: " + Math.round(xPeaks.getBPM(sampleRate)) + + "; y: " + Math.round(yPeaks.getBPM(sampleRate)) + + "; z: " + Math.round(zPeaks.getBPM(sampleRate)) + + "; mag: " + Math.round(magPeaks.getBPM(sampleRate)) + + "; 3D: " + Math.round(distPeaks.getBPM(sampleRate))); } - //printout the current BPM - System.out.println(length_ms + "; x: " + Math.round(xPeaks.getBPM(sampleRate)) - + "; y: " + Math.round(yPeaks.getBPM(sampleRate)) - + "; z: " + Math.round(zPeaks.getBPM(sampleRate)) - + "; mag: " + Math.round(magPeaks.getBPM(sampleRate)) - + "; 3D: " + Math.round(distPeaks.getBPM(sampleRate))); - - double estimatedBPM = getBestBpmEstimation(xPeaks, yPeaks, zPeaks, magPeaks); if(estimatedBPM != -1){ @@ -198,9 +196,14 @@ public class BpmEstimator { } public double getDistEstimation(){ - BpmHistory tmp = (BpmHistory) mBpmHistory_Dist.clone(); - tmp.removeOutliers(); - return tmp.getMean(); + + if(!mBpmHistory_Dist.isEmpty()){ + BpmHistory tmp = (BpmHistory) mBpmHistory_Dist.clone(); + tmp.removeOutliers(); + return tmp.getMean(); + } else { + return -1; + } } public double getMagnitudeMean(){ @@ -208,9 +211,15 @@ public class BpmEstimator { //Utils.removeOutliersZScore(mBpmHistory_Mag, 3.4); //double mean = Utils.mean(mBpmHistory_Mag); //mBpmHistory_Mag.clear(); - BpmHistory tmp = (BpmHistory) mBpmHistory_Mag.clone(); - tmp.removeOutliers(); - return tmp.getMean(); + + if(!mBpmHistory_Mag.isEmpty()){ + BpmHistory tmp = (BpmHistory) mBpmHistory_Mag.clone(); + tmp.removeOutliers(); + return tmp.getMean(); + } else { + return -1; + } + } public double getMeanBpm(){ @@ -282,17 +291,21 @@ public class BpmEstimator { tmpHistory.add(mBpmHistory_Z); tmpHistory.add(mBpmHistory_Mag); - //remove outliers again - tmpHistory.removeOutliers(); + if(!tmpHistory.isEmpty()){ + //remove outliers again + tmpHistory.removeOutliers(); - //clear - mBpmHistory_X.clear(); - mBpmHistory_Y.clear(); - mBpmHistory_Z.clear(); - mBpmHistory_Mag.clear(); - mBpmHistory_Dist.clear(); + //clear + mBpmHistory_X.clear(); + mBpmHistory_Y.clear(); + mBpmHistory_Z.clear(); + mBpmHistory_Mag.clear(); + mBpmHistory_Dist.clear(); - return tmpHistory.getMean(); + return tmpHistory.getMean(); + } else { + return -1; + } } diff --git a/java/src/main/java/bpmEstimation/BpmHistory.java b/java/src/main/java/bpmEstimation/BpmHistory.java index 12896a7..b3b0e12 100644 --- a/java/src/main/java/bpmEstimation/BpmHistory.java +++ b/java/src/main/java/bpmEstimation/BpmHistory.java @@ -35,7 +35,7 @@ public class BpmHistory extends LinkedList { if(this.size() > 2){ return Utils.mean(this); } else { - return 333; //TODO: das ist natürlich quatsch und faulheit. mal schaun wie man das am besten löst. + return this.getFirst(); //TODO: das ist natürlich quatsch und faulheit. mal schaun wie man das am besten löst. } } @@ -43,10 +43,14 @@ public class BpmHistory extends LinkedList { if(this.size() > 2){ return Utils.var(this); } else { - return 666; //TODO: das ist natürlich quatsch und faulheit. mal schaun wie man das am besten löst. + return 0; //TODO: das ist natürlich quatsch und faulheit. mal schaun wie man das am besten löst. } } + public double getStd(){ + return Utils.stdDev(this); + } + public void removeOutliers(){ Utils.removeOutliersZScore(this, 3.4); } diff --git a/java/src/main/java/bpmEstimation/DistanceCorrelation.java b/java/src/main/java/bpmEstimation/DistanceCorrelation.java index 3b84419..3d49c06 100644 --- a/java/src/main/java/bpmEstimation/DistanceCorrelation.java +++ b/java/src/main/java/bpmEstimation/DistanceCorrelation.java @@ -3,7 +3,6 @@ package bpmEstimation; import utilities.Utils; import java.util.Arrays; -import java.util.Collections; import java.util.DoubleSummaryStatistics; /** @@ -11,6 +10,8 @@ import java.util.DoubleSummaryStatistics; */ public class DistanceCorrelation { + //TODO: remove bad peaks found at the very beginning and end of the signal + private static int mMaxLag; private double[] mCorr; diff --git a/java/src/main/java/bpmEstimation/Peaks.java b/java/src/main/java/bpmEstimation/Peaks.java index 6076749..2aeac8a 100644 --- a/java/src/main/java/bpmEstimation/Peaks.java +++ b/java/src/main/java/bpmEstimation/Peaks.java @@ -125,6 +125,7 @@ public class Peaks { */ public double getBPM(double sampleRate_ms){ + //todo: rückweisungsklasse kann auch hier mit rein. if(hasPeaks()){ @@ -148,6 +149,19 @@ public class Peaks { mPeaksValue.add(mData[idx]); } + + /* + if(hasPeaks()) { + //wir entfernen den ersten und den letzten peak weil die dist correlation + //am anfang und ende oft peaks erkennt, die käse sind und so den fehler in die höhe + //treiben können... + mPeaksPos.removeFirst(); + mPeaksIdx.removeFirst(); + mPeaksValue.removeFirst(); + mPeaksPos.removeLast(); + mPeaksIdx.removeLast(); + mPeaksValue.removeLast(); + }*/ } //TODO: findPeaks method identical to Matlab... with PeakProminence diff --git a/java/src/main/java/utilities/Utils.java b/java/src/main/java/utilities/Utils.java index 190197e..f1058a5 100644 --- a/java/src/main/java/utilities/Utils.java +++ b/java/src/main/java/utilities/Utils.java @@ -7,6 +7,9 @@ import javax.swing.*; import java.awt.*; import java.awt.event.WindowEvent; import java.awt.image.BufferedImage; +import java.text.DateFormat; +import java.text.ParseException; +import java.text.SimpleDateFormat; import java.util.*; import java.util.List; import java.util.stream.IntStream; @@ -14,7 +17,7 @@ import java.util.stream.IntStream; //TODO: change from double to generic type public class Utils { - public static final boolean DEBUG_MODE = true; + public static final boolean DEBUG_MODE = false; public static double getDistance(double x1, double y1, double x2, double y2) { return (double) Math.sqrt((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2)); @@ -96,7 +99,7 @@ public class Utils { double[] diff = new double[data.length - 1]; int i=0; for(int j = 1; j < data.length; ++j){ - diff[i] = data[j] - data[i]; + diff[i] = Math.abs(data[j] - data[i]); ++i; } return diff; @@ -443,4 +446,61 @@ public class Utils { } + public static class GroundTruthData { + + public static class GroundTruthValue { + public Calendar timestamp; + public double bpm; + + public GroundTruthValue(double bpm, Calendar ts){ + this.timestamp = ts; + this.bpm = bpm; + } + } + + private DateFormat format = new SimpleDateFormat("HH:mm:ss.SSS"); + private Vector gtData; + private boolean isFile = true; + + public GroundTruthData(){ + this.gtData = new Vector<>(); + } + + public void setValuesFromString(String val) throws ParseException { + + String[] measurement = val.split(" "); + if(measurement.length < 2){ + throw new RuntimeException("broken Ground Truth format"); + } + + double bpm = Double.valueOf(measurement[0]); + + Calendar date = Calendar.getInstance(); + date.setTime(format.parse(measurement[1])); + + this.gtData.add(new GroundTruthValue(bpm, date)); + } + + public void setSingleBPM(double bpm){ + isFile = true; + this.gtData.add(new GroundTruthValue(bpm, Calendar.getInstance())); + } + + public double getBPM(int idx){ + if(isFile) { + return this.gtData.get(idx).bpm; + } else return this.gtData.get(0).bpm; + } + + public long getTimestamp(int idx){ + if(isFile) { + return 1000 * (60 * this.gtData.get(idx).timestamp.get(Calendar.MINUTE) + this.gtData.get(idx).timestamp.get(Calendar.SECOND)) + this.gtData.get(idx).timestamp.get(Calendar.MILLISECOND); + } else return 0L; + } + + public int getSize(){ + return this.gtData.size(); + } + } + } diff --git a/matlab/AutoCorrMethodNew_Watch.m b/matlab/AutoCorrMethodNew_Watch.m index f5002b5..25043ba 100644 --- a/matlab/AutoCorrMethodNew_Watch.m +++ b/matlab/AutoCorrMethodNew_Watch.m @@ -8,8 +8,8 @@ %files = dir(fullfile('../../measurements/2018.06/manfred/LGWatchR/', '*.csv')); %files = dir(fullfile('../../measurements/2018.06/peter/Huawai/', '*.csv')); %files = dir(fullfile('../../measurements/2018.06/peter/mSensor/', '*.csv')); -%files = dir(fullfile('../../measurements/2018.06/frank/mSensor/', '*.csv')); -files = dir(fullfile('../../measurements/2018.06/leon/mSensor/', '*.csv')); +files = dir(fullfile('../../measurements/2018.06/frank/mSensorTest/', '*.csv')); +%files = dir(fullfile('../../measurements/2018.06/leon/mSensor/', '*.csv')); %files_sorted = natsortfiles({files.name}); for file = files' @@ -98,7 +98,7 @@ for file = files' %set cur ground truth if(length(gtData) > 1) - curTimestamp = timestamps(i); + curTimestamp = timestamps(i) - timestamps(1); while(curTimestamp > gtData(gtIdx,1) && gtIdx < length(gtData)) curGtBpm = gtData(gtIdx,2); gtIdx = gtIdx + 1; @@ -106,6 +106,8 @@ for file = files' else curGtBpm = gtData; end + + %measure periodicity of window and use axis with best periodicity [corr_x, lag_x] = xcov(m(i-window_size:i,3), (window_size/2), "coeff"); [corr_y, lag_y] = xcov(m(i-window_size:i,4), (window_size/2), "coeff"); @@ -121,7 +123,7 @@ for file = files' %distanz zwischen den vektoren nehmen und in eine normale autocorrelation zu packen %aufpassen wegen der norm, dass die richtung quasi nicht verloren geht. %https://en.wikipedia.org/wiki/Lp_space - [corr_3D, lag_3D] = distCorr(m(i-window_size:i, 3:5), (round(window_size * 0.8))); + [corr_3D, lag_3D] = distCorr(m(i-window_size:i, 3:5), (window_size/2)); corr_x_pos = corr_x; corr_y_pos = corr_y; @@ -150,19 +152,20 @@ for file = files' idx_y = findFalseDetectedPeaks(idx_y_raw, lag_y, corr_y); idx_z = findFalseDetectedPeaks(idx_z_raw, lag_z, corr_z); idx_mag = findFalseDetectedPeaks(idx_mag_raw, lag_mag, corr_mag); - idx_3D = findFalseDetectedPeaks(idx_3D_raw, lag_3D', corr_3D); + %idx_3D = findFalseDetectedPeaks(idx_3D_raw, lag_3D', corr_3D); + idx_3D = idx_3D_raw; Dwindow = m(i-window_size:i,3); Dwindow_mean_ts_diff = mean(diff(lag_3D(idx_3D) * sample_rate_ms)); %2.5 ms is the time between two samples at 400hz Dwindow_mean_bpm = (60000 / (Dwindow_mean_ts_diff)); - figure(10); - plot(lag_3D, corr_3D, lag_3D(idx_3D), corr_3D(idx_3D), 'r*', lag_3D(idx_3D_raw), corr_3D(idx_3D_raw), 'g*') - hold ("on") - m_label_ms = strcat(" mean ms: ", num2str(Dwindow_mean_ts_diff)); - m_label_bpm = strcat(" mean bpm: ", num2str(Dwindow_mean_bpm)); - title(strcat(" ", m_label_ms, " ", m_label_bpm)); - hold ("off"); +% figure(10); +% plot(lag_3D, corr_3D, lag_3D(idx_3D), corr_3D(idx_3D), 'r*', lag_3D(idx_3D_raw), corr_3D(idx_3D_raw), 'g*') +% hold ("on") +% m_label_ms = strcat(" mean ms: ", num2str(Dwindow_mean_ts_diff)); +% m_label_bpm = strcat(" mean bpm: ", num2str(Dwindow_mean_bpm)); +% title(strcat(" ", m_label_ms, " ", m_label_bpm)); +% hold ("off"); Xwindow = m(i-window_size:i,3); Xwindow_mean_ts_diff = mean(diff(lag_x(idx_x) * sample_rate_ms)); %2.5 ms is the time between two samples at 400hz