File: 0001-Implement-the-GetDependencyHeads-method-for-Dependen.patch

package info (click to toggle)
chromium 139.0.7258.127-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 6,122,068 kB
  • sloc: cpp: 35,100,771; ansic: 7,163,530; javascript: 4,103,002; python: 1,436,920; asm: 946,517; xml: 746,709; pascal: 187,653; perl: 88,691; sh: 88,436; objc: 79,953; sql: 51,488; cs: 44,583; fortran: 24,137; makefile: 22,147; tcl: 15,277; php: 13,980; yacc: 8,984; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (156 lines) | stat: -rw-r--r-- 6,875 bytes parent folder | download | duplicates (9)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
From b96b199069e844cbb8dc428adee741fa40e76af8 Mon Sep 17 00:00:00 2001
From: Xiang Xiao <xiangxiao@google.com>
Date: Tue, 10 Sep 2024 14:51:46 -0700
Subject: [PATCH] Implement the GetDependencyHeads method for
 DependencyParserModel.

This method runs a TFLite dependency parser model on a string and
returns a vector of dependency head for each token in the string.

Bug: 339037155
Change-Id: I4a89ad43849fd3fa85b82d01d2cad915aa78e174
---
 third_party/tensorflow-text/BUILD.gn          |  2 +
 .../shims/tensorflow/core/lib/core/errors.h   | 45 +++++++++++++++++++
 .../shims/tensorflow/core/lib/core/status.h   | 14 ++++++
 .../shims/tensorflow/core/platform/logging.h  |  8 ++++
 .../tensorflow_text/core/kernels/mst_solver.h |  7 ++-
 5 files changed, 72 insertions(+), 4 deletions(-)
 create mode 100644 third_party/tensorflow-text/shims/tensorflow/core/lib/core/errors.h
 create mode 100644 third_party/tensorflow-text/shims/tensorflow/core/lib/core/status.h
 create mode 100644 third_party/tensorflow-text/shims/tensorflow/core/platform/logging.h

