1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
|
/*
* Copyright 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstddef>
#include <cstdint>
#include <functional>
#include <limits>
#include <vector>
#include <input/Input.h> // for MotionEvent
#include <input/RingBuffer.h>
#include <utils/Timers.h> // for nsecs_t
#include "Eigen/Core"
namespace android {
/**
* Class to handle computing and reporting metrics for MotionPredictor.
*
* The public API provides two methods: `onRecord` and `onPredict`, which expect to receive the
* MotionEvents from the corresponding methods in MotionPredictor.
*
* This class stores AggregatedStrokeMetrics, updating them as new MotionEvents are passed in. When
* onRecord receives an UP or CANCEL event, this indicates the end of the stroke, and the final
* AtomFields are computed and reported to the stats library. The number of atoms reported is equal
* to the value of `maxNumPredictions` passed to the constructor. Each atom corresponds to one
* "prediction time bucket" — the amount of time into the future being predicted.
*
* If mMockLoggedAtomFields is set, the batch of AtomFields that are reported to the stats library
* for one stroke are also stored in mMockLoggedAtomFields at the time they're reported.
*/
class MotionPredictorMetricsManager {
public:
struct AtomFields;
using ReportAtomFunction = std::function<void(const AtomFields&)>;
static void defaultReportAtomFunction(const AtomFields& atomFields);
// Parameters:
// • predictionInterval: the time interval between successive prediction target timestamps.
// Note: the MetricsManager assumes that the input interval equals the prediction interval.
// • maxNumPredictions: the maximum number of distinct target timestamps the prediction model
// will generate predictions for. The MetricsManager reports this many atoms per stroke.
// • [Optional] reportAtomFunction: the function that will be called to report metrics. If
// omitted (or if an empty function is given), the `stats_write(…)` function from the Android
// stats library will be used.
MotionPredictorMetricsManager(
nsecs_t predictionInterval,
size_t maxNumPredictions,
ReportAtomFunction reportAtomFunction = defaultReportAtomFunction);
// This method should be called once for each call to MotionPredictor::record, receiving the
// forwarded MotionEvent argument.
void onRecord(const MotionEvent& inputEvent);
// This method should be called once for each call to MotionPredictor::predict, receiving the
// MotionEvent that will be returned by MotionPredictor::predict.
void onPredict(const MotionEvent& predictionEvent);
// Simple structs to hold relevant touch input information. Public so they can be used in tests.
struct TouchPoint {
Eigen::Vector2f position; // (y, x) in pixels
float pressure;
};
struct GroundTruthPoint : TouchPoint {
nsecs_t timestamp;
};
struct PredictionPoint : TouchPoint {
// The timestamp of the last ground truth point when the prediction was made.
nsecs_t originTimestamp;
nsecs_t targetTimestamp;
// Order by targetTimestamp when sorting.
bool operator<(const PredictionPoint& other) const {
return this->targetTimestamp < other.targetTimestamp;
}
};
// Metrics aggregated so far for the current stroke. These are not the final fields to be
// reported in the atom (see AtomFields below), but rather an intermediate representation of the
// data that can be conveniently aggregated and from which the atom fields can be derived later.
//
// Displacement units are in pixels.
//
// "Along-trajectory error" is the dot product of the prediction error with the unit vector
// pointing towards the ground truth point whose timestamp corresponds to the prediction
// target timestamp, originating from the preceding ground truth point.
//
// "Off-trajectory error" is the component of the prediction error orthogonal to the
// "along-trajectory" unit vector described above.
//
// "High-velocity" errors are errors that are only accumulated when the velocity between the
// most recent two input events exceeds a certain threshold.
//
// "Scale-invariant errors" are the errors produced when the path length of the stroke is
// scaled to 1. (In other words, the error distances are normalized by the path length.)
struct AggregatedStrokeMetrics {
// General errors
float alongTrajectoryErrorSum = 0;
float alongTrajectorySumSquaredErrors = 0;
float offTrajectorySumSquaredErrors = 0;
float pressureSumSquaredErrors = 0;
size_t generalErrorsCount = 0;
// High-velocity errors
float highVelocityAlongTrajectorySse = 0;
float highVelocityOffTrajectorySse = 0;
size_t highVelocityErrorsCount = 0;
// Scale-invariant errors
float scaleInvariantAlongTrajectorySse = 0;
float scaleInvariantOffTrajectorySse = 0;
size_t scaleInvariantErrorsCount = 0;
};
// In order to explicitly indicate "no relevant data" for a metric, we report this
// large-magnitude negative sentinel value. (Most metrics are non-negative, so this value is
// completely unobtainable. For along-trajectory error mean, which can be negative, the
// magnitude makes it unobtainable in practice.)
static const int NO_DATA_SENTINEL = std::numeric_limits<int32_t>::min();
// Final metric values reported in the atom.
struct AtomFields {
int deltaTimeBucketMilliseconds = 0;
// General errors
int alongTrajectoryErrorMeanMillipixels = NO_DATA_SENTINEL;
int alongTrajectoryErrorStdMillipixels = NO_DATA_SENTINEL;
int offTrajectoryRmseMillipixels = NO_DATA_SENTINEL;
int pressureRmseMilliunits = NO_DATA_SENTINEL;
// High-velocity errors
int highVelocityAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
int highVelocityOffTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
// Scale-invariant errors
int scaleInvariantAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
int scaleInvariantOffTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
};
private:
// The interval between consecutive predictions' target timestamps. We assume that the input
// interval also equals this value.
const nsecs_t mPredictionInterval;
// The maximum number of input frames into the future the model can predict.
// Used to perform time-bucketing of metrics.
const size_t mMaxNumPredictions;
// History of mMaxNumPredictions + 1 ground truth points, used to compute scale-invariant
// error. (Also, the last two points are used to compute the ground truth trajectory.)
RingBuffer<GroundTruthPoint> mRecentGroundTruthPoints;
// Predictions having a targetTimestamp after the most recent ground truth point's timestamp.
// Invariant: sorted in ascending order of targetTimestamp.
std::vector<PredictionPoint> mRecentPredictions;
// Containers for the intermediate representation of stroke metrics and the final atom fields.
// These are indexed by the number of input frames into the future being predicted minus one,
// and always have size mMaxNumPredictions.
std::vector<AggregatedStrokeMetrics> mAggregatedMetrics;
std::vector<AtomFields> mAtomFields;
const ReportAtomFunction mReportAtomFunction;
// Helper methods for the implementation of onRecord and onPredict.
// Clears stored ground truth and prediction points, as well as all stored metrics for the
// current stroke.
void clearStrokeData();
// Adds the new ground truth point to mRecentGroundTruths, removes outdated predictions from
// mRecentPredictions, and updates the aggregated metrics to include the recent predictions that
// fuzzily match with the new ground truth point.
void incorporateNewGroundTruth(const GroundTruthPoint& groundTruthPoint);
// Given a new prediction with targetTimestamp matching the latest ground truth point's
// timestamp, computes the corresponding metrics and updates mAggregatedMetrics.
void updateAggregatedMetrics(const PredictionPoint& predictionPoint);
// Computes the atom fields to mAtomFields from the values in mAggregatedMetrics.
void computeAtomFields();
// Reports the current data in mAtomFields by calling mReportAtomFunction.
void reportMetrics();
};
} // namespace android
|