1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
17
18 #include <chrono> // NOLINT (required by TF interfaces)
19 #include <cstdlib>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/strings/str_format.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/util.h"
33
34 namespace xla {
35 namespace gpu {
36
GetNcclAllToAllConfig(mlir::lmhlo::AllToAllOp op)37 /*static*/ NcclAllToAllConfig NcclAllToAllThunk::GetNcclAllToAllConfig(
38 mlir::lmhlo::AllToAllOp op) {
39 NcclAllToAllConfig config;
40 // FIXME(b/180174349): LMHLO AllToAll incorrectly has use_global_device_ids
41 // attribute and it should be removed.
42 config.config = GetNcclCollectiveConfigForMlir(op, absl::nullopt);
43 config.has_split_dimension = op.split_dimension().hasValue();
44 return config;
45 }
46
CanImplement(mlir::lmhlo::AllToAllOp op)47 /*static*/ bool NcclAllToAllThunk::CanImplement(mlir::lmhlo::AllToAllOp op) {
48 return absl::c_all_of(op.operands(), [&op](mlir::Value operand) {
49 Shape shape = GetShape(operand);
50 return LayoutUtil::IsDenseArray(shape) &&
51 IsTypeSupportedByNccl(shape.element_type()) &&
52 (!op.split_dimension() ||
53 LayoutUtil::MinorToMajor(shape).back() == *op.split_dimension());
54 });
55 }
56
NcclAllToAllThunk(ThunkInfo thunk_info,mlir::lmhlo::AllToAllOp op,std::vector<NcclAllToAllThunk::Buffer> buffers)57 NcclAllToAllThunk::NcclAllToAllThunk(
58 ThunkInfo thunk_info, mlir::lmhlo::AllToAllOp op,
59 std::vector<NcclAllToAllThunk::Buffer> buffers)
60 : NcclCollectiveThunk(Thunk::kNcclAllToAll, thunk_info),
61 config_(GetNcclAllToAllConfig(op)),
62 buffers_(std::move(buffers)) {
63 CHECK_EQ(config_.config.operand_count, buffers_.size());
64 }
65
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)66 Status NcclAllToAllThunk::RunNcclCollective(const ExecuteParams& params,
67 ncclComm_t comm) {
68 #if XLA_ENABLE_XCCL
69 int device_ordinal = params.stream->parent()->device_ordinal();
70 VLOG(3) << "Performing all-to-all from device ordinal: " << device_ordinal;
71
72 cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
73 params.stream->implementation()->GpuStreamMemberHack());
74
75 int num_participants;
76 XLA_CUDA_RETURN_IF_ERROR(ncclCommCount(comm, &num_participants));
77
78 XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
79 // AllToAll can operate in two modes. Either it specifies a split dimension,
80 // in which case inputs are split and outputs concatenated in that dimension
81 // (here, we only support dimension 0), or it takes a list of inputs
82 // and produces a tuple of outputs.
83 if (config_.has_split_dimension) {
84 for (size_t i = 0; i < buffers_.size(); ++i) {
85 const Buffer& buffer = buffers_[i];
86 const uint8* send_buffer = static_cast<uint8*>(
87 params.buffer_allocations->GetDeviceAddress(buffer.source_buffer)
88 .opaque());
89 uint8* recv_buffer = static_cast<uint8*>(
90 params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer)
91 .opaque());
92
93 PrimitiveType element_type = config_.config.operand_element_type[i];
94 TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
95 ToNcclDataTypeAndCountMultiplier(element_type));
96 ncclDataType_t dtype = dtype_and_multiplier.first;
97 int element_count = buffer.element_count * dtype_and_multiplier.second;
98
99 TF_RET_CHECK(element_count % num_participants == 0)
100 << "Buffer was not an exact multiple of the number of participants.";
101 size_t chunk_elements = element_count / num_participants;
102 size_t chunk_bytes =
103 chunk_elements * ShapeUtil::ByteSizeOfPrimitiveType(element_type);
104
105 for (int rank = 0; rank < num_participants; ++rank) {
106 XLA_CUDA_RETURN_IF_ERROR(ncclSend(send_buffer + rank * chunk_bytes,
107 chunk_elements, dtype, rank, comm,
108 *cu_stream));
109 XLA_CUDA_RETURN_IF_ERROR(ncclRecv(recv_buffer + rank * chunk_bytes,
110 chunk_elements, dtype, rank, comm,
111 *cu_stream));
112 }
113 }
114 } else {
115 TF_RET_CHECK(buffers_.size() == num_participants)
116 << "Number of inputs didn't match the number of participants.";
117
118 for (size_t i = 0; i < buffers_.size(); ++i) {
119 const Buffer& buffer = buffers_[i];
120 const uint8* send_buffer = static_cast<uint8*>(
121 params.buffer_allocations->GetDeviceAddress(buffer.source_buffer)
122 .opaque());
123 uint8* recv_buffer = static_cast<uint8*>(
124 params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer)
125 .opaque());
126
127 PrimitiveType element_type = config_.config.operand_element_type[i];
128 TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
129 ToNcclDataTypeAndCountMultiplier(element_type));
130 ncclDataType_t dtype = dtype_and_multiplier.first;
131 int element_count = buffer.element_count * dtype_and_multiplier.second;
132
133 XLA_CUDA_RETURN_IF_ERROR(ncclSend(send_buffer, element_count, dtype,
134 /*rank=*/i, comm, *cu_stream));
135 XLA_CUDA_RETURN_IF_ERROR(ncclRecv(recv_buffer, element_count, dtype,
136 /*rank=*/i, comm, *cu_stream));
137 }
138 }
139 XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
140
141 VLOG(3) << "Done performing all-to-all for ordinal: " << device_ordinal;
142 return Status::OK();
143 #else // XLA_ENABLE_XCCL
144 return Unimplemented(
145 "NCCL support is not available: this binary was not built with a CUDA "
146 "compiler, which is necessary to build the NCCL source library.");
147 #endif // XLA_ENABLE_XCCL
148 }
149
150 } // namespace gpu
151 } // namespace xla
152