1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
|
From d8ed662f7f2e1de66a98e1f8d34ccda73c092317 Mon Sep 17 00:00:00 2001
From: Robert Ogden <robertogden@chromium.org>
Date: Fri, 14 Jun 2024 12:56:17 -0700
Subject: [PATCH] add int8 quantization support to seq_flow_lite ops
---
.../tflite_ops/sequence_string_projection.cc | 32 +++++-
.../tflite_ops/tflite_qrnn_pooling.cc | 99 +++++++++++++++----
2 files changed, 107 insertions(+), 24 deletions(-)
diff --git a/third_party/tensorflow_models/src/research/seq_flow_lite/tflite_ops/sequence_string_projection.cc b/third_party/tensorflow_models/src/research/seq_flow_lite/tflite_ops/sequence_string_projection.cc
index f53cf650f16f8..22f4c609e86c9 100644
--- a/third_party/tensorflow_models/src/research/seq_flow_lite/tflite_ops/sequence_string_projection.cc
+++ b/third_party/tensorflow_models/src/research/seq_flow_lite/tflite_ops/sequence_string_projection.cc
@@ -28,8 +28,8 @@ limitations under the License.
#include "flatbuffers/flexbuffers.h" // flatbuffer
#include "tensorflow/lite/string_util.h"
#include "tf_ops/projection_normalizer_util.h" // seq_flow_lite
-#include "tf_ops/projection_util.h" // seq_flow_lite
-#include "tflite_ops/quantization_util.h" // seq_flow_lite
+#include "tf_ops/projection_util.h" // seq_flow_lite
+#include "tflite_ops/quantization_util.h" // seq_flow_lite
namespace seq_flow_lite {
namespace ops {
@@ -144,8 +144,17 @@ class ProjectionParams {
void WordNoveltyFeature(uint8_t* data, int word_count) const {
float word_novelty_feature;
WordNoveltyFeature(&word_novelty_feature, word_count);
- *data = PodQuantize<uint8_t>(word_novelty_feature, 127.0f, 127);
+ *data = PodQuantize<uint8_t>(word_novelty_feature, /*zero_point=*/127,
+ /*inverse_scale=*/127.0f);
}
+
+ void WordNoveltyFeature(int8_t* data, int word_count) const {
+ float word_novelty_feature;
+ WordNoveltyFeature(&word_novelty_feature, word_count);
+ *data = PodQuantize<int8_t>(word_novelty_feature, /*zero_point=*/0,
+ /*inverse_scale=*/127.5f);
+ }
+
bool DocSizeFeatureEnabled() const { return (doc_size_levels_ != 0); }
bool FirstCap() const { return add_first_cap_feature_; }
bool AllCaps() const { return add_all_caps_feature_; }
@@ -161,8 +170,17 @@ class ProjectionParams {
void DocSizeFeature(uint8_t* data, int num_tokens) {
float doc_size_feature;
DocSizeFeature(&doc_size_feature, num_tokens);
- *data = PodQuantize<uint8_t>(doc_size_feature, 127.0f, 127);
+ *data = PodQuantize<uint8_t>(doc_size_feature, /*zero_point=*/127,
+ /*inverse_scale=*/127.0f);
}
+
+ void DocSizeFeature(int8_t* data, int num_tokens) {
+ float doc_size_feature;
+ DocSizeFeature(&doc_size_feature, num_tokens);
+ *data = PodQuantize<int8_t>(doc_size_feature, /*zero_point=*/0,
+ /*inverse_scale=*/127.5f);
+ }
+
void Hash(const std::string& word, std::vector<uint64_t>& hash_codes) {
hasher_->GetHashCodes(word, hash_codes);
}
@@ -473,11 +491,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (output->type == kTfLiteUInt8) {
const uint8_t kMappingTable[1 << kMapBits] = {127, 255, 0, 127};
TypedEval(kMappingTable, params, output->data.uint8);
+ } else if (output->type == kTfLiteInt8) {
+ const int8_t kMappingTable[1 << kMapBits] = {0, 127, -128, 0};
+ TypedEval(kMappingTable, params, output->data.int8);
} else if (output->type == kTfLiteFloat32) {
const float kMappingTable[1 << kMapBits] = {0.0, 1.0, -1.0, 0.0};
TypedEval(kMappingTable, params, output->data.f);
} else {
- context->ReportError(context, "Output type must be UInt8 or Float32.");
+ context->ReportError(context,
+ "Output type must be Int8, UInt8, or Float32.");
return kTfLiteError;
}
diff --git a/third_party/tensorflow_models/src/research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.cc b/third_party/tensorflow_models/src/research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.cc
index 6641234c53d4b..205e8a619e5c2 100644
--- a/third_party/tensorflow_models/src/research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.cc
+++ b/third_party/tensorflow_models/src/research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.cc
@@ -22,7 +22,9 @@ namespace custom {
namespace {
-const uint8_t kPoolingForward = 255;
+const uint8_t kPoolingUInt8Forward = 255;
+const int8_t kPoolingInt8Forward = 127;
+const float kPoolingFloatForward = 1.0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 3);
@@ -34,9 +36,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* constant = &context->tensors[node->inputs->data[1]];
TfLiteTensor* direction = &context->tensors[node->inputs->data[2]];
- TF_LITE_ENSURE_EQ(context, multiplier->type, kTfLiteUInt8);
- TF_LITE_ENSURE_EQ(context, constant->type, kTfLiteUInt8);
- TF_LITE_ENSURE_EQ(context, direction->type, kTfLiteUInt8);
+ TF_LITE_ENSURE(context, (constant->type == kTfLiteUInt8) ||
+ (constant->type == kTfLiteInt8) ||
+ (constant->type == kTfLiteFloat32));
+ TF_LITE_ENSURE_TYPES_EQ(context, multiplier->type, constant->type);
+ TF_LITE_ENSURE_TYPES_EQ(context, direction->type, constant->type);
TF_LITE_ENSURE_EQ(context, multiplier->dims->size, 3);
TF_LITE_ENSURE_EQ(context, multiplier->dims->data[0], 1);
@@ -71,9 +75,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-TfLiteStatus QRNNPooling(TfLiteContext* context, TfLiteTensor* multiplier,
- TfLiteTensor* constant, TfLiteTensor* outputs,
- TfLiteTensor* final_state, bool forward) {
+template <typename T>
+TfLiteStatus QRNNPooling(TfLiteContext* context,
+ TfLiteTensor* multiplier,
+ TfLiteTensor* constant,
+ TfLiteTensor* outputs,
+ TfLiteTensor* final_state,
+ bool forward) {
const int time_steps = multiplier->dims->data[1];
const int state_size = multiplier->dims->data[2];
@@ -82,28 +90,67 @@ TfLiteStatus QRNNPooling(TfLiteContext* context, TfLiteTensor* multiplier,
const int32_t out_zero_point = outputs ? outputs->params.zero_point : 0;
const float out_inverse_scale = outputs ? 1.0f / outputs->params.scale : 1.0f;
- uint8_t* out_ptr = outputs ? outputs->data.uint8 : nullptr;
+ T* outputs_ptr = outputs ? tflite::GetTensorData<T>(outputs) : nullptr;
for (int i = 0; i < time_steps; ++i) {
for (int j = 0; j < state_size; ++j) {
const int time_index = forward ? i : time_steps - (i + 1);
const int index = time_index * state_size + j;
- float multiplier_value = PodDequantize<uint8_t>(*multiplier, index);
- float constant_vale = PodDequantize<uint8_t>(*constant, index);
- state[j] = state[j] * multiplier_value + constant_vale;
+ float multiplier_value = PodDequantize<T>(*multiplier, index);
+ float constant_value = PodDequantize<T>(*constant, index);
+ state[j] = state[j] * multiplier_value + constant_value;
if (outputs) {
- out_ptr[index] =
- PodQuantize<uint8_t>(state[j], out_zero_point, out_inverse_scale);
+ outputs_ptr[index] =
+ PodQuantize<T>(state[j], out_zero_point, out_inverse_scale);
}
}
}
if (final_state) {
- uint8_t* final_state_ptr = final_state->data.uint8;
+ T* final_state_ptr = tflite::GetTensorData<T>(final_state);
const int32_t zero_point = final_state->params.zero_point;
const float inverse_scale = 1.0f / final_state->params.scale;
for (int j = 0; j < state_size; ++j) {
- final_state_ptr[j] =
- PodQuantize<uint8_t>(state[j], zero_point, inverse_scale);
+ final_state_ptr[j] = PodQuantize<T>(state[j], zero_point, inverse_scale);
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+template <>
+TfLiteStatus QRNNPooling<float>(TfLiteContext* context,
+ TfLiteTensor* multiplier,
+ TfLiteTensor* constant,
+ TfLiteTensor* outputs,
+ TfLiteTensor* final_state,
+ bool forward) {
+ const int time_steps = multiplier->dims->data[1];
+ const int state_size = multiplier->dims->data[2];
+
+ auto state = std::make_unique<float[]>(state_size);
+ memset(state.get(), 0, sizeof(float) * state_size);
+
+ float* multiplier_ptr = tflite::GetTensorData<float>(multiplier);
+ float* constant_ptr = tflite::GetTensorData<float>(constant);
+ float* outputs_ptr =
+ outputs ? tflite::GetTensorData<float>(outputs) : nullptr;
+ for (int i = 0; i < time_steps; ++i) {
+ for (int j = 0; j < state_size; ++j) {
+ const int time_index = forward ? i : time_steps - (i + 1);
+ const int index = time_index * state_size + j;
+ float multiplier_value = multiplier_ptr[index];
+ float constant_value = constant_ptr[index];
+ state[j] = state[j] * multiplier_value + constant_value;
+ if (outputs) {
+ outputs_ptr[index] = state[j];
+ }
+ }
+ }
+
+ if (final_state) {
+ float* final_state_ptr = tflite::GetTensorData<float>(final_state);
+ for (int j = 0; j < state_size; ++j) {
+ final_state_ptr[j] = state[j];
}
}
@@ -126,10 +173,24 @@ TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) {
// When pooling forward the direction parameter is expected to be
// kPoolingForward.
- return QRNNPooling(context, multiplier, constant, outputs, final_state,
- (direction->data.uint8[0] == kPoolingForward));
+ switch (multiplier->type) {
+ case kTfLiteUInt8:
+ return QRNNPooling<uint8_t>(
+ context, multiplier, constant, outputs, final_state,
+ (tflite::GetTensorData<uint8_t>(direction)[0] ==
+ kPoolingUInt8Forward));
+ case kTfLiteInt8:
+ return QRNNPooling<int8_t>(
+ context, multiplier, constant, outputs, final_state,
+ (tflite::GetTensorData<int8_t>(direction)[0] == kPoolingInt8Forward));
+ case kTfLiteFloat32:
+ return QRNNPooling<float>(
+ context, multiplier, constant, outputs, final_state,
+ (tflite::GetTensorData<float>(direction)[0] == kPoolingFloatForward));
+ default:
+ return kTfLiteError;
+ }
}
-
} // namespace
const char kPoolingOp[] = "PoolingOp";
--
2.45.2.627.g7a2c4fd464-goog
|