File: tf_runtime.patch

package info (click to toggle)
tensorflow 2.14.1%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 359,396 kB
  • sloc: cpp: 2,418,453; python: 736,954; java: 20,254; ansic: 18,962; sh: 9,279; pascal: 7,941; objc: 1,584; xml: 988; ada: 727; cs: 273; perl: 150; makefile: 92
file content (84 lines) | stat: -rw-r--r-- 3,585 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
Intermittent patch to TFRT to submit a TF/TFRT cross-cutting change.
This patch will be applied only until TF's TFRT commit is automatically bumped.

---

diff --git a/backends/gpu/include/tfrt/gpu/gpu_types.h b/backends/gpu/include/tfrt/gpu/gpu_types.h
index 3d311c3..a216716 100644
--- a/backends/gpu/include/tfrt/gpu/gpu_types.h
+++ b/backends/gpu/include/tfrt/gpu/gpu_types.h
@@ -295,11 +295,7 @@
       wrapper::CurrentContext current, wrapper::Stream stream,
       wrapper::CclComm comm)>;
 
-  explicit GpuCclHandle(AsyncValueRef<GpuContext> context,
-                        wrapper::OwningCclComm comm, int num_ranks);
-  // TODO(hanbinyoon): Remove after transitioning to the above constructor.
-  explicit GpuCclHandle(AsyncValueRef<GpuContext> context,
-                        wrapper::OwningCclComm comm);
+  GpuCclHandle(AsyncValueRef<GpuContext> context, wrapper::OwningCclComm comm);
   ~GpuCclHandle();
 
   GpuCclHandle(GpuCclHandle&&) = default;
@@ -311,8 +307,6 @@
   llvm::Error ExecuteCallbacks(wrapper::CurrentContext current,
                                wrapper::Stream stream);
 
-  int num_ranks() const { return num_ranks_; }
-
   const wrapper::OwningCclComm& operator->() const { return comm_; }
   wrapper::CclComm get() const { return comm_.get(); }
   wrapper::CclComm release();
@@ -322,7 +316,6 @@
  private:
   AsyncValueRef<GpuContext> context_;
   wrapper::OwningCclComm comm_;
-  int num_ranks_;
   std::vector<Callback> callbacks_;
 };
 
diff --git a/backends/gpu/lib/gpu_types.cc b/backends/gpu/lib/gpu_types.cc
index 38529bc..01e3dba 100644
--- a/backends/gpu/lib/gpu_types.cc
+++ b/backends/gpu/lib/gpu_types.cc
@@ -214,15 +214,8 @@
 GpuBlasHandle::~GpuBlasHandle() = default;
 
 GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context,
-                           wrapper::OwningCclComm comm, int num_ranks)
-    : context_(std::move(context)),
-      comm_(std::move(comm)),
-      num_ranks_(num_ranks) {}
-
-// TODO(hanbinyoon): Remove after transitioning to the above constructor.
-GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context,
                            wrapper::OwningCclComm comm)
-    : context_(std::move(context)), comm_(std::move(comm)), num_ranks_(0) {}
+    : context_(std::move(context)), comm_(std::move(comm)) {}
 
 GpuCclHandle::~GpuCclHandle() = default;
 
diff --git a/backends/gpu/lib/kernels/ccl_kernels.cc b/backends/gpu/lib/kernels/ccl_kernels.cc
index 52ce820..9cfc1de 100644
--- a/backends/gpu/lib/kernels/ccl_kernels.cc
+++ b/backends/gpu/lib/kernels/ccl_kernels.cc
@@ -107,8 +107,6 @@
   auto width = ToWidthInBytes(type);
   if (!width) return width.takeError();
   assert(*width != 0);
-  if (input->size() != output->size() * handle->num_ranks())
-    return MakeStringError("Input size must be output size times ranks.");
 
   handle->AddCallback([input = input.ValueRef(), output = output.ValueRef(),
                        recvcount = output->size() / *width, type,
@@ -116,6 +114,10 @@
                           wrapper::CurrentContext current,
                           wrapper::Stream stream,
                           wrapper::CclComm comm) -> llvm::Error {
+    auto count = wrapper::CclCommCount(comm);
+    if (!count) return count.takeError();
+    if (input->size() != output->size() * *count)
+      return MakeStringError("Input size must be output size times ranks.");
     return wrapper::CclReduceScatter(current, input->pointer(),
                                      output->pointer(), recvcount, type, op,
                                      comm, stream);