Intermittent patch to TFRT to submit a TF/TFRT cross-cutting change. This patch will be applied only until TF's TFRT commit is automatically bumped. --- diff --git a/backends/gpu/include/tfrt/gpu/gpu_types.h b/backends/gpu/include/tfrt/gpu/gpu_types.h index 3d311c3..a216716 100644 --- a/backends/gpu/include/tfrt/gpu/gpu_types.h +++ b/backends/gpu/include/tfrt/gpu/gpu_types.h @@ -295,11 +295,7 @@ wrapper::CurrentContext current, wrapper::Stream stream, wrapper::CclComm comm)>; - explicit GpuCclHandle(AsyncValueRef context, - wrapper::OwningCclComm comm, int num_ranks); - // TODO(hanbinyoon): Remove after transitioning to the above constructor. - explicit GpuCclHandle(AsyncValueRef context, - wrapper::OwningCclComm comm); + GpuCclHandle(AsyncValueRef context, wrapper::OwningCclComm comm); ~GpuCclHandle(); GpuCclHandle(GpuCclHandle&&) = default; @@ -311,8 +307,6 @@ llvm::Error ExecuteCallbacks(wrapper::CurrentContext current, wrapper::Stream stream); - int num_ranks() const { return num_ranks_; } - const wrapper::OwningCclComm& operator->() const { return comm_; } wrapper::CclComm get() const { return comm_.get(); } wrapper::CclComm release(); @@ -322,7 +316,6 @@ private: AsyncValueRef context_; wrapper::OwningCclComm comm_; - int num_ranks_; std::vector callbacks_; }; diff --git a/backends/gpu/lib/gpu_types.cc b/backends/gpu/lib/gpu_types.cc index 38529bc..01e3dba 100644 --- a/backends/gpu/lib/gpu_types.cc +++ b/backends/gpu/lib/gpu_types.cc @@ -214,15 +214,8 @@ GpuBlasHandle::~GpuBlasHandle() = default; GpuCclHandle::GpuCclHandle(AsyncValueRef context, - wrapper::OwningCclComm comm, int num_ranks) - : context_(std::move(context)), - comm_(std::move(comm)), - num_ranks_(num_ranks) {} - -// TODO(hanbinyoon): Remove after transitioning to the above constructor. -GpuCclHandle::GpuCclHandle(AsyncValueRef context, wrapper::OwningCclComm comm) - : context_(std::move(context)), comm_(std::move(comm)), num_ranks_(0) {} + : context_(std::move(context)), comm_(std::move(comm)) {} GpuCclHandle::~GpuCclHandle() = default; diff --git a/backends/gpu/lib/kernels/ccl_kernels.cc b/backends/gpu/lib/kernels/ccl_kernels.cc index 52ce820..9cfc1de 100644 --- a/backends/gpu/lib/kernels/ccl_kernels.cc +++ b/backends/gpu/lib/kernels/ccl_kernels.cc @@ -107,8 +107,6 @@ auto width = ToWidthInBytes(type); if (!width) return width.takeError(); assert(*width != 0); - if (input->size() != output->size() * handle->num_ranks()) - return MakeStringError("Input size must be output size times ranks."); handle->AddCallback([input = input.ValueRef(), output = output.ValueRef(), recvcount = output->size() / *width, type, @@ -116,6 +114,10 @@ wrapper::CurrentContext current, wrapper::Stream stream, wrapper::CclComm comm) -> llvm::Error { + auto count = wrapper::CclCommCount(comm); + if (!count) return count.takeError(); + if (input->size() != output->size() * *count) + return MakeStringError("Input size must be output size times ranks."); return wrapper::CclReduceScatter(current, input->pointer(), output->pointer(), recvcount, type, op, comm, stream);