1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
|
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <array>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <span>
#include <android-base/mapped_file.h>
#include <input/RingBuffer.h>
#include <utils/Timers.h>
#include <tensorflow/lite/core/api/error_reporter.h>
#include <tensorflow/lite/interpreter.h>
#include <tensorflow/lite/model.h>
#include <tensorflow/lite/signature_runner.h>
namespace android {
struct TfLiteMotionPredictorSample {
// The untransformed AMOTION_EVENT_AXIS_X and AMOTION_EVENT_AXIS_Y of the sample.
struct Point {
float x;
float y;
} position;
// The AMOTION_EVENT_AXIS_PRESSURE, _TILT, and _ORIENTATION.
float pressure;
float tilt;
float orientation;
};
inline TfLiteMotionPredictorSample::Point operator-(const TfLiteMotionPredictorSample::Point& lhs,
const TfLiteMotionPredictorSample::Point& rhs) {
return {.x = lhs.x - rhs.x, .y = lhs.y - rhs.y};
}
class TfLiteMotionPredictorModel;
// Buffer storage for a TfLiteMotionPredictorModel.
class TfLiteMotionPredictorBuffers {
public:
// Creates buffer storage for a model with the given input length.
TfLiteMotionPredictorBuffers(size_t inputLength);
// Adds a motion sample to the buffers.
void pushSample(int64_t timestamp, TfLiteMotionPredictorSample sample);
// Returns true if the buffers are complete enough to generate a prediction.
bool isReady() const {
// Predictions can't be applied unless there are at least two points to determine
// the direction to apply them in.
return mAxisFrom && mAxisTo;
}
// Resets all buffers to their initial state.
void reset();
// Copies the buffers to those of a model for prediction.
void copyTo(TfLiteMotionPredictorModel& model) const;
// Returns the current axis of the buffer's samples. Only valid if isReady().
TfLiteMotionPredictorSample axisFrom() const { return *mAxisFrom; }
TfLiteMotionPredictorSample axisTo() const { return *mAxisTo; }
// Returns the timestamp of the last sample.
int64_t lastTimestamp() const { return mTimestamp; }
private:
int64_t mTimestamp = 0;
RingBuffer<float> mInputR;
RingBuffer<float> mInputPhi;
RingBuffer<float> mInputPressure;
RingBuffer<float> mInputTilt;
RingBuffer<float> mInputOrientation;
// The samples defining the current polar axis.
std::optional<TfLiteMotionPredictorSample> mAxisFrom;
std::optional<TfLiteMotionPredictorSample> mAxisTo;
};
// A TFLite model for generating motion predictions.
class TfLiteMotionPredictorModel {
public:
struct Config {
// The time between predictions.
nsecs_t predictionInterval = 0;
// The noise floor for predictions.
// Distances (r) less than this should be discarded as noise.
float distanceNoiseFloor = 0;
};
// Creates a model from an encoded Flatbuffer model.
static std::unique_ptr<TfLiteMotionPredictorModel> create();
~TfLiteMotionPredictorModel();
// Returns the length of the model's input buffers.
size_t inputLength() const;
// Returns the length of the model's output buffers.
size_t outputLength() const;
const Config& config() const { return mConfig; }
// Executes the model.
// Returns true if the model successfully executed and the output tensors can be read.
bool invoke();
// Returns mutable buffers to the input tensors of inputLength() elements.
std::span<float> inputR();
std::span<float> inputPhi();
std::span<float> inputPressure();
std::span<float> inputOrientation();
std::span<float> inputTilt();
// Returns immutable buffers to the output tensors of identical length. Only valid after a
// successful call to invoke().
std::span<const float> outputR() const;
std::span<const float> outputPhi() const;
std::span<const float> outputPressure() const;
private:
explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model,
Config config);
void allocateTensors();
void attachInputTensors();
void attachOutputTensors();
TfLiteTensor* mInputR = nullptr;
TfLiteTensor* mInputPhi = nullptr;
TfLiteTensor* mInputPressure = nullptr;
TfLiteTensor* mInputTilt = nullptr;
TfLiteTensor* mInputOrientation = nullptr;
const TfLiteTensor* mOutputR = nullptr;
const TfLiteTensor* mOutputPhi = nullptr;
const TfLiteTensor* mOutputPressure = nullptr;
std::unique_ptr<android::base::MappedFile> mFlatBuffer;
std::unique_ptr<tflite::ErrorReporter> mErrorReporter;
std::unique_ptr<tflite::FlatBufferModel> mModel;
std::unique_ptr<tflite::Interpreter> mInterpreter;
tflite::SignatureRunner* mRunner = nullptr;
const Config mConfig = {};
};
} // namespace android
|