• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
17 
18 #include <chrono>  // NOLINT (required by TF interfaces)
19 #include <cstdlib>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 
35 namespace xla {
36 namespace gpu {
37 namespace {
38 
RunAllReduce(const NcclAllReduceConfig & config,const std::vector<NcclCollectiveThunk::Buffer> & buffers,const BufferAllocations & buffer_allocations,se::Stream & stream,ncclComm_t comm)39 Status RunAllReduce(const NcclAllReduceConfig& config,
40                     const std::vector<NcclCollectiveThunk::Buffer>& buffers,
41                     const BufferAllocations& buffer_allocations,
42                     se::Stream& stream, ncclComm_t comm) {
43 #if XLA_ENABLE_XCCL
44   int device_ordinal = stream.parent()->device_ordinal();
45   VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal;
46 
47   ncclRedOp_t reduce_op = ToNcclReduction(config.reduction_kind);
48 
49   cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
50       stream.implementation()->GpuStreamMemberHack());
51 
52   XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
53   for (size_t i = 0; i < buffers.size(); ++i) {
54     const NcclCollectiveThunk::Buffer& buffer = buffers[i];
55     const void* send_buffer =
56         buffer_allocations.GetDeviceAddress(buffer.source_buffer).opaque();
57     void* recv_buffer =
58         buffer_allocations.GetDeviceAddress(buffer.destination_buffer).opaque();
59 
60     PrimitiveType element_type = config.config.operand_element_type[i];
61     TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
62                         ToNcclDataTypeAndCountMultiplier(element_type));
63     ncclDataType_t dtype = dtype_and_multiplier.first;
64     int element_count = buffer.element_count * dtype_and_multiplier.second;
65 
66     VLOG(3) << absl::StreamFormat(
67         "Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
68         "comm=%p, stream=%p)",
69         send_buffer, recv_buffer, element_count, static_cast<const void*>(comm),
70         cu_stream);
71 
72     XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer,
73                                            element_count, dtype, reduce_op,
74                                            comm, *cu_stream));
75   }
76   return XLA_CUDA_STATUS(ncclGroupEnd());
77 #else   // XLA_ENABLE_XCCL
78   return Unimplemented(
79       "NCCL support is not available: this binary was not built with a CUDA "
80       "compiler, which is necessary to build the NCCL source library.");
81 #endif  // XLA_ENABLE_XCCL
82 }
83 
IsValidOperand(mlir::Value operand)84 bool IsValidOperand(mlir::Value operand) {
85   Shape shape = TypeToShape(operand.getType());
86   return LayoutUtil::IsDenseArray(shape) &&
87          IsTypeSupportedByNccl(shape.element_type());
88 }
89 
90 // Generally, the reduction op should be the only operation in the block, except
91 // the terminator. However, if the type is bf16, the `BFloat16Normalization`
92 // pass will have converted the op to float32 and added type conversions.
93 // TODO(cjfj): Can we prevent the bf16 conversion for this computation?
FindReductionOp(mlir::Block & block)94 StatusOr<mlir::Operation*> FindReductionOp(mlir::Block& block) {
95   TF_RET_CHECK(block.getNumArguments() == 2);
96   mlir::Operation* terminator = block.getTerminator();
97   TF_RET_CHECK(terminator);
98   TF_RET_CHECK(terminator->getNumOperands() == 1);
99   mlir::Value result = terminator->getOperand(0);
100   TF_RET_CHECK(block.getArgument(0).getType() == result.getType());
101   TF_RET_CHECK(block.getArgument(1).getType() == result.getType());
102 
103   mlir::Operation* result_op = result.getDefiningOp();
104   TF_RET_CHECK(result_op);
105 
106   // In the bf16 case, the type conversions and op might be fused.
107   if (mlir::isa<mlir::mhlo::FusionOp>(result_op)) {
108     return FindReductionOp(result_op->getRegion(0).front());
109   }
110 
111   // Standard case.
112   if (absl::c_is_permutation(result_op->getOperands(), block.getArguments())) {
113     return result_op;
114   }
115 
116   // bf16 case.
117   TF_RET_CHECK(mlir::isa<mlir::mhlo::ConvertOp>(result_op));
118   TF_RET_CHECK(result_op->getNumOperands() == 1);
119   mlir::Operation* reduction_op = result_op->getOperand(0).getDefiningOp();
120   TF_RET_CHECK(reduction_op);
121   TF_RET_CHECK(reduction_op->getNumOperands() == 2);
122   mlir::Value operand0 = reduction_op->getOperand(0);
123   mlir::Value operand1 = reduction_op->getOperand(1);
124   auto operand0_op = operand0.getDefiningOp<mlir::mhlo::ConvertOp>();
125   auto operand1_op = operand1.getDefiningOp<mlir::mhlo::ConvertOp>();
126   TF_RET_CHECK(operand0_op);
127   TF_RET_CHECK(operand1_op);
128   TF_RET_CHECK(operand0_op->getNumOperands() == 1);
129   TF_RET_CHECK(operand1_op->getNumOperands() == 1);
130   std::array<mlir::Value, 2> operands{operand0_op->getOperand(0),
131                                       operand1_op->getOperand(0)};
132   TF_RET_CHECK(absl::c_is_permutation(operands, block.getArguments()));
133   return reduction_op;
134 }
135 
136 }  // namespace
137 
138 namespace impl {
139 
140 template <typename OpT>
CanImplement(OpT op)141 bool CanImplement(OpT op) {
142   return absl::c_all_of(op.operands(), IsValidOperand) &&
143          NcclAllReduceThunkBase::MatchAllReduceComputation(op.computation())
144              .has_value();
145 }
146 
147 template <typename OpT>
GetNcclAllReduceConfig(OpT op)148 NcclAllReduceConfig GetNcclAllReduceConfig(OpT op) {
149   absl::optional<ReductionKind> reduction_kind =
150       NcclAllReduceThunkBase::MatchAllReduceComputation(op.computation());
151   CHECK(reduction_kind.has_value());
152 
153   NcclAllReduceConfig config;
154   config.config =
155       GetNcclCollectiveConfigForMlir(op, op.use_global_device_ids());
156   config.reduction_kind = *reduction_kind;
157   return config;
158 }
159 
160 template <typename OpT>
IsDegenerate(OpT op,int64_t replica_count,int64_t partition_count)161 bool IsDegenerate(OpT op, int64_t replica_count, int64_t partition_count) {
162   return GetNcclCollectiveConfigForMlir(op, op.use_global_device_ids())
163       .IsDegenerate(replica_count, partition_count);
164 }
165 
166 template <typename OpT>
GetGroupMode(OpT op)167 CollectiveOpGroupMode GetGroupMode(OpT op) {
168   return GetNcclAllReduceConfig(op).config.group_mode;
169 }
170 
171 }  // namespace impl
172 
MatchAllReduceComputation(mlir::Region & computation)173 absl::optional<ReductionKind> NcclAllReduceThunkBase::MatchAllReduceComputation(
174     mlir::Region& computation) {
175   mlir::Block& block = computation.front();
176   StatusOr<mlir::Operation*> reduction_op = FindReductionOp(block);
177   if (!reduction_op.ok()) return absl::nullopt;
178   StatusOr<HloOpcode> opcode = MhloToHloOpcode(*reduction_op);
179   if (!opcode.ok()) return absl::nullopt;
180   // Match the operation to a reduction kind. We can represent and/or of pred as
181   // min/max. This works because pred is stored as an 8-bit int of value 0 or 1.
182   PrimitiveType type =
183       TypeToShape(block.getArgument(0).getType()).element_type();
184   if (type == PRED) {
185     switch (opcode.ValueOrDie()) {
186       case HloOpcode::kAnd:
187         return ReductionKind::MIN;
188       case HloOpcode::kOr:
189         return ReductionKind::MAX;
190       default:
191         return absl::nullopt;
192     }
193   } else if (primitive_util::IsComplexType(type)) {
194     // Only addition is supported for complex types.
195     if (*opcode == HloOpcode::kAdd) {
196       return ReductionKind::SUM;
197     } else {
198       return absl::nullopt;
199     }
200   } else {
201     switch (*opcode) {
202       case HloOpcode::kAdd:
203         return ReductionKind::SUM;
204       case HloOpcode::kMultiply:
205         return ReductionKind::PRODUCT;
206       case HloOpcode::kMaximum:
207         return ReductionKind::MAX;
208       case HloOpcode::kMinimum:
209         return ReductionKind::MIN;
210       default:
211         return absl::nullopt;
212     }
213   }
214 }
215 
NcclAllReduceThunkBase(Thunk::Kind kind,ThunkInfo thunk_info,NcclAllReduceConfig config,std::vector<Buffer> buffers)216 NcclAllReduceThunkBase::NcclAllReduceThunkBase(Thunk::Kind kind,
217                                                ThunkInfo thunk_info,
218                                                NcclAllReduceConfig config,
219                                                std::vector<Buffer> buffers)
220     : NcclCollectiveThunk(kind, thunk_info),
221       config_(std::move(config)),
222       buffers_(std::move(buffers)) {
223   CHECK_EQ(config_.config.operand_count, buffers_.size());
224 }
225 
NcclAllReduceThunk(ThunkInfo thunk_info,mlir::lmhlo::AllReduceOp op,std::vector<Buffer> buffers)226 NcclAllReduceThunk::NcclAllReduceThunk(ThunkInfo thunk_info,
227                                        mlir::lmhlo::AllReduceOp op,
228                                        std::vector<Buffer> buffers)
229     : NcclAllReduceThunkBase(Thunk::kNcclAllReduce, thunk_info,
230                              impl::GetNcclAllReduceConfig(op), buffers) {}
231 
CanImplement(mlir::lmhlo::AllReduceOp op)232 bool NcclAllReduceThunk::CanImplement(mlir::lmhlo::AllReduceOp op) {
233   return impl::CanImplement(op);
234 }
235 
IsDegenerate(mlir::lmhlo::AllReduceOp op,int64_t replica_count,int64_t partition_count)236 bool NcclAllReduceThunk::IsDegenerate(mlir::lmhlo::AllReduceOp op,
237                                       int64_t replica_count,
238                                       int64_t partition_count) {
239   return impl::IsDegenerate(op, replica_count, partition_count);
240 }
241 
GetGroupMode(mlir::lmhlo::AllReduceOp op)242 CollectiveOpGroupMode NcclAllReduceThunk::GetGroupMode(
243     mlir::lmhlo::AllReduceOp op) {
244   return impl::GetGroupMode(op);
245 }
246 
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)247 Status NcclAllReduceThunk::RunNcclCollective(const ExecuteParams& params,
248                                              ncclComm_t comm) {
249   se::Stream& stream = *params.stream;
250   TF_RETURN_IF_ERROR(RunAllReduce(config_, buffers_, *params.buffer_allocations,
251                                   stream, comm));
252 
253   int device_ordinal = stream.parent()->device_ordinal();
254   VLOG(3) << "Done performing all-reduce for ordinal: " << device_ordinal;
255   return Status::OK();
256 }
257 
NcclAllReduceStartThunk(ThunkInfo thunk_info,mlir::lmhlo_gpu::AllReduceStartOp op,std::vector<Buffer> buffers)258 NcclAllReduceStartThunk::NcclAllReduceStartThunk(
259     ThunkInfo thunk_info, mlir::lmhlo_gpu::AllReduceStartOp op,
260     std::vector<Buffer> buffers)
261     : NcclAllReduceThunkBase(Thunk::kNcclAllReduceStart, thunk_info,
262                              impl::GetNcclAllReduceConfig(op), buffers) {}
263 
CanImplement(mlir::lmhlo_gpu::AllReduceStartOp op)264 bool NcclAllReduceStartThunk::CanImplement(
265     mlir::lmhlo_gpu::AllReduceStartOp op) {
266   return impl::CanImplement(op);
267 }
268 
IsDegenerate(mlir::lmhlo_gpu::AllReduceStartOp op,int64_t replica_count,int64_t partition_count)269 bool NcclAllReduceStartThunk::IsDegenerate(mlir::lmhlo_gpu::AllReduceStartOp op,
270                                            int64_t replica_count,
271                                            int64_t partition_count) {
272   return impl::IsDegenerate(op, replica_count, partition_count);
273 }
274 
GetGroupMode(mlir::lmhlo_gpu::AllReduceStartOp op)275 CollectiveOpGroupMode NcclAllReduceStartThunk::GetGroupMode(
276     mlir::lmhlo_gpu::AllReduceStartOp op) {
277   return impl::GetGroupMode(op);
278 }
279 
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)280 Status NcclAllReduceStartThunk::RunNcclCollective(const ExecuteParams& params,
281                                                   ncclComm_t comm) {
282   se::Stream& async_comms_stream = *params.async_comms_stream;
283   // Wait until compute inputs are ready.
284   async_comms_stream.ThenWaitFor(params.stream);
285 
286   TF_RETURN_IF_ERROR(RunAllReduce(config_, buffers_, *params.buffer_allocations,
287                                   async_comms_stream, comm));
288 
289   // Create an event on the async stream for the completion of the all-reduce.
290   se::Event done_event(async_comms_stream.parent());
291   TF_RET_CHECK(done_event.Init());
292   async_comms_stream.ThenRecordEvent(&done_event);
293 
294   int device_ordinal = async_comms_stream.parent()->device_ordinal();
295 
296   {
297     absl::MutexLock lock(&mu_);
298     auto result = done_events_.emplace(device_ordinal, std::move(done_event));
299     TF_RET_CHECK(result.second) << "done event has not been consumed";
300   }
301 
302   VLOG(3) << "Done performing all-reduce-start for ordinal: " << device_ordinal;
303   return Status::OK();
304 }
305 
TakeDoneEvent(int device_ordinal)306 StatusOr<se::Event> NcclAllReduceStartThunk::TakeDoneEvent(int device_ordinal) {
307   absl::MutexLock lock(&mu_);
308   auto it = done_events_.find(device_ordinal);
309   TF_RET_CHECK(it != done_events_.end()) << "done event not found";
310   // Take ownership of the event.
311   se::Event done_event = std::move(it->second);
312   done_events_.erase(it);
313   return done_event;
314 }
315 
NcclAllReduceDoneThunk(ThunkInfo thunk_info,NcclAllReduceStartThunk & start_thunk)316 NcclAllReduceDoneThunk::NcclAllReduceDoneThunk(
317     ThunkInfo thunk_info, NcclAllReduceStartThunk& start_thunk)
318     : Thunk(Thunk::kNcclAllReduceDone, thunk_info), start_thunk_(start_thunk) {}
319 
ExecuteOnStream(const ExecuteParams & params)320 Status NcclAllReduceDoneThunk::ExecuteOnStream(const ExecuteParams& params) {
321   int device_ordinal = params.stream->parent()->device_ordinal();
322   TF_ASSIGN_OR_RETURN(se::Event done_event,
323                       start_thunk_.TakeDoneEvent(device_ordinal));
324   params.stream->ThenWaitFor(&done_event);
325   return Status::OK();
326 }
327 
NcclReduceScatterThunk(ThunkInfo thunk_info,mlir::lmhlo::ReduceScatterOp op,std::vector<NcclAllReduceThunk::Buffer> buffers)328 NcclReduceScatterThunk::NcclReduceScatterThunk(
329     ThunkInfo thunk_info, mlir::lmhlo::ReduceScatterOp op,
330     std::vector<NcclAllReduceThunk::Buffer> buffers)
331     : NcclAllReduceThunkBase(Thunk::kNcclReduceScatter, thunk_info,
332                              impl::GetNcclAllReduceConfig(op),
333                              std::move(buffers)) {}
334 
CanImplement(mlir::lmhlo::ReduceScatterOp op)335 /*static*/ bool NcclReduceScatterThunk::CanImplement(
336     mlir::lmhlo::ReduceScatterOp op) {
337   return impl::CanImplement(op);
338 }
339 
IsDegenerate(mlir::lmhlo::ReduceScatterOp op,int64_t replica_count,int64_t partition_count)340 /*static*/ bool NcclReduceScatterThunk::IsDegenerate(
341     mlir::lmhlo::ReduceScatterOp op, int64_t replica_count,
342     int64_t partition_count) {
343   return impl::IsDegenerate(op, replica_count, partition_count);
344 }
345 
GetGroupMode(mlir::lmhlo::ReduceScatterOp op)346 /*static*/ CollectiveOpGroupMode NcclReduceScatterThunk::GetGroupMode(
347     mlir::lmhlo::ReduceScatterOp op) {
348   return impl::GetGroupMode(op);
349 }
350 
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)351 Status NcclReduceScatterThunk::RunNcclCollective(const ExecuteParams& params,
352                                                  ncclComm_t comm) {
353 #if XLA_ENABLE_XCCL
354   int device_ordinal = params.stream->parent()->device_ordinal();
355   VLOG(3) << "Performing reduce-scatter from device ordinal: "
356           << device_ordinal;
357 
358   ncclRedOp_t reduce_op = ToNcclReduction(config_.reduction_kind);
359 
360   cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
361       params.stream->implementation()->GpuStreamMemberHack());
362 
363   int num_participants = 0;
364   XLA_CUDA_RETURN_IF_ERROR(ncclCommCount(comm, &num_participants));
365 
366   XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
367   for (size_t i = 0; i < buffers_.size(); ++i) {
368     const Buffer& buffer = buffers_[i];
369     const void* send_buffer =
370         params.buffer_allocations->GetDeviceAddress(buffer.source_buffer)
371             .opaque();
372     void* recv_buffer =
373         params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer)
374             .opaque();
375 
376     PrimitiveType element_type = config_.config.operand_element_type[i];
377     TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
378                         ToNcclDataTypeAndCountMultiplier(element_type));
379     ncclDataType_t dtype = dtype_and_multiplier.first;
380     int element_count = buffer.element_count * dtype_and_multiplier.second;
381 
382     // buffer.element_count is the source buffers element count. For
383     // ncclReduceScatter, we need the destination buffers element count.
384     TF_RET_CHECK(element_count % num_participants == 0)
385         << "Source buffer was not an exact multiple of the number of "
386            "participants.";
387 
388     int64_t recv_count = element_count / num_participants;
389     VLOG(3) << absl::StreamFormat(
390         "Calling ncclReduceScatter(send_buffer=%p, recv_buffer=%p, "
391         "recvcount=%d, "
392         "comm=%p, stream=%p)",
393         send_buffer, recv_buffer, recv_count, static_cast<const void*>(comm),
394         cu_stream);
395     XLA_CUDA_RETURN_IF_ERROR(ncclReduceScatter(send_buffer, recv_buffer,
396                                                recv_count, dtype, reduce_op,
397                                                comm, *cu_stream));
398   }
399   XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
400 
401   VLOG(3) << "Done performing reduce-scatter for ordinal: " << device_ordinal;
402   return Status::OK();
403 #else   // XLA_ENABLE_XCCL
404   return Unimplemented(
405       "NCCL support is not available: this binary was not built with a CUDA "
406       "compiler, which is necessary to build the NCCL source library.");
407 #endif  // XLA_ENABLE_XCCL
408 }
409 
410 }  // namespace gpu
411 }  // namespace xla
412