diff --git a/third_party/tensorflow-text/BUILD.gn b/third_party/tensorflow-text/BUILD.gn
index 9a8d5fef1c5dc..7c5fa1b37c3b4 100644
--- a/third_party/tensorflow-text/BUILD.gn
+++ b/third_party/tensorflow-text/BUILD.gn
@@ -17,6 +17,8 @@ config("tensorflow-text-flags") {
 
 static_library("tensorflow-text") {
   sources = [
+    "src/tensorflow_text/core/kernels/disjoint_set_forest.h",
+    "src/tensorflow_text/core/kernels/mst_solver.h",
     "src/tensorflow_text/core/kernels/regex_split.cc",
     "src/tensorflow_text/core/kernels/regex_split.h",
     "src/tensorflow_text/core/kernels/wordpiece_tokenizer.cc",
diff --git a/third_party/tensorflow-text/shims/tensorflow/core/lib/core/errors.h b/third_party/tensorflow-text/shims/tensorflow/core/lib/core/errors.h
new file mode 100644
index 0000000000000..0475248b26b25
--- /dev/null
+++ b/third_party/tensorflow-text/shims/tensorflow/core/lib/core/errors.h
@@ -0,0 +1,45 @@
+// Copyright 2024 The Chromium Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_ERRORS_H_
+#define THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_ERRORS_H_
+
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+
+#define TF_PREDICT_FALSE(x) (x)
+#define TF_PREDICT_TRUE(x) (x)
+
+// For propagating errors when calling a function.
+#define TF_RETURN_IF_ERROR(...)             \
+  do {                                      \
+    ::absl::Status _status = (__VA_ARGS__); \
+    if (TF_PREDICT_FALSE(!_status.ok())) {  \
+      return _status;                       \
+    }                                       \
+  } while (0)
+
+namespace tensorflow::errors {
+
+template <typename Arg1, typename Arg2>
+::absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2) {
+  return absl::Status(absl::StatusCode::kInvalidArgument,
+                      absl::StrCat(arg1, arg2));
+}
+
+template <typename Arg1, typename Arg2, typename Arg3, typename Arg4>
+::absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2, Arg3 arg3, Arg4 arg4) {
+  return absl::Status(absl::StatusCode::kInvalidArgument,
+                      absl::StrCat(arg1, arg2, arg3, arg4));
+}
+
+template <typename... Args>
+absl::Status FailedPrecondition(Args... args) {
+  return absl::Status(absl::StatusCode::kFailedPrecondition,
+                      absl::StrCat(args...));
+}
+
+}  // namespace tensorflow::errors
+
+#endif  // THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_ERRORS_H_
diff --git a/third_party/tensorflow-text/shims/tensorflow/core/lib/core/status.h b/third_party/tensorflow-text/shims/tensorflow/core/lib/core/status.h
new file mode 100644
index 0000000000000..8b16ba1f64c09
--- /dev/null
+++ b/third_party/tensorflow-text/shims/tensorflow/core/lib/core/status.h
@@ -0,0 +1,14 @@
+// Copyright 2024 The Chromium Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_STATUS_H_
+#define THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_STATUS_H_
+
+#include "absl/status/status.h"
+
+namespace tensorflow {
+using Status = ::absl::Status;
+}
+
+#endif  // THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_STATUS_H_
diff --git a/third_party/tensorflow-text/shims/tensorflow/core/platform/logging.h b/third_party/tensorflow-text/shims/tensorflow/core/platform/logging.h
new file mode 100644
index 0000000000000..776917efba118
--- /dev/null
+++ b/third_party/tensorflow-text/shims/tensorflow/core/platform/logging.h
@@ -0,0 +1,8 @@
+// Copyright 2024 The Chromium Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_PLATFORM_LOGGING_H_
+#define THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_PLATFORM_LOGGING_H_
+
+#endif  // THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_PLATFORM_LOGGING_H_
diff --git a/third_party/tensorflow-text/src/tensorflow_text/core/kernels/mst_solver.h b/third_party/tensorflow-text/src/tensorflow_text/core/kernels/mst_solver.h
index 818518bc2fdef..50372c5853d27 100644
--- a/third_party/tensorflow-text/src/tensorflow_text/core/kernels/mst_solver.h
+++ b/third_party/tensorflow-text/src/tensorflow_text/core/kernels/mst_solver.h
@@ -503,9 +503,9 @@ void MstSolver<Index, Score>::ContractCycle(Index node) {
   for (const auto &node_and_arc : cycle_) {
     // Set the |score_offset| to the cost of breaking the cycle by replacing the
     // arc currently directed into the |cycle_node|.
-    const Index cycle_node = node_and_arc.first;
+    const Index current_cycle_node = node_and_arc.first;
     const Score score_offset = -node_and_arc.second->score;
-    MergeInboundArcs(cycle_node, score_offset, contracted_node);
+    MergeInboundArcs(current_cycle_node, score_offset, contracted_node);
   }
 }
 
@@ -586,7 +586,6 @@ tensorflow::Status MstSolver<Index, Score>::ExpansionPhase(
       argmax[target] = arc.source - 1;
     }
   }
-  DCHECK_GE(num_roots, 1);
 
   // Even when |forest_| is false, |num_roots| can still be more than 1.  While
   // the root score penalty discourages structures with multiple root arcs, it
@@ -595,7 +594,7 @@ tensorflow::Status MstSolver<Index, Score>::ExpansionPhase(
   // produce an all-root structure in spite of the root score penalty.  As this
   // example illustrates, however, |num_roots| will be more than 1 if and only
   // if the original digraph is infeasible for trees.
-  if (!forest_ && num_roots != 1) {
+  if (num_roots < 1 || (!forest_ && num_roots != 1)) {
     return tensorflow::errors::FailedPrecondition("Infeasible digraph");
   }
 
-- 
2.46.0.662.g92d0881bb0-goog