From b96b199069e844cbb8dc428adee741fa40e76af8 Mon Sep 17 00:00:00 2001
From: Xiang Xiao <xiangxiao@google.com>
Date: Tue, 10 Sep 2024 14:51:46 -0700
Subject: [PATCH] Implement the GetDependencyHeads method for
 DependencyParserModel.

This method runs a TFLite dependency parser model on a string and
returns a vector of dependency head for each token in the string.

Bug: 339037155
Change-Id: I4a89ad43849fd3fa85b82d01d2cad915aa78e174
---
 third_party/tensorflow-text/BUILD.gn          |  2 +
 .../shims/tensorflow/core/lib/core/errors.h   | 45 +++++++++++++++++++
 .../shims/tensorflow/core/lib/core/status.h   | 14 ++++++
 .../shims/tensorflow/core/platform/logging.h  |  8 ++++
 .../tensorflow_text/core/kernels/mst_solver.h |  7 ++-
 5 files changed, 72 insertions(+), 4 deletions(-)
 create mode 100644 third_party/tensorflow-text/shims/tensorflow/core/lib/core/errors.h
 create mode 100644 third_party/tensorflow-text/shims/tensorflow/core/lib/core/status.h
 create mode 100644 third_party/tensorflow-text/shims/tensorflow/core/platform/logging.h

diff --git a/third_party/tensorflow-text/BUILD.gn b/third_party/tensorflow-text/BUILD.gn
index 9a8d5fef1c5dc..7c5fa1b37c3b4 100644
--- a/third_party/tensorflow-text/BUILD.gn
+++ b/third_party/tensorflow-text/BUILD.gn
@@ -17,6 +17,8 @@ config("tensorflow-text-flags") {
 
 static_library("tensorflow-text") {
   sources = [
+    "src/tensorflow_text/core/kernels/disjoint_set_forest.h",
+    "src/tensorflow_text/core/kernels/mst_solver.h",
     "src/tensorflow_text/core/kernels/regex_split.cc",
     "src/tensorflow_text/core/kernels/regex_split.h",
     "src/tensorflow_text/core/kernels/wordpiece_tokenizer.cc",
diff --git a/third_party/tensorflow-text/shims/tensorflow/core/lib/core/errors.h b/third_party/tensorflow-text/shims/tensorflow/core/lib/core/errors.h
new file mode 100644
index 0000000000000..0475248b26b25
--- /dev/null
+++ b/third_party/tensorflow-text/shims/tensorflow/core/lib/core/errors.h
@@ -0,0 +1,45 @@
+// Copyright 2024 The Chromium Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_ERRORS_H_
+#define THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_ERRORS_H_
+
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+
+#define TF_PREDICT_FALSE(x) (x)
+#define TF_PREDICT_TRUE(x) (x)
+
+// For propagating errors when calling a function.
+#define TF_RETURN_IF_ERROR(...)             \
+  do {                                      \
+    ::absl::Status _status = (__VA_ARGS__); \
+    if (TF_PREDICT_FALSE(!_status.ok())) {  \
+      return _status;                       \
+    }                                       \
+  } while (0)
+
+namespace tensorflow::errors {
+
+template <typename Arg1, typename Arg2>
+::absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2) {
+  return absl::Status(absl::StatusCode::kInvalidArgument,
+                      absl::StrCat(arg1, arg2));
+}
+
+template <typename Arg1, typename Arg2, typename Arg3, typename Arg4>
+::absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2, Arg3 arg3, Arg4 arg4) {
+  return absl::Status(absl::StatusCode::kInvalidArgument,
+                      absl::StrCat(arg1, arg2, arg3, arg4));
+}
+
+template <typename... Args>
+absl::Status FailedPrecondition(Args... args) {
+  return absl::Status(absl::StatusCode::kFailedPrecondition,
+                      absl::StrCat(args...));
+}
+
+}  // namespace tensorflow::errors
+
+#endif  // THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_ERRORS_H_
diff --git a/third_party/tensorflow-text/shims/tensorflow/core/lib/core/status.h b/third_party/tensorflow-text/shims/tensorflow/core/lib/core/status.h
new file mode 100644
index 0000000000000..8b16ba1f64c09
--- /dev/null
+++ b/third_party/tensorflow-text/shims/tensorflow/core/lib/core/status.h
@@ -0,0 +1,14 @@
+// Copyright 2024 The Chromium Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_STATUS_H_
+#define THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_STATUS_H_
+
+#include "absl/status/status.h"
+
+namespace tensorflow {
+using Status = ::absl::Status;
+}
+
+#endif  // THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_STATUS_H_
diff --git a/third_party/tensorflow-text/shims/tensorflow/core/platform/logging.h b/third_party/tensorflow-text/shims/tensorflow/core/platform/logging.h
new file mode 100644
index 0000000000000..776917efba118
--- /dev/null
+++ b/third_party/tensorflow-text/shims/tensorflow/core/platform/logging.h
@@ -0,0 +1,8 @@
+// Copyright 2024 The Chromium Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_PLATFORM_LOGGING_H_
+#define THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_PLATFORM_LOGGING_H_
+
+#endif  // THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_PLATFORM_LOGGING_H_
diff --git a/third_party/tensorflow-text/src/tensorflow_text/core/kernels/mst_solver.h b/third_party/tensorflow-text/src/tensorflow_text/core/kernels/mst_solver.h
index 818518bc2fdef..50372c5853d27 100644
--- a/third_party/tensorflow-text/src/tensorflow_text/core/kernels/mst_solver.h
+++ b/third_party/tensorflow-text/src/tensorflow_text/core/kernels/mst_solver.h
@@ -503,9 +503,9 @@ void MstSolver<Index, Score>::ContractCycle(Index node) {
   for (const auto &node_and_arc : cycle_) {
     // Set the |score_offset| to the cost of breaking the cycle by replacing the
     // arc currently directed into the |cycle_node|.
-    const Index cycle_node = node_and_arc.first;
+    const Index current_cycle_node = node_and_arc.first;
     const Score score_offset = -node_and_arc.second->score;
-    MergeInboundArcs(cycle_node, score_offset, contracted_node);
+    MergeInboundArcs(current_cycle_node, score_offset, contracted_node);
   }
 }
 
@@ -586,7 +586,6 @@ tensorflow::Status MstSolver<Index, Score>::ExpansionPhase(
       argmax[target] = arc.source - 1;
     }
   }
-  DCHECK_GE(num_roots, 1);
 
   // Even when |forest_| is false, |num_roots| can still be more than 1.  While
   // the root score penalty discourages structures with multiple root arcs, it
@@ -595,7 +594,7 @@ tensorflow::Status MstSolver<Index, Score>::ExpansionPhase(
   // produce an all-root structure in spite of the root score penalty.  As this
   // example illustrates, however, |num_roots| will be more than 1 if and only
   // if the original digraph is infeasible for trees.
-  if (!forest_ && num_roots != 1) {
+  if (num_roots < 1 || (!forest_ && num_roots != 1)) {
     return tensorflow::errors::FailedPrecondition("Infeasible digraph");
   }
 
-- 
2.46.0.662.g92d0881bb0-goog

