• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
17 
18 #include <chrono>  // NOLINT (required by TF interfaces)
19 #include <cstdlib>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/layout_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/util.h"
31 
32 namespace xla {
33 namespace gpu {
34 
GetNcclAllToAllConfig(mlir::lmhlo::AllToAllOp op,int64 replica_count)35 static NcclAllToAllConfig GetNcclAllToAllConfig(mlir::lmhlo::AllToAllOp op,
36                                                 int64 replica_count) {
37   NcclAllToAllConfig config;
38   config.config = GetNcclCollectiveConfigForMlir(op, replica_count);
39   config.has_split_dimension = op.split_dimension().hasValue();
40   return config;
41 }
42 
CanImplement(mlir::lmhlo::AllToAllOp op)43 /*static*/ bool NcclAllToAllThunk::CanImplement(mlir::lmhlo::AllToAllOp op) {
44   bool operands_are_supported =
45       absl::c_all_of(op.operands(), [](mlir::Value operand) {
46         Shape shape = TypeToShape(operand.getType());
47         return LayoutUtil::IsDenseArray(shape) &&
48                IsTypeSupportedByNccl(shape.element_type());
49       });
50   return op.split_dimension().getValueOr(0) == 0 && operands_are_supported;
51 }
52 
NcclAllToAllThunk(ThunkInfo thunk_info,mlir::lmhlo::AllToAllOp op,int64 replica_count,std::vector<NcclAllToAllThunk::Buffer> buffers)53 NcclAllToAllThunk::NcclAllToAllThunk(
54     ThunkInfo thunk_info, mlir::lmhlo::AllToAllOp op, int64 replica_count,
55     std::vector<NcclAllToAllThunk::Buffer> buffers)
56     : NcclCollectiveThunk(Thunk::kNcclAllToAll, thunk_info),
57       config_(GetNcclAllToAllConfig(op, replica_count)),
58       buffers_(std::move(buffers)) {
59   CHECK_EQ(config_.config.operand_count, buffers_.size());
60 }
61 
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)62 Status NcclAllToAllThunk::RunNcclCollective(const ExecuteParams& params,
63                                             ncclComm_t comm) {
64 #if XLA_ENABLE_XCCL
65   int device_ordinal = params.stream->parent()->device_ordinal();
66   VLOG(3) << "Performing all-to-all from device ordinal: " << device_ordinal;
67 
68   cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
69       params.stream->implementation()->GpuStreamMemberHack());
70 
71   int num_participants;
72   XLA_CUDA_RETURN_IF_ERROR(ncclCommCount(comm, &num_participants));
73 
74   XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
75   // AllToAll can operate in two modes. Either it specifies a split dimension,
76   // in which case inputs are split and outputs concatenated in that dimension
77   // (here, we only support dimension 0), or it takes a list of inputs
78   // and produces a tuple of outputs.
79   if (config_.has_split_dimension) {
80     for (size_t i = 0; i < buffers_.size(); ++i) {
81       const Buffer& buffer = buffers_[i];
82       const uint8* send_buffer = static_cast<uint8*>(
83           params.buffer_allocations->GetDeviceAddress(buffer.source_buffer)
84               .opaque());
85       uint8* recv_buffer = static_cast<uint8*>(
86           params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer)
87               .opaque());
88 
89       PrimitiveType element_type = config_.config.operand_element_type[i];
90       TF_ASSIGN_OR_RETURN(ncclDataType_t datatype,
91                           ToNcclDataType(element_type));
92 
93       TF_RET_CHECK(buffer.element_count % num_participants == 0)
94           << "Buffer was not an exact multiple of the number of participants.";
95       size_t chunk_elements = buffer.element_count / num_participants;
96       size_t chunk_bytes =
97           chunk_elements * ShapeUtil::ByteSizeOfPrimitiveType(element_type);
98 
99       for (int rank = 0; rank < num_participants; ++rank) {
100         XLA_CUDA_RETURN_IF_ERROR(ncclSend(send_buffer + rank * chunk_bytes,
101                                           chunk_elements, datatype, rank, comm,
102                                           *cu_stream));
103         XLA_CUDA_RETURN_IF_ERROR(ncclRecv(recv_buffer + rank * chunk_bytes,
104                                           chunk_elements, datatype, rank, comm,
105                                           *cu_stream));
106       }
107     }
108   } else {
109     TF_RET_CHECK(buffers_.size() == num_participants)
110         << "Number of inputs didn't match the number of participants.";
111 
112     for (size_t i = 0; i < buffers_.size(); ++i) {
113       const Buffer& buffer = buffers_[i];
114       const uint8* send_buffer = static_cast<uint8*>(
115           params.buffer_allocations->GetDeviceAddress(buffer.source_buffer)
116               .opaque());
117       uint8* recv_buffer = static_cast<uint8*>(
118           params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer)
119               .opaque());
120 
121       PrimitiveType element_type = config_.config.operand_element_type[i];
122       TF_ASSIGN_OR_RETURN(ncclDataType_t datatype,
123                           ToNcclDataType(element_type));
124 
125       XLA_CUDA_RETURN_IF_ERROR(ncclSend(send_buffer, buffer.element_count,
126                                         datatype, /*rank=*/i, comm,
127                                         *cu_stream));
128       XLA_CUDA_RETURN_IF_ERROR(ncclRecv(recv_buffer, buffer.element_count,
129                                         datatype, /*rank=*/i, comm,
130                                         *cu_stream));
131     }
132   }
133   XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
134 
135   VLOG(3) << "Done performing all-to-all for ordinal: " << device_ordinal;
136   return Status::OK();
137 #else   // XLA_ENABLE_XCCL
138   return Unimplemented(
139       "NCCL support is not available: this binary was not built with a CUDA "
140       "compiler, which is necessary to build the NCCL source library.");
141 #endif  // XLA_ENABLE_XCCL
142 }
143 
144 }  // namespace gpu
145 }  // namespace xla
146