• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
17 
18 #include <chrono>  // NOLINT (required by TF interfaces)
19 #include <cstdlib>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/strings/str_format.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/util.h"
33 
34 namespace xla {
35 namespace gpu {
36 
GetNcclAllToAllConfig(mlir::lmhlo::AllToAllOp op)37 /*static*/ NcclAllToAllConfig NcclAllToAllThunk::GetNcclAllToAllConfig(
38     mlir::lmhlo::AllToAllOp op) {
39   NcclAllToAllConfig config;
40   // FIXME(b/180174349): LMHLO AllToAll incorrectly has use_global_device_ids
41   // attribute and it should be removed.
42   config.config = GetNcclCollectiveConfigForMlir(op, absl::nullopt);
43   config.has_split_dimension = op.split_dimension().hasValue();
44   return config;
45 }
46 
CanImplement(mlir::lmhlo::AllToAllOp op)47 /*static*/ bool NcclAllToAllThunk::CanImplement(mlir::lmhlo::AllToAllOp op) {
48   return absl::c_all_of(op.operands(), [&op](mlir::Value operand) {
49     Shape shape = GetShape(operand);
50     return LayoutUtil::IsDenseArray(shape) &&
51            IsTypeSupportedByNccl(shape.element_type()) &&
52            (!op.split_dimension() ||
53             LayoutUtil::MinorToMajor(shape).back() == *op.split_dimension());
54   });
55 }
56 
NcclAllToAllThunk(ThunkInfo thunk_info,mlir::lmhlo::AllToAllOp op,std::vector<NcclAllToAllThunk::Buffer> buffers)57 NcclAllToAllThunk::NcclAllToAllThunk(
58     ThunkInfo thunk_info, mlir::lmhlo::AllToAllOp op,
59     std::vector<NcclAllToAllThunk::Buffer> buffers)
60     : NcclCollectiveThunk(Thunk::kNcclAllToAll, thunk_info),
61       config_(GetNcclAllToAllConfig(op)),
62       buffers_(std::move(buffers)) {
63   CHECK_EQ(config_.config.operand_count, buffers_.size());
64 }
65 
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)66 Status NcclAllToAllThunk::RunNcclCollective(const ExecuteParams& params,
67                                             ncclComm_t comm) {
68 #if XLA_ENABLE_XCCL
69   int device_ordinal = params.stream->parent()->device_ordinal();
70   VLOG(3) << "Performing all-to-all from device ordinal: " << device_ordinal;
71 
72   cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
73       params.stream->implementation()->GpuStreamMemberHack());
74 
75   int num_participants;
76   XLA_CUDA_RETURN_IF_ERROR(ncclCommCount(comm, &num_participants));
77 
78   XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
79   // AllToAll can operate in two modes. Either it specifies a split dimension,
80   // in which case inputs are split and outputs concatenated in that dimension
81   // (here, we only support dimension 0), or it takes a list of inputs
82   // and produces a tuple of outputs.
83   if (config_.has_split_dimension) {
84     for (size_t i = 0; i < buffers_.size(); ++i) {
85       const Buffer& buffer = buffers_[i];
86       const uint8* send_buffer = static_cast<uint8*>(
87           params.buffer_allocations->GetDeviceAddress(buffer.source_buffer)
88               .opaque());
89       uint8* recv_buffer = static_cast<uint8*>(
90           params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer)
91               .opaque());
92 
93       PrimitiveType element_type = config_.config.operand_element_type[i];
94       TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
95                           ToNcclDataTypeAndCountMultiplier(element_type));
96       ncclDataType_t dtype = dtype_and_multiplier.first;
97       int element_count = buffer.element_count * dtype_and_multiplier.second;
98 
99       TF_RET_CHECK(element_count % num_participants == 0)
100           << "Buffer was not an exact multiple of the number of participants.";
101       size_t chunk_elements = element_count / num_participants;
102       size_t chunk_bytes =
103           chunk_elements * ShapeUtil::ByteSizeOfPrimitiveType(element_type);
104 
105       for (int rank = 0; rank < num_participants; ++rank) {
106         XLA_CUDA_RETURN_IF_ERROR(ncclSend(send_buffer + rank * chunk_bytes,
107                                           chunk_elements, dtype, rank, comm,
108                                           *cu_stream));
109         XLA_CUDA_RETURN_IF_ERROR(ncclRecv(recv_buffer + rank * chunk_bytes,
110                                           chunk_elements, dtype, rank, comm,
111                                           *cu_stream));
112       }
113     }
114   } else {
115     TF_RET_CHECK(buffers_.size() == num_participants)
116         << "Number of inputs didn't match the number of participants.";
117 
118     for (size_t i = 0; i < buffers_.size(); ++i) {
119       const Buffer& buffer = buffers_[i];
120       const uint8* send_buffer = static_cast<uint8*>(
121           params.buffer_allocations->GetDeviceAddress(buffer.source_buffer)
122               .opaque());
123       uint8* recv_buffer = static_cast<uint8*>(
124           params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer)
125               .opaque());
126 
127       PrimitiveType element_type = config_.config.operand_element_type[i];
128       TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
129                           ToNcclDataTypeAndCountMultiplier(element_type));
130       ncclDataType_t dtype = dtype_and_multiplier.first;
131       int element_count = buffer.element_count * dtype_and_multiplier.second;
132 
133       XLA_CUDA_RETURN_IF_ERROR(ncclSend(send_buffer, element_count, dtype,
134                                         /*rank=*/i, comm, *cu_stream));
135       XLA_CUDA_RETURN_IF_ERROR(ncclRecv(recv_buffer, element_count, dtype,
136                                         /*rank=*/i, comm, *cu_stream));
137     }
138   }
139   XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
140 
141   VLOG(3) << "Done performing all-to-all for ordinal: " << device_ordinal;
142   return Status::OK();
143 #else   // XLA_ENABLE_XCCL
144   return Unimplemented(
145       "NCCL support is not available: this binary was not built with a CUDA "
146       "compiler, which is necessary to build the NCCL source library.");
147 #endif  // XLA_ENABLE_XCCL
148 }
149 
150 }  // namespace gpu
151 }  // namespace xla
152