• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.h"
17 
18 #include <map>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
26 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
27 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 
30 namespace xla {
31 namespace gpu {
32 
33 /*static*/ NcclCollectivePermuteConfig
GetNcclCollectivePermuteConfig(mlir::lmhlo::CollectivePermuteOp op,int64_t replica_count,int64_t partition_count)34 NcclCollectivePermuteThunk::GetNcclCollectivePermuteConfig(
35     mlir::lmhlo::CollectivePermuteOp op, int64_t replica_count,
36     int64_t partition_count) {
37   NcclCollectivePermuteConfig config;
38 
39   config.operand_count = 1;
40   const Shape shape = GetShape(op.operand());
41   config.operand_element_type.push_back(shape.element_type());
42   config.SetCollectiveOpKindAndID(op);
43   config.group_mode = GetGroupMode(op);
44 
45   // With a collective permute, all execution instances together form one
46   // replica group.
47   const int64_t num_participants =
48       config.group_mode == CollectiveOpGroupMode::kCrossReplica
49           ? replica_count
50           : partition_count;
51   config.replica_groups.emplace_back();
52   ReplicaGroup& replica_group = config.replica_groups.front();
53   for (int i = 0; i < num_participants; ++i) {
54     replica_group.add_replica_ids(i);
55   }
56 
57   const std::vector<std::pair<int64, int64>> source_target_pairs =
58       ConvertNx2Attribute(op.source_target_pairs()).ValueOrDie();
59 
60   for (const std::pair<int64, int64>& source_target : source_target_pairs) {
61     int64_t source = source_target.first;
62     int64_t target = source_target.second;
63 
64     config.id_to_source_target.insert({target, {}}).first->second.source =
65         source;
66     config.id_to_source_target.insert({source, {}}).first->second.target =
67         target;
68   }
69 
70   return config;
71 }
72 
73 // The collective permute is degenerate if all source-target pairs are identity,
74 // and all the IDs appear in the list.
IsDegenerate(mlir::lmhlo::CollectivePermuteOp op,int64_t replica_count,int64_t partition_count)75 /*static*/ bool NcclCollectivePermuteThunk::IsDegenerate(
76     mlir::lmhlo::CollectivePermuteOp op, int64_t replica_count,
77     int64_t partition_count) {
78   const std::vector<std::pair<int64, int64>> source_target_pairs =
79       ConvertNx2Attribute(op.source_target_pairs()).ValueOrDie();
80   // Each ID can appear only once as a source and as a target. So if all pairs
81   // are identity, all IDs must appear in the list is the size == number of
82   // replicas/partitions.
83   const int64_t expected_size =
84       op.channel_id() ? partition_count : replica_count;
85   return source_target_pairs.size() == expected_size &&
86          absl::c_all_of(source_target_pairs,
87                         [](const std::pair<int64, int64>& source_target) {
88                           return source_target.first == source_target.second;
89                         });
90 }
91 
CanImplement(mlir::lmhlo::CollectivePermuteOp op)92 /*static*/ bool NcclCollectivePermuteThunk::CanImplement(
93     mlir::lmhlo::CollectivePermuteOp op) {
94   const Shape shape = GetShape(op.operand());
95   return IsTypeSupportedByNccl(shape.element_type());
96 }
97 
NcclCollectivePermuteThunk(ThunkInfo thunk_info,mlir::lmhlo::CollectivePermuteOp op,int64_t replica_count,int64_t partition_count,const Buffer & buffer)98 NcclCollectivePermuteThunk::NcclCollectivePermuteThunk(
99     ThunkInfo thunk_info, mlir::lmhlo::CollectivePermuteOp op,
100     int64_t replica_count, int64_t partition_count, const Buffer& buffer)
101     : NcclCollectiveThunk(Thunk::kCollectivePermute, thunk_info),
102       config_(
103           GetNcclCollectivePermuteConfig(op, replica_count, partition_count)),
104       buffer_(buffer) {}
105 
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)106 Status NcclCollectivePermuteThunk::RunNcclCollective(
107     const ExecuteParams& params, ncclComm_t comm) {
108 #if XLA_ENABLE_XCCL
109   // Determine the source and target IDs for this instance. The source ID is the
110   // ID which will copy its data to this instance. The destination ID is the ID
111   // to which this instance will copy its data. Either are optional.
112   //
113   // No source and no dest:
114   //  - this instance does not actually participate, no one send it any data and
115   //    it does not have to send any data as well. Since there is no dest,
116   //    just memzero() the dest buffer as required by the collective permute
117   //    semantics.
118   //
119   // No source, dest present:
120   //  - This instance has to send data to 'dest' Issue an send of the input.
121   //    Since there is no source, memzero the dest buffer.
122   //
123   // Source present, no destination:
124   //  - This instance received data from the source, does not have to send data
125   //    to anyone, Issue a receive.
126   //
127   // Source and dest both present:
128   //   - Issue a send of the input to dest, receive for the output from the
129   //     src.
130   //
131   //
132 
133   int device_ordinal = params.stream->parent()->device_ordinal();
134   VLOG(3) << "Performing collective permute from device ordinal: "
135           << device_ordinal;
136 
137   TF_ASSIGN_OR_RETURN(const GlobalDeviceId global_device_id,
138                       params.GetGlobalDeviceId());
139   TF_ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID current_logical_id,
140                       params.device_assn->LogicalIdForDevice(global_device_id));
141   const int64_t current_id =
142       config_.group_mode == CollectiveOpGroupMode::kCrossReplica
143           ? current_logical_id.replica_id
144           : current_logical_id.computation_id;
145 
146   const NcclCollectivePermuteConfig::SourceTargetMapEntry source_target =
147       config_.GetSourceTarget(current_id);
148   const absl::optional<int64> source_id = source_target.source;
149   const absl::optional<int64> target_id = source_target.target;
150 
151   // NCCL 2.8.x has an issue with point-to-point communication primitives if
152   // different ranks process different amounts of data. This can happen in the
153   // case of a collective permute as certain nodes may not do any send or
154   // receives, or do only send or only receive. Sending and receiving to self
155   // as well (identity pair) causes this imbalance. NCCL 2.8.x requires the
156   // use of NCCL_LAUNCH_MODE=PARALLEL to avoid these issues. See
157   // https://docs.nvidia.com/deeplearning/nccl/release-notes/rel_2-8-4.html#rel_2-8-4
158   if (!IsNcclLaunchModeParallel()) {
159     LOG(WARNING) << "NCCL based collective permute may not work correctly if "
160                     "NCCL_LAUNCH_MODE is not set to PARALLEL";
161   }
162 
163   se::DeviceMemoryBase src_addr =
164       params.buffer_allocations->GetDeviceAddress(buffer_.source_buffer);
165   se::DeviceMemoryBase dest_addr =
166       params.buffer_allocations->GetDeviceAddress(buffer_.destination_buffer);
167 
168   VLOG(3) << absl::StreamFormat("%s : id = %d, source_id = %d, target_id = %d",
169                                 GetDeviceString(params), current_id,
170                                 source_id.value_or(-1), target_id.value_or(-1));
171 
172   XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
173 
174   PrimitiveType element_type = config_.operand_element_type[0];
175   TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
176                       ToNcclDataTypeAndCountMultiplier(element_type));
177   ncclDataType_t dtype = dtype_and_multiplier.first;
178   int element_count = buffer_.element_count * dtype_and_multiplier.second;
179 
180   cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
181       params.stream->implementation()->GpuStreamMemberHack());
182 
183   // send source buffer to target peer if needed.
184   if (target_id) {
185     VLOG(3) << absl::StreamFormat(
186         "%s : Calling ncclSend(sendbuff=%p, count=%d, peer=%d "
187         "comm=%p, stream=%p)",
188         GetDeviceString(params), src_addr.opaque(), element_count, *target_id,
189         static_cast<const void*>(comm), *cu_stream);
190     XLA_CUDA_RETURN_IF_ERROR(ncclSend(src_addr.opaque(), element_count, dtype,
191                                       *target_id, comm, *cu_stream));
192   }
193 
194   // Receive data from the source peer to the destination buffer.
195   if (source_id) {
196     VLOG(3) << absl::StreamFormat(
197         "%s : Calling ncclRecv(recvbuff=%p, count=%d, peer=%d comm=%p, "
198         "stream=%p)",
199         GetDeviceString(params), dest_addr.opaque(), element_count, *source_id,
200         static_cast<const void*>(comm), *cu_stream);
201     XLA_CUDA_RETURN_IF_ERROR(ncclRecv(dest_addr.opaque(), element_count, dtype,
202                                       *source_id, comm, *cu_stream));
203   }
204   XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
205 
206   if (!source_id) {
207     // If there is no source peer, i.e. no one send us any data, zero out dest
208     // buffer.
209     VLOG(3) << absl::StreamFormat("%s : collective-Permute: Issuing MemZero",
210                                   GetDeviceString(params));
211     params.stream->ThenMemZero(&dest_addr, dest_addr.size());
212   }
213   return Status::OK();
214 #else   // XLA_ENABLE_XCCL
215   return Unimplemented(
216       "NCCL support is not available: this binary was not built with a CUDA "
217       "compiler, which is necessary to build the NCCL source library.");
218 #endif  // XLA_ENABLE_XCCL
219 }
220 
221 }  // namespace gpu
222 }  // namespace xla
223