• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
17 
18 #include <chrono>  // NOLINT (required by TF interfaces)
19 #include <cstdlib>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 
35 namespace xla {
36 namespace gpu {
37 
38 // Attempts to match computation to one of the possible cases in ReductionKind.
MatchReductionComputation(mlir::lmhlo::AllReduceOp op)39 static absl::optional<ReductionKind> MatchReductionComputation(
40     mlir::lmhlo::AllReduceOp op) {
41   mlir::Block& block = op.computation().front();
42   if (!llvm::hasSingleElement(block.without_terminator())) return absl::nullopt;
43   // The single operation should use both block arguments and produce a single
44   // result (all of the same type)
45   mlir::Operation* reduction_op = &block.front();
46   if (reduction_op->getNumOperands() != 2 || reduction_op->getNumResults() != 1)
47     return absl::nullopt;
48   mlir::BlockArgument arg0 =
49       reduction_op->getOperand(0).dyn_cast<mlir::BlockArgument>();
50   mlir::BlockArgument arg1 =
51       reduction_op->getOperand(1).dyn_cast<mlir::BlockArgument>();
52   mlir::OpResult result = reduction_op->getResult(0);
53   // Both operands should be block arguments of the reduction computation block
54   // and be different arguments of that block.
55   if (!arg0 || !arg1 || arg0.getOwner() != &block ||
56       arg1.getOwner() != &block || arg0 == arg1 ||
57       arg0.getType() != arg1.getType() || arg0.getType() != result.getType())
58     return absl::nullopt;
59   StatusOr<HloOpcode> opcode = MhloToHloOpcode(reduction_op);
60   if (!opcode.ok()) return absl::nullopt;
61   // Match the operation to a reduction kind. We can represent and/or of pred as
62   // min/max. This works because pred is stored as an 8-bit int of value 0 or 1.
63   PrimitiveType type = TypeToShape(result.getType()).element_type();
64   if (type == PRED) {
65     switch (opcode.ValueOrDie()) {
66       case HloOpcode::kAnd:
67         return ReductionKind::MIN;
68       case HloOpcode::kOr:
69         return ReductionKind::MAX;
70       default:
71         return absl::nullopt;
72     }
73   } else {
74     switch (opcode.ValueOrDie()) {
75       case HloOpcode::kAdd:
76         return ReductionKind::SUM;
77       case HloOpcode::kMultiply:
78         return ReductionKind::PRODUCT;
79       case HloOpcode::kMaximum:
80         return ReductionKind::MAX;
81       case HloOpcode::kMinimum:
82         return ReductionKind::MIN;
83       default:
84         return absl::nullopt;
85     }
86   }
87 }
88 
GetNcclAllReduceConfig(mlir::lmhlo::AllReduceOp op,int64 replica_count)89 static NcclAllReduceConfig GetNcclAllReduceConfig(mlir::lmhlo::AllReduceOp op,
90                                                   int64 replica_count) {
91   auto reduction_kind = MatchReductionComputation(op);
92   CHECK(reduction_kind.has_value());
93 
94   NcclAllReduceConfig config;
95   config.config = GetNcclCollectiveConfigForMlir(op, replica_count);
96   config.reduction_kind = *reduction_kind;
97   return config;
98 }
99 
CanImplement(mlir::lmhlo::AllReduceOp op)100 /*static*/ bool NcclAllReduceThunk::CanImplement(mlir::lmhlo::AllReduceOp op) {
101   bool operands_are_supported =
102       absl::c_all_of(op.operands(), [](mlir::Value operand) {
103         Shape shape = TypeToShape(operand.getType());
104         return LayoutUtil::IsDenseArray(shape) &&
105                IsTypeSupportedByNccl(shape.element_type());
106       });
107   return operands_are_supported && MatchReductionComputation(op).has_value();
108 }
109 
NcclAllReduceThunk(ThunkInfo thunk_info,mlir::lmhlo::AllReduceOp op,int64 replica_count,std::vector<NcclAllReduceThunk::Buffer> buffers)110 NcclAllReduceThunk::NcclAllReduceThunk(
111     ThunkInfo thunk_info, mlir::lmhlo::AllReduceOp op, int64 replica_count,
112     std::vector<NcclAllReduceThunk::Buffer> buffers)
113     : NcclCollectiveThunk(Thunk::kNcclAllReduce, thunk_info),
114       config_(GetNcclAllReduceConfig(op, replica_count)),
115       buffers_(std::move(buffers)) {
116   CHECK_EQ(config_.config.operand_count, buffers_.size());
117 }
118 
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)119 Status NcclAllReduceThunk::RunNcclCollective(const ExecuteParams& params,
120                                              ncclComm_t comm) {
121 #if XLA_ENABLE_XCCL
122   int device_ordinal = params.stream->parent()->device_ordinal();
123   VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal;
124 
125   ncclRedOp_t reduce_op = ToNcclReduction(config_.reduction_kind);
126 
127   cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
128       params.stream->implementation()->GpuStreamMemberHack());
129 
130   XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
131   for (size_t i = 0; i < buffers_.size(); ++i) {
132     const Buffer& buffer = buffers_[i];
133     const void* send_buffer =
134         params.buffer_allocations->GetDeviceAddress(buffer.source_buffer)
135             .opaque();
136     void* recv_buffer =
137         params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer)
138             .opaque();
139 
140     TF_ASSIGN_OR_RETURN(ncclDataType_t datatype,
141                         ToNcclDataType(config_.config.operand_element_type[i]));
142 
143     VLOG(3) << absl::StreamFormat(
144         "Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
145         "comm=%p, stream=%p)",
146         send_buffer, recv_buffer, buffer.element_count,
147         static_cast<const void*>(comm), cu_stream);
148 
149     XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer,
150                                            buffer.element_count, datatype,
151                                            reduce_op, comm, *cu_stream));
152   }
153   XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
154 
155   VLOG(3) << "Done performing all-reduce for ordinal: " << device_ordinal;
156   return Status::OK();
157 #else   // XLA_ENABLE_XCCL
158   return Unimplemented(
159       "NCCL support is not available: this binary was not built with a CUDA "
160       "compiler, which is necessary to build the NCCL source library.");
161 #endif  // XLA_ENABLE_XCCL
162 }
163 
164 }  // namespace gpu
165 }  // namespace xla
166