• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1Intermittent patch to TFRT to submit a TF/TFRT cross-cutting change.
2This patch will be applied only until TF's TFRT commit is automatically bumped.
3
4---
5
6diff --git a/backends/gpu/include/tfrt/gpu/gpu_types.h b/backends/gpu/include/tfrt/gpu/gpu_types.h
7index 3d311c3..a216716 100644
8--- a/backends/gpu/include/tfrt/gpu/gpu_types.h
9+++ b/backends/gpu/include/tfrt/gpu/gpu_types.h
10@@ -295,11 +295,7 @@
11       wrapper::CurrentContext current, wrapper::Stream stream,
12       wrapper::CclComm comm)>;
13
14-  explicit GpuCclHandle(AsyncValueRef<GpuContext> context,
15-                        wrapper::OwningCclComm comm, int num_ranks);
16-  // TODO(hanbinyoon): Remove after transitioning to the above constructor.
17-  explicit GpuCclHandle(AsyncValueRef<GpuContext> context,
18-                        wrapper::OwningCclComm comm);
19+  GpuCclHandle(AsyncValueRef<GpuContext> context, wrapper::OwningCclComm comm);
20   ~GpuCclHandle();
21
22   GpuCclHandle(GpuCclHandle&&) = default;
23@@ -311,8 +307,6 @@
24   llvm::Error ExecuteCallbacks(wrapper::CurrentContext current,
25                                wrapper::Stream stream);
26
27-  int num_ranks() const { return num_ranks_; }
28-
29   const wrapper::OwningCclComm& operator->() const { return comm_; }
30   wrapper::CclComm get() const { return comm_.get(); }
31   wrapper::CclComm release();
32@@ -322,7 +316,6 @@
33  private:
34   AsyncValueRef<GpuContext> context_;
35   wrapper::OwningCclComm comm_;
36-  int num_ranks_;
37   std::vector<Callback> callbacks_;
38 };
39
40diff --git a/backends/gpu/lib/gpu_types.cc b/backends/gpu/lib/gpu_types.cc
41index 38529bc..01e3dba 100644
42--- a/backends/gpu/lib/gpu_types.cc
43+++ b/backends/gpu/lib/gpu_types.cc
44@@ -214,15 +214,8 @@
45 GpuBlasHandle::~GpuBlasHandle() = default;
46
47 GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context,
48-                           wrapper::OwningCclComm comm, int num_ranks)
49-    : context_(std::move(context)),
50-      comm_(std::move(comm)),
51-      num_ranks_(num_ranks) {}
52-
53-// TODO(hanbinyoon): Remove after transitioning to the above constructor.
54-GpuCclHandle::GpuCclHandle(AsyncValueRef<GpuContext> context,
55                            wrapper::OwningCclComm comm)
56-    : context_(std::move(context)), comm_(std::move(comm)), num_ranks_(0) {}
57+    : context_(std::move(context)), comm_(std::move(comm)) {}
58
59 GpuCclHandle::~GpuCclHandle() = default;
60
61diff --git a/backends/gpu/lib/kernels/ccl_kernels.cc b/backends/gpu/lib/kernels/ccl_kernels.cc
62index 52ce820..9cfc1de 100644
63--- a/backends/gpu/lib/kernels/ccl_kernels.cc
64+++ b/backends/gpu/lib/kernels/ccl_kernels.cc
65@@ -107,8 +107,6 @@
66   auto width = ToWidthInBytes(type);
67   if (!width) return width.takeError();
68   assert(*width != 0);
69-  if (input->size() != output->size() * handle->num_ranks())
70-    return MakeStringError("Input size must be output size times ranks.");
71
72   handle->AddCallback([input = input.ValueRef(), output = output.ValueRef(),
73                        recvcount = output->size() / *width, type,
74@@ -116,6 +114,10 @@
75                           wrapper::CurrentContext current,
76                           wrapper::Stream stream,
77                           wrapper::CclComm comm) -> llvm::Error {
78+    auto count = wrapper::CclCommCount(comm);
79+    if (!count) return count.takeError();
80+    if (input->size() != output->size() * *count)
81+      return MakeStringError("Input size must be output size times ranks.");
82     return wrapper::CclReduceScatter(current, input->pointer(),
83                                      output->pointer(), recvcount, type, op,
84                                      comm, stream);
85