1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
|
From 150ff65e69474c43fe4fb6d9152492f6d3ec614a Mon Sep 17 00:00:00 2001
From: Robert Ogden <robertogden@chromium.org>
Date: Wed, 30 Nov 2022 10:24:16 -0800
Subject: [PATCH 05/11] check cancel flag before calling invoke
---
.../cc/port/default/tflite_wrapper.cc | 14 ++++++++++----
1 file changed, 10 insertions(+), 4 deletions(-)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
index dce66868264cb..70fedca9a3f22 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
@@ -258,8 +258,10 @@ absl::Status TfLiteInterpreterWrapper::InvokeWithFallback(
const std::function<absl::Status(tflite::Interpreter* interpreter)>&
set_inputs) {
RETURN_IF_ERROR(set_inputs(interpreter_.get()));
- // Reset cancel flag before calling `Invoke()`.
- cancel_flag_.Set(false);
+ if (cancel_flag_.Get()) {
+ cancel_flag_.Set(false);
+ return absl::CancelledError("cancelled before Invoke() was called");
+ }
TfLiteStatus status = kTfLiteError;
if (fallback_on_execution_error_) {
status = InterpreterUtils::InvokeWithCPUFallback(interpreter_.get());
@@ -273,6 +275,7 @@ absl::Status TfLiteInterpreterWrapper::InvokeWithFallback(
// Assume the inference is cancelled successfully if Invoke() returns
// kTfLiteError and the cancel flag is `true`.
if (status == kTfLiteError && cancel_flag_.Get()) {
+ cancel_flag_.Set(false);
return absl::CancelledError("Invoke() cancelled.");
}
if (delegate_) {
@@ -289,14 +292,17 @@ absl::Status TfLiteInterpreterWrapper::InvokeWithFallback(
}
absl::Status TfLiteInterpreterWrapper::InvokeWithoutFallback() {
- // Reset cancel flag before calling `Invoke()`.
- cancel_flag_.Set(false);
+ if (cancel_flag_.Get()) {
+ cancel_flag_.Set(false);
+ return absl::CancelledError("cancelled before Invoke() was called");
+ }
TfLiteStatus status = interpreter_->Invoke();
if (status != kTfLiteOk) {
// Assume InvokeWithoutFallback() is guarded under caller's synchronization.
// Assume the inference is cancelled successfully if Invoke() returns
// kTfLiteError and the cancel flag is `true`.
if (status == kTfLiteError && cancel_flag_.Get()) {
+ cancel_flag_.Set(false);
return absl::CancelledError("Invoke() cancelled.");
}
return absl::InternalError("Invoke() failed.");
--
2.49.0.777.g153de2bbd5-goog
|