1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
|
From b96b199069e844cbb8dc428adee741fa40e76af8 Mon Sep 17 00:00:00 2001
From: Xiang Xiao <xiangxiao@google.com>
Date: Tue, 10 Sep 2024 14:51:46 -0700
Subject: [PATCH] Implement the GetDependencyHeads method for
DependencyParserModel.
This method runs a TFLite dependency parser model on a string and
returns a vector of dependency head for each token in the string.
Bug: 339037155
Change-Id: I4a89ad43849fd3fa85b82d01d2cad915aa78e174
---
third_party/tensorflow-text/BUILD.gn | 2 +
.../shims/tensorflow/core/lib/core/errors.h | 45 +++++++++++++++++++
.../shims/tensorflow/core/lib/core/status.h | 14 ++++++
.../shims/tensorflow/core/platform/logging.h | 8 ++++
.../tensorflow_text/core/kernels/mst_solver.h | 7 ++-
5 files changed, 72 insertions(+), 4 deletions(-)
create mode 100644 third_party/tensorflow-text/shims/tensorflow/core/lib/core/errors.h
create mode 100644 third_party/tensorflow-text/shims/tensorflow/core/lib/core/status.h
create mode 100644 third_party/tensorflow-text/shims/tensorflow/core/platform/logging.h
diff --git a/third_party/tensorflow-text/BUILD.gn b/third_party/tensorflow-text/BUILD.gn
index 9a8d5fef1c5dc..7c5fa1b37c3b4 100644
--- a/third_party/tensorflow-text/BUILD.gn
+++ b/third_party/tensorflow-text/BUILD.gn
@@ -17,6 +17,8 @@ config("tensorflow-text-flags") {
static_library("tensorflow-text") {
sources = [
+ "src/tensorflow_text/core/kernels/disjoint_set_forest.h",
+ "src/tensorflow_text/core/kernels/mst_solver.h",
"src/tensorflow_text/core/kernels/regex_split.cc",
"src/tensorflow_text/core/kernels/regex_split.h",
"src/tensorflow_text/core/kernels/wordpiece_tokenizer.cc",
diff --git a/third_party/tensorflow-text/shims/tensorflow/core/lib/core/errors.h b/third_party/tensorflow-text/shims/tensorflow/core/lib/core/errors.h
new file mode 100644
index 0000000000000..0475248b26b25
--- /dev/null
+++ b/third_party/tensorflow-text/shims/tensorflow/core/lib/core/errors.h
@@ -0,0 +1,45 @@
+// Copyright 2024 The Chromium Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_ERRORS_H_
+#define THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_ERRORS_H_
+
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+
+#define TF_PREDICT_FALSE(x) (x)
+#define TF_PREDICT_TRUE(x) (x)
+
+// For propagating errors when calling a function.
+#define TF_RETURN_IF_ERROR(...) \
+ do { \
+ ::absl::Status _status = (__VA_ARGS__); \
+ if (TF_PREDICT_FALSE(!_status.ok())) { \
+ return _status; \
+ } \
+ } while (0)
+
+namespace tensorflow::errors {
+
+template <typename Arg1, typename Arg2>
+::absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2) {
+ return absl::Status(absl::StatusCode::kInvalidArgument,
+ absl::StrCat(arg1, arg2));
+}
+
+template <typename Arg1, typename Arg2, typename Arg3, typename Arg4>
+::absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2, Arg3 arg3, Arg4 arg4) {
+ return absl::Status(absl::StatusCode::kInvalidArgument,
+ absl::StrCat(arg1, arg2, arg3, arg4));
+}
+
+template <typename... Args>
+absl::Status FailedPrecondition(Args... args) {
+ return absl::Status(absl::StatusCode::kFailedPrecondition,
+ absl::StrCat(args...));
+}
+
+} // namespace tensorflow::errors
+
+#endif // THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_ERRORS_H_
diff --git a/third_party/tensorflow-text/shims/tensorflow/core/lib/core/status.h b/third_party/tensorflow-text/shims/tensorflow/core/lib/core/status.h
new file mode 100644
index 0000000000000..8b16ba1f64c09
--- /dev/null
+++ b/third_party/tensorflow-text/shims/tensorflow/core/lib/core/status.h
@@ -0,0 +1,14 @@
+// Copyright 2024 The Chromium Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_STATUS_H_
+#define THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_STATUS_H_
+
+#include "absl/status/status.h"
+
+namespace tensorflow {
+using Status = ::absl::Status;
+}
+
+#endif // THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_STATUS_H_
diff --git a/third_party/tensorflow-text/shims/tensorflow/core/platform/logging.h b/third_party/tensorflow-text/shims/tensorflow/core/platform/logging.h
new file mode 100644
index 0000000000000..776917efba118
--- /dev/null
+++ b/third_party/tensorflow-text/shims/tensorflow/core/platform/logging.h
@@ -0,0 +1,8 @@
+// Copyright 2024 The Chromium Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_PLATFORM_LOGGING_H_
+#define THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_PLATFORM_LOGGING_H_
+
+#endif // THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_PLATFORM_LOGGING_H_
diff --git a/third_party/tensorflow-text/src/tensorflow_text/core/kernels/mst_solver.h b/third_party/tensorflow-text/src/tensorflow_text/core/kernels/mst_solver.h
index 818518bc2fdef..50372c5853d27 100644
--- a/third_party/tensorflow-text/src/tensorflow_text/core/kernels/mst_solver.h
+++ b/third_party/tensorflow-text/src/tensorflow_text/core/kernels/mst_solver.h
@@ -503,9 +503,9 @@ void MstSolver<Index, Score>::ContractCycle(Index node) {
for (const auto &node_and_arc : cycle_) {
// Set the |score_offset| to the cost of breaking the cycle by replacing the
// arc currently directed into the |cycle_node|.
- const Index cycle_node = node_and_arc.first;
+ const Index current_cycle_node = node_and_arc.first;
const Score score_offset = -node_and_arc.second->score;
- MergeInboundArcs(cycle_node, score_offset, contracted_node);
+ MergeInboundArcs(current_cycle_node, score_offset, contracted_node);
}
}
@@ -586,7 +586,6 @@ tensorflow::Status MstSolver<Index, Score>::ExpansionPhase(
argmax[target] = arc.source - 1;
}
}
- DCHECK_GE(num_roots, 1);
// Even when |forest_| is false, |num_roots| can still be more than 1. While
// the root score penalty discourages structures with multiple root arcs, it
@@ -595,7 +594,7 @@ tensorflow::Status MstSolver<Index, Score>::ExpansionPhase(
// produce an all-root structure in spite of the root score penalty. As this
// example illustrates, however, |num_roots| will be more than 1 if and only
// if the original digraph is infeasible for trees.
- if (!forest_ && num_roots != 1) {
+ if (num_roots < 1 || (!forest_ && num_roots != 1)) {
return tensorflow::errors::FailedPrecondition("Infeasible digraph");
}
--
2.46.0.662.g92d0881bb0-goog
|