1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
17
18 #include <chrono> // NOLINT (required by TF interfaces)
19 #include <cstdlib>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34
35 namespace xla {
36 namespace gpu {
37 namespace {
38
RunAllReduce(const NcclAllReduceConfig & config,const std::vector<NcclCollectiveThunk::Buffer> & buffers,const BufferAllocations & buffer_allocations,se::Stream & stream,ncclComm_t comm)39 Status RunAllReduce(const NcclAllReduceConfig& config,
40 const std::vector<NcclCollectiveThunk::Buffer>& buffers,
41 const BufferAllocations& buffer_allocations,
42 se::Stream& stream, ncclComm_t comm) {
43 #if XLA_ENABLE_XCCL
44 int device_ordinal = stream.parent()->device_ordinal();
45 VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal;
46
47 ncclRedOp_t reduce_op = ToNcclReduction(config.reduction_kind);
48
49 cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
50 stream.implementation()->GpuStreamMemberHack());
51
52 XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
53 for (size_t i = 0; i < buffers.size(); ++i) {
54 const NcclCollectiveThunk::Buffer& buffer = buffers[i];
55 const void* send_buffer =
56 buffer_allocations.GetDeviceAddress(buffer.source_buffer).opaque();
57 void* recv_buffer =
58 buffer_allocations.GetDeviceAddress(buffer.destination_buffer).opaque();
59
60 PrimitiveType element_type = config.config.operand_element_type[i];
61 TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
62 ToNcclDataTypeAndCountMultiplier(element_type));
63 ncclDataType_t dtype = dtype_and_multiplier.first;
64 int element_count = buffer.element_count * dtype_and_multiplier.second;
65
66 VLOG(3) << absl::StreamFormat(
67 "Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
68 "comm=%p, stream=%p)",
69 send_buffer, recv_buffer, element_count, static_cast<const void*>(comm),
70 cu_stream);
71
72 XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer,
73 element_count, dtype, reduce_op,
74 comm, *cu_stream));
75 }
76 return XLA_CUDA_STATUS(ncclGroupEnd());
77 #else // XLA_ENABLE_XCCL
78 return Unimplemented(
79 "NCCL support is not available: this binary was not built with a CUDA "
80 "compiler, which is necessary to build the NCCL source library.");
81 #endif // XLA_ENABLE_XCCL
82 }
83
IsValidOperand(mlir::Value operand)84 bool IsValidOperand(mlir::Value operand) {
85 Shape shape = TypeToShape(operand.getType());
86 return LayoutUtil::IsDenseArray(shape) &&
87 IsTypeSupportedByNccl(shape.element_type());
88 }
89
90 // Generally, the reduction op should be the only operation in the block, except
91 // the terminator. However, if the type is bf16, the `BFloat16Normalization`
92 // pass will have converted the op to float32 and added type conversions.
93 // TODO(cjfj): Can we prevent the bf16 conversion for this computation?
FindReductionOp(mlir::Block & block)94 StatusOr<mlir::Operation*> FindReductionOp(mlir::Block& block) {
95 TF_RET_CHECK(block.getNumArguments() == 2);
96 mlir::Operation* terminator = block.getTerminator();
97 TF_RET_CHECK(terminator);
98 TF_RET_CHECK(terminator->getNumOperands() == 1);
99 mlir::Value result = terminator->getOperand(0);
100 TF_RET_CHECK(block.getArgument(0).getType() == result.getType());
101 TF_RET_CHECK(block.getArgument(1).getType() == result.getType());
102
103 mlir::Operation* result_op = result.getDefiningOp();
104 TF_RET_CHECK(result_op);
105
106 // In the bf16 case, the type conversions and op might be fused.
107 if (mlir::isa<mlir::mhlo::FusionOp>(result_op)) {
108 return FindReductionOp(result_op->getRegion(0).front());
109 }
110
111 // Standard case.
112 if (absl::c_is_permutation(result_op->getOperands(), block.getArguments())) {
113 return result_op;
114 }
115
116 // bf16 case.
117 TF_RET_CHECK(mlir::isa<mlir::mhlo::ConvertOp>(result_op));
118 TF_RET_CHECK(result_op->getNumOperands() == 1);
119 mlir::Operation* reduction_op = result_op->getOperand(0).getDefiningOp();
120 TF_RET_CHECK(reduction_op);
121 TF_RET_CHECK(reduction_op->getNumOperands() == 2);
122 mlir::Value operand0 = reduction_op->getOperand(0);
123 mlir::Value operand1 = reduction_op->getOperand(1);
124 auto operand0_op = operand0.getDefiningOp<mlir::mhlo::ConvertOp>();
125 auto operand1_op = operand1.getDefiningOp<mlir::mhlo::ConvertOp>();
126 TF_RET_CHECK(operand0_op);
127 TF_RET_CHECK(operand1_op);
128 TF_RET_CHECK(operand0_op->getNumOperands() == 1);
129 TF_RET_CHECK(operand1_op->getNumOperands() == 1);
130 std::array<mlir::Value, 2> operands{operand0_op->getOperand(0),
131 operand1_op->getOperand(0)};
132 TF_RET_CHECK(absl::c_is_permutation(operands, block.getArguments()));
133 return reduction_op;
134 }
135
136 } // namespace
137
138 namespace impl {
139
140 template <typename OpT>
CanImplement(OpT op)141 bool CanImplement(OpT op) {
142 return absl::c_all_of(op.operands(), IsValidOperand) &&
143 NcclAllReduceThunkBase::MatchAllReduceComputation(op.computation())
144 .has_value();
145 }
146
147 template <typename OpT>
GetNcclAllReduceConfig(OpT op)148 NcclAllReduceConfig GetNcclAllReduceConfig(OpT op) {
149 absl::optional<ReductionKind> reduction_kind =
150 NcclAllReduceThunkBase::MatchAllReduceComputation(op.computation());
151 CHECK(reduction_kind.has_value());
152
153 NcclAllReduceConfig config;
154 config.config =
155 GetNcclCollectiveConfigForMlir(op, op.use_global_device_ids());
156 config.reduction_kind = *reduction_kind;
157 return config;
158 }
159
160 template <typename OpT>
IsDegenerate(OpT op,int64_t replica_count,int64_t partition_count)161 bool IsDegenerate(OpT op, int64_t replica_count, int64_t partition_count) {
162 return GetNcclCollectiveConfigForMlir(op, op.use_global_device_ids())
163 .IsDegenerate(replica_count, partition_count);
164 }
165
166 template <typename OpT>
GetGroupMode(OpT op)167 CollectiveOpGroupMode GetGroupMode(OpT op) {
168 return GetNcclAllReduceConfig(op).config.group_mode;
169 }
170
171 } // namespace impl
172
MatchAllReduceComputation(mlir::Region & computation)173 absl::optional<ReductionKind> NcclAllReduceThunkBase::MatchAllReduceComputation(
174 mlir::Region& computation) {
175 mlir::Block& block = computation.front();
176 StatusOr<mlir::Operation*> reduction_op = FindReductionOp(block);
177 if (!reduction_op.ok()) return absl::nullopt;
178 StatusOr<HloOpcode> opcode = MhloToHloOpcode(*reduction_op);
179 if (!opcode.ok()) return absl::nullopt;
180 // Match the operation to a reduction kind. We can represent and/or of pred as
181 // min/max. This works because pred is stored as an 8-bit int of value 0 or 1.
182 PrimitiveType type =
183 TypeToShape(block.getArgument(0).getType()).element_type();
184 if (type == PRED) {
185 switch (opcode.ValueOrDie()) {
186 case HloOpcode::kAnd:
187 return ReductionKind::MIN;
188 case HloOpcode::kOr:
189 return ReductionKind::MAX;
190 default:
191 return absl::nullopt;
192 }
193 } else if (primitive_util::IsComplexType(type)) {
194 // Only addition is supported for complex types.
195 if (*opcode == HloOpcode::kAdd) {
196 return ReductionKind::SUM;
197 } else {
198 return absl::nullopt;
199 }
200 } else {
201 switch (*opcode) {
202 case HloOpcode::kAdd:
203 return ReductionKind::SUM;
204 case HloOpcode::kMultiply:
205 return ReductionKind::PRODUCT;
206 case HloOpcode::kMaximum:
207 return ReductionKind::MAX;
208 case HloOpcode::kMinimum:
209 return ReductionKind::MIN;
210 default:
211 return absl::nullopt;
212 }
213 }
214 }
215
NcclAllReduceThunkBase(Thunk::Kind kind,ThunkInfo thunk_info,NcclAllReduceConfig config,std::vector<Buffer> buffers)216 NcclAllReduceThunkBase::NcclAllReduceThunkBase(Thunk::Kind kind,
217 ThunkInfo thunk_info,
218 NcclAllReduceConfig config,
219 std::vector<Buffer> buffers)
220 : NcclCollectiveThunk(kind, thunk_info),
221 config_(std::move(config)),
222 buffers_(std::move(buffers)) {
223 CHECK_EQ(config_.config.operand_count, buffers_.size());
224 }
225
NcclAllReduceThunk(ThunkInfo thunk_info,mlir::lmhlo::AllReduceOp op,std::vector<Buffer> buffers)226 NcclAllReduceThunk::NcclAllReduceThunk(ThunkInfo thunk_info,
227 mlir::lmhlo::AllReduceOp op,
228 std::vector<Buffer> buffers)
229 : NcclAllReduceThunkBase(Thunk::kNcclAllReduce, thunk_info,
230 impl::GetNcclAllReduceConfig(op), buffers) {}
231
CanImplement(mlir::lmhlo::AllReduceOp op)232 bool NcclAllReduceThunk::CanImplement(mlir::lmhlo::AllReduceOp op) {
233 return impl::CanImplement(op);
234 }
235
IsDegenerate(mlir::lmhlo::AllReduceOp op,int64_t replica_count,int64_t partition_count)236 bool NcclAllReduceThunk::IsDegenerate(mlir::lmhlo::AllReduceOp op,
237 int64_t replica_count,
238 int64_t partition_count) {
239 return impl::IsDegenerate(op, replica_count, partition_count);
240 }
241
GetGroupMode(mlir::lmhlo::AllReduceOp op)242 CollectiveOpGroupMode NcclAllReduceThunk::GetGroupMode(
243 mlir::lmhlo::AllReduceOp op) {
244 return impl::GetGroupMode(op);
245 }
246
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)247 Status NcclAllReduceThunk::RunNcclCollective(const ExecuteParams& params,
248 ncclComm_t comm) {
249 se::Stream& stream = *params.stream;
250 TF_RETURN_IF_ERROR(RunAllReduce(config_, buffers_, *params.buffer_allocations,
251 stream, comm));
252
253 int device_ordinal = stream.parent()->device_ordinal();
254 VLOG(3) << "Done performing all-reduce for ordinal: " << device_ordinal;
255 return Status::OK();
256 }
257
NcclAllReduceStartThunk(ThunkInfo thunk_info,mlir::lmhlo_gpu::AllReduceStartOp op,std::vector<Buffer> buffers)258 NcclAllReduceStartThunk::NcclAllReduceStartThunk(
259 ThunkInfo thunk_info, mlir::lmhlo_gpu::AllReduceStartOp op,
260 std::vector<Buffer> buffers)
261 : NcclAllReduceThunkBase(Thunk::kNcclAllReduceStart, thunk_info,
262 impl::GetNcclAllReduceConfig(op), buffers) {}
263
CanImplement(mlir::lmhlo_gpu::AllReduceStartOp op)264 bool NcclAllReduceStartThunk::CanImplement(
265 mlir::lmhlo_gpu::AllReduceStartOp op) {
266 return impl::CanImplement(op);
267 }
268
IsDegenerate(mlir::lmhlo_gpu::AllReduceStartOp op,int64_t replica_count,int64_t partition_count)269 bool NcclAllReduceStartThunk::IsDegenerate(mlir::lmhlo_gpu::AllReduceStartOp op,
270 int64_t replica_count,
271 int64_t partition_count) {
272 return impl::IsDegenerate(op, replica_count, partition_count);
273 }
274
GetGroupMode(mlir::lmhlo_gpu::AllReduceStartOp op)275 CollectiveOpGroupMode NcclAllReduceStartThunk::GetGroupMode(
276 mlir::lmhlo_gpu::AllReduceStartOp op) {
277 return impl::GetGroupMode(op);
278 }
279
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)280 Status NcclAllReduceStartThunk::RunNcclCollective(const ExecuteParams& params,
281 ncclComm_t comm) {
282 se::Stream& async_comms_stream = *params.async_comms_stream;
283 // Wait until compute inputs are ready.
284 async_comms_stream.ThenWaitFor(params.stream);
285
286 TF_RETURN_IF_ERROR(RunAllReduce(config_, buffers_, *params.buffer_allocations,
287 async_comms_stream, comm));
288
289 // Create an event on the async stream for the completion of the all-reduce.
290 se::Event done_event(async_comms_stream.parent());
291 TF_RET_CHECK(done_event.Init());
292 async_comms_stream.ThenRecordEvent(&done_event);
293
294 int device_ordinal = async_comms_stream.parent()->device_ordinal();
295
296 {
297 absl::MutexLock lock(&mu_);
298 auto result = done_events_.emplace(device_ordinal, std::move(done_event));
299 TF_RET_CHECK(result.second) << "done event has not been consumed";
300 }
301
302 VLOG(3) << "Done performing all-reduce-start for ordinal: " << device_ordinal;
303 return Status::OK();
304 }
305
TakeDoneEvent(int device_ordinal)306 StatusOr<se::Event> NcclAllReduceStartThunk::TakeDoneEvent(int device_ordinal) {
307 absl::MutexLock lock(&mu_);
308 auto it = done_events_.find(device_ordinal);
309 TF_RET_CHECK(it != done_events_.end()) << "done event not found";
310 // Take ownership of the event.
311 se::Event done_event = std::move(it->second);
312 done_events_.erase(it);
313 return done_event;
314 }
315
NcclAllReduceDoneThunk(ThunkInfo thunk_info,NcclAllReduceStartThunk & start_thunk)316 NcclAllReduceDoneThunk::NcclAllReduceDoneThunk(
317 ThunkInfo thunk_info, NcclAllReduceStartThunk& start_thunk)
318 : Thunk(Thunk::kNcclAllReduceDone, thunk_info), start_thunk_(start_thunk) {}
319
ExecuteOnStream(const ExecuteParams & params)320 Status NcclAllReduceDoneThunk::ExecuteOnStream(const ExecuteParams& params) {
321 int device_ordinal = params.stream->parent()->device_ordinal();
322 TF_ASSIGN_OR_RETURN(se::Event done_event,
323 start_thunk_.TakeDoneEvent(device_ordinal));
324 params.stream->ThenWaitFor(&done_event);
325 return Status::OK();
326 }
327
NcclReduceScatterThunk(ThunkInfo thunk_info,mlir::lmhlo::ReduceScatterOp op,std::vector<NcclAllReduceThunk::Buffer> buffers)328 NcclReduceScatterThunk::NcclReduceScatterThunk(
329 ThunkInfo thunk_info, mlir::lmhlo::ReduceScatterOp op,
330 std::vector<NcclAllReduceThunk::Buffer> buffers)
331 : NcclAllReduceThunkBase(Thunk::kNcclReduceScatter, thunk_info,
332 impl::GetNcclAllReduceConfig(op),
333 std::move(buffers)) {}
334
CanImplement(mlir::lmhlo::ReduceScatterOp op)335 /*static*/ bool NcclReduceScatterThunk::CanImplement(
336 mlir::lmhlo::ReduceScatterOp op) {
337 return impl::CanImplement(op);
338 }
339
IsDegenerate(mlir::lmhlo::ReduceScatterOp op,int64_t replica_count,int64_t partition_count)340 /*static*/ bool NcclReduceScatterThunk::IsDegenerate(
341 mlir::lmhlo::ReduceScatterOp op, int64_t replica_count,
342 int64_t partition_count) {
343 return impl::IsDegenerate(op, replica_count, partition_count);
344 }
345
GetGroupMode(mlir::lmhlo::ReduceScatterOp op)346 /*static*/ CollectiveOpGroupMode NcclReduceScatterThunk::GetGroupMode(
347 mlir::lmhlo::ReduceScatterOp op) {
348 return impl::GetGroupMode(op);
349 }
350
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)351 Status NcclReduceScatterThunk::RunNcclCollective(const ExecuteParams& params,
352 ncclComm_t comm) {
353 #if XLA_ENABLE_XCCL
354 int device_ordinal = params.stream->parent()->device_ordinal();
355 VLOG(3) << "Performing reduce-scatter from device ordinal: "
356 << device_ordinal;
357
358 ncclRedOp_t reduce_op = ToNcclReduction(config_.reduction_kind);
359
360 cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
361 params.stream->implementation()->GpuStreamMemberHack());
362
363 int num_participants = 0;
364 XLA_CUDA_RETURN_IF_ERROR(ncclCommCount(comm, &num_participants));
365
366 XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
367 for (size_t i = 0; i < buffers_.size(); ++i) {
368 const Buffer& buffer = buffers_[i];
369 const void* send_buffer =
370 params.buffer_allocations->GetDeviceAddress(buffer.source_buffer)
371 .opaque();
372 void* recv_buffer =
373 params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer)
374 .opaque();
375
376 PrimitiveType element_type = config_.config.operand_element_type[i];
377 TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
378 ToNcclDataTypeAndCountMultiplier(element_type));
379 ncclDataType_t dtype = dtype_and_multiplier.first;
380 int element_count = buffer.element_count * dtype_and_multiplier.second;
381
382 // buffer.element_count is the source buffers element count. For
383 // ncclReduceScatter, we need the destination buffers element count.
384 TF_RET_CHECK(element_count % num_participants == 0)
385 << "Source buffer was not an exact multiple of the number of "
386 "participants.";
387
388 int64_t recv_count = element_count / num_participants;
389 VLOG(3) << absl::StreamFormat(
390 "Calling ncclReduceScatter(send_buffer=%p, recv_buffer=%p, "
391 "recvcount=%d, "
392 "comm=%p, stream=%p)",
393 send_buffer, recv_buffer, recv_count, static_cast<const void*>(comm),
394 cu_stream);
395 XLA_CUDA_RETURN_IF_ERROR(ncclReduceScatter(send_buffer, recv_buffer,
396 recv_count, dtype, reduce_op,
397 comm, *cu_stream));
398 }
399 XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
400
401 VLOG(3) << "Done performing reduce-scatter for ordinal: " << device_ordinal;
402 return Status::OK();
403 #else // XLA_ENABLE_XCCL
404 return Unimplemented(
405 "NCCL support is not available: this binary was not built with a CUDA "
406 "compiler, which is necessary to build the NCCL source library.");
407 #endif // XLA_ENABLE_XCCL
408 }
409
410 } // namespace gpu
411 } // namespace xla
412