1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
17
18 #include <chrono> // NOLINT (required by TF interfaces)
19 #include <cstdlib>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34
35 namespace xla {
36 namespace gpu {
37
38 // Attempts to match computation to one of the possible cases in ReductionKind.
MatchReductionComputation(mlir::lmhlo::AllReduceOp op)39 static absl::optional<ReductionKind> MatchReductionComputation(
40 mlir::lmhlo::AllReduceOp op) {
41 mlir::Block& block = op.computation().front();
42 if (!llvm::hasSingleElement(block.without_terminator())) return absl::nullopt;
43 // The single operation should use both block arguments and produce a single
44 // result (all of the same type)
45 mlir::Operation* reduction_op = &block.front();
46 if (reduction_op->getNumOperands() != 2 || reduction_op->getNumResults() != 1)
47 return absl::nullopt;
48 mlir::BlockArgument arg0 =
49 reduction_op->getOperand(0).dyn_cast<mlir::BlockArgument>();
50 mlir::BlockArgument arg1 =
51 reduction_op->getOperand(1).dyn_cast<mlir::BlockArgument>();
52 mlir::OpResult result = reduction_op->getResult(0);
53 // Both operands should be block arguments of the reduction computation block
54 // and be different arguments of that block.
55 if (!arg0 || !arg1 || arg0.getOwner() != &block ||
56 arg1.getOwner() != &block || arg0 == arg1 ||
57 arg0.getType() != arg1.getType() || arg0.getType() != result.getType())
58 return absl::nullopt;
59 StatusOr<HloOpcode> opcode = MhloToHloOpcode(reduction_op);
60 if (!opcode.ok()) return absl::nullopt;
61 // Match the operation to a reduction kind. We can represent and/or of pred as
62 // min/max. This works because pred is stored as an 8-bit int of value 0 or 1.
63 PrimitiveType type = TypeToShape(result.getType()).element_type();
64 if (type == PRED) {
65 switch (opcode.ValueOrDie()) {
66 case HloOpcode::kAnd:
67 return ReductionKind::MIN;
68 case HloOpcode::kOr:
69 return ReductionKind::MAX;
70 default:
71 return absl::nullopt;
72 }
73 } else {
74 switch (opcode.ValueOrDie()) {
75 case HloOpcode::kAdd:
76 return ReductionKind::SUM;
77 case HloOpcode::kMultiply:
78 return ReductionKind::PRODUCT;
79 case HloOpcode::kMaximum:
80 return ReductionKind::MAX;
81 case HloOpcode::kMinimum:
82 return ReductionKind::MIN;
83 default:
84 return absl::nullopt;
85 }
86 }
87 }
88
GetNcclAllReduceConfig(mlir::lmhlo::AllReduceOp op,int64 replica_count)89 static NcclAllReduceConfig GetNcclAllReduceConfig(mlir::lmhlo::AllReduceOp op,
90 int64 replica_count) {
91 auto reduction_kind = MatchReductionComputation(op);
92 CHECK(reduction_kind.has_value());
93
94 NcclAllReduceConfig config;
95 config.config = GetNcclCollectiveConfigForMlir(op, replica_count);
96 config.reduction_kind = *reduction_kind;
97 return config;
98 }
99
CanImplement(mlir::lmhlo::AllReduceOp op)100 /*static*/ bool NcclAllReduceThunk::CanImplement(mlir::lmhlo::AllReduceOp op) {
101 bool operands_are_supported =
102 absl::c_all_of(op.operands(), [](mlir::Value operand) {
103 Shape shape = TypeToShape(operand.getType());
104 return LayoutUtil::IsDenseArray(shape) &&
105 IsTypeSupportedByNccl(shape.element_type());
106 });
107 return operands_are_supported && MatchReductionComputation(op).has_value();
108 }
109
NcclAllReduceThunk(ThunkInfo thunk_info,mlir::lmhlo::AllReduceOp op,int64 replica_count,std::vector<NcclAllReduceThunk::Buffer> buffers)110 NcclAllReduceThunk::NcclAllReduceThunk(
111 ThunkInfo thunk_info, mlir::lmhlo::AllReduceOp op, int64 replica_count,
112 std::vector<NcclAllReduceThunk::Buffer> buffers)
113 : NcclCollectiveThunk(Thunk::kNcclAllReduce, thunk_info),
114 config_(GetNcclAllReduceConfig(op, replica_count)),
115 buffers_(std::move(buffers)) {
116 CHECK_EQ(config_.config.operand_count, buffers_.size());
117 }
118
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)119 Status NcclAllReduceThunk::RunNcclCollective(const ExecuteParams& params,
120 ncclComm_t comm) {
121 #if XLA_ENABLE_XCCL
122 int device_ordinal = params.stream->parent()->device_ordinal();
123 VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal;
124
125 ncclRedOp_t reduce_op = ToNcclReduction(config_.reduction_kind);
126
127 cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
128 params.stream->implementation()->GpuStreamMemberHack());
129
130 XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
131 for (size_t i = 0; i < buffers_.size(); ++i) {
132 const Buffer& buffer = buffers_[i];
133 const void* send_buffer =
134 params.buffer_allocations->GetDeviceAddress(buffer.source_buffer)
135 .opaque();
136 void* recv_buffer =
137 params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer)
138 .opaque();
139
140 TF_ASSIGN_OR_RETURN(ncclDataType_t datatype,
141 ToNcclDataType(config_.config.operand_element_type[i]));
142
143 VLOG(3) << absl::StreamFormat(
144 "Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
145 "comm=%p, stream=%p)",
146 send_buffer, recv_buffer, buffer.element_count,
147 static_cast<const void*>(comm), cu_stream);
148
149 XLA_CUDA_RETURN_IF_ERROR(ncclAllReduce(send_buffer, recv_buffer,
150 buffer.element_count, datatype,
151 reduce_op, comm, *cu_stream));
152 }
153 XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
154
155 VLOG(3) << "Done performing all-reduce for ordinal: " << device_ordinal;
156 return Status::OK();
157 #else // XLA_ENABLE_XCCL
158 return Unimplemented(
159 "NCCL support is not available: this binary was not built with a CUDA "
160 "compiler, which is necessary to build the NCCL source library.");
161 #endif // XLA_ENABLE_XCCL
162 }
163
164 } // namespace gpu
165 } // namespace xla
166