1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.h"
17
18 #include <map>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
26 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
27 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29
30 namespace xla {
31 namespace gpu {
32
33 /*static*/ NcclCollectivePermuteConfig
GetNcclCollectivePermuteConfig(mlir::lmhlo::CollectivePermuteOp op,int64_t replica_count,int64_t partition_count)34 NcclCollectivePermuteThunk::GetNcclCollectivePermuteConfig(
35 mlir::lmhlo::CollectivePermuteOp op, int64_t replica_count,
36 int64_t partition_count) {
37 NcclCollectivePermuteConfig config;
38
39 config.operand_count = 1;
40 const Shape shape = GetShape(op.operand());
41 config.operand_element_type.push_back(shape.element_type());
42 config.SetCollectiveOpKindAndID(op);
43 config.group_mode = GetGroupMode(op);
44
45 // With a collective permute, all execution instances together form one
46 // replica group.
47 const int64_t num_participants =
48 config.group_mode == CollectiveOpGroupMode::kCrossReplica
49 ? replica_count
50 : partition_count;
51 config.replica_groups.emplace_back();
52 ReplicaGroup& replica_group = config.replica_groups.front();
53 for (int i = 0; i < num_participants; ++i) {
54 replica_group.add_replica_ids(i);
55 }
56
57 const std::vector<std::pair<int64, int64>> source_target_pairs =
58 ConvertNx2Attribute(op.source_target_pairs()).ValueOrDie();
59
60 for (const std::pair<int64, int64>& source_target : source_target_pairs) {
61 int64_t source = source_target.first;
62 int64_t target = source_target.second;
63
64 config.id_to_source_target.insert({target, {}}).first->second.source =
65 source;
66 config.id_to_source_target.insert({source, {}}).first->second.target =
67 target;
68 }
69
70 return config;
71 }
72
73 // The collective permute is degenerate if all source-target pairs are identity,
74 // and all the IDs appear in the list.
IsDegenerate(mlir::lmhlo::CollectivePermuteOp op,int64_t replica_count,int64_t partition_count)75 /*static*/ bool NcclCollectivePermuteThunk::IsDegenerate(
76 mlir::lmhlo::CollectivePermuteOp op, int64_t replica_count,
77 int64_t partition_count) {
78 const std::vector<std::pair<int64, int64>> source_target_pairs =
79 ConvertNx2Attribute(op.source_target_pairs()).ValueOrDie();
80 // Each ID can appear only once as a source and as a target. So if all pairs
81 // are identity, all IDs must appear in the list is the size == number of
82 // replicas/partitions.
83 const int64_t expected_size =
84 op.channel_id() ? partition_count : replica_count;
85 return source_target_pairs.size() == expected_size &&
86 absl::c_all_of(source_target_pairs,
87 [](const std::pair<int64, int64>& source_target) {
88 return source_target.first == source_target.second;
89 });
90 }
91
CanImplement(mlir::lmhlo::CollectivePermuteOp op)92 /*static*/ bool NcclCollectivePermuteThunk::CanImplement(
93 mlir::lmhlo::CollectivePermuteOp op) {
94 const Shape shape = GetShape(op.operand());
95 return IsTypeSupportedByNccl(shape.element_type());
96 }
97
NcclCollectivePermuteThunk(ThunkInfo thunk_info,mlir::lmhlo::CollectivePermuteOp op,int64_t replica_count,int64_t partition_count,const Buffer & buffer)98 NcclCollectivePermuteThunk::NcclCollectivePermuteThunk(
99 ThunkInfo thunk_info, mlir::lmhlo::CollectivePermuteOp op,
100 int64_t replica_count, int64_t partition_count, const Buffer& buffer)
101 : NcclCollectiveThunk(Thunk::kCollectivePermute, thunk_info),
102 config_(
103 GetNcclCollectivePermuteConfig(op, replica_count, partition_count)),
104 buffer_(buffer) {}
105
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)106 Status NcclCollectivePermuteThunk::RunNcclCollective(
107 const ExecuteParams& params, ncclComm_t comm) {
108 #if XLA_ENABLE_XCCL
109 // Determine the source and target IDs for this instance. The source ID is the
110 // ID which will copy its data to this instance. The destination ID is the ID
111 // to which this instance will copy its data. Either are optional.
112 //
113 // No source and no dest:
114 // - this instance does not actually participate, no one send it any data and
115 // it does not have to send any data as well. Since there is no dest,
116 // just memzero() the dest buffer as required by the collective permute
117 // semantics.
118 //
119 // No source, dest present:
120 // - This instance has to send data to 'dest' Issue an send of the input.
121 // Since there is no source, memzero the dest buffer.
122 //
123 // Source present, no destination:
124 // - This instance received data from the source, does not have to send data
125 // to anyone, Issue a receive.
126 //
127 // Source and dest both present:
128 // - Issue a send of the input to dest, receive for the output from the
129 // src.
130 //
131 //
132
133 int device_ordinal = params.stream->parent()->device_ordinal();
134 VLOG(3) << "Performing collective permute from device ordinal: "
135 << device_ordinal;
136
137 TF_ASSIGN_OR_RETURN(const GlobalDeviceId global_device_id,
138 params.GetGlobalDeviceId());
139 TF_ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID current_logical_id,
140 params.device_assn->LogicalIdForDevice(global_device_id));
141 const int64_t current_id =
142 config_.group_mode == CollectiveOpGroupMode::kCrossReplica
143 ? current_logical_id.replica_id
144 : current_logical_id.computation_id;
145
146 const NcclCollectivePermuteConfig::SourceTargetMapEntry source_target =
147 config_.GetSourceTarget(current_id);
148 const absl::optional<int64> source_id = source_target.source;
149 const absl::optional<int64> target_id = source_target.target;
150
151 // NCCL 2.8.x has an issue with point-to-point communication primitives if
152 // different ranks process different amounts of data. This can happen in the
153 // case of a collective permute as certain nodes may not do any send or
154 // receives, or do only send or only receive. Sending and receiving to self
155 // as well (identity pair) causes this imbalance. NCCL 2.8.x requires the
156 // use of NCCL_LAUNCH_MODE=PARALLEL to avoid these issues. See
157 // https://docs.nvidia.com/deeplearning/nccl/release-notes/rel_2-8-4.html#rel_2-8-4
158 if (!IsNcclLaunchModeParallel()) {
159 LOG(WARNING) << "NCCL based collective permute may not work correctly if "
160 "NCCL_LAUNCH_MODE is not set to PARALLEL";
161 }
162
163 se::DeviceMemoryBase src_addr =
164 params.buffer_allocations->GetDeviceAddress(buffer_.source_buffer);
165 se::DeviceMemoryBase dest_addr =
166 params.buffer_allocations->GetDeviceAddress(buffer_.destination_buffer);
167
168 VLOG(3) << absl::StreamFormat("%s : id = %d, source_id = %d, target_id = %d",
169 GetDeviceString(params), current_id,
170 source_id.value_or(-1), target_id.value_or(-1));
171
172 XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
173
174 PrimitiveType element_type = config_.operand_element_type[0];
175 TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
176 ToNcclDataTypeAndCountMultiplier(element_type));
177 ncclDataType_t dtype = dtype_and_multiplier.first;
178 int element_count = buffer_.element_count * dtype_and_multiplier.second;
179
180 cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
181 params.stream->implementation()->GpuStreamMemberHack());
182
183 // send source buffer to target peer if needed.
184 if (target_id) {
185 VLOG(3) << absl::StreamFormat(
186 "%s : Calling ncclSend(sendbuff=%p, count=%d, peer=%d "
187 "comm=%p, stream=%p)",
188 GetDeviceString(params), src_addr.opaque(), element_count, *target_id,
189 static_cast<const void*>(comm), *cu_stream);
190 XLA_CUDA_RETURN_IF_ERROR(ncclSend(src_addr.opaque(), element_count, dtype,
191 *target_id, comm, *cu_stream));
192 }
193
194 // Receive data from the source peer to the destination buffer.
195 if (source_id) {
196 VLOG(3) << absl::StreamFormat(
197 "%s : Calling ncclRecv(recvbuff=%p, count=%d, peer=%d comm=%p, "
198 "stream=%p)",
199 GetDeviceString(params), dest_addr.opaque(), element_count, *source_id,
200 static_cast<const void*>(comm), *cu_stream);
201 XLA_CUDA_RETURN_IF_ERROR(ncclRecv(dest_addr.opaque(), element_count, dtype,
202 *source_id, comm, *cu_stream));
203 }
204 XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
205
206 if (!source_id) {
207 // If there is no source peer, i.e. no one send us any data, zero out dest
208 // buffer.
209 VLOG(3) << absl::StreamFormat("%s : collective-Permute: Issuing MemZero",
210 GetDeviceString(params));
211 params.stream->ThenMemZero(&dest_addr, dest_addr.size());
212 }
213 return Status::OK();
214 #else // XLA_ENABLE_XCCL
215 return Unimplemented(
216 "NCCL support is not available: this binary was not built with a CUDA "
217 "compiler, which is necessary to build the NCCL source library.");
218 #endif // XLA_ENABLE_XCCL
219 }
220
221 } // namespace gpu
222 } // namespace xla
223