1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
18
19 #include <memory>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/executable_run_options.h"
23 #include "tensorflow/compiler/xla/service/computation_placer.h"
24 #include "tensorflow/compiler/xla/service/global_device_id.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30
31 namespace xla {
32
33 enum class ReductionKind { SUM, PRODUCT, MIN, MAX };
34
35 // Attempts to match instruction to one of the possible cases for ReductionKind.
36 absl::optional<ReductionKind> MatchReductionInstruction(
37 const HloInstruction* hlo);
38
39 // Attempts to match computation to one of the possible cases in ReductionKind.
40 absl::optional<ReductionKind> MatchReductionComputation(
41 const HloComputation* computation);
42
43 // Figures out which IDs are participating in the collective subgroup.
44 // An empty `groups` indicates that all [0, total_participant_count) IDs
45 // are participating. Note that for CollectiveOpGroupMode::kFlattenedID,
46 // groups cannot be empty, so `total_participant_count` is an optional.
47 StatusOr<std::vector<int>> GetParticipatingIDs(
48 int current_id, absl::optional<int> total_participant_count,
49 absl::Span<const ReplicaGroup> groups);
50
51 // There are broadly 4 modes that collective communication ops use to describe
52 // which sets of devices are participating with a given device in the operation.
53 // These modes are determined by the values of channel_id (optional) and
54 // use_global_device_ids (optional). The modes are as follows:
55 //
56 // kCrossReplica:
57 // implied by: no channel id, use_global_device_ids = false, or
58 // no channel_id, no use_global_device_ids:
59 // replica_groups contain replica_id, group contains all replicas for the
60 // current partition
61 //
62 // kCrossPartition:
63 // implied by: channel_id is set, no use_global_device_ids:
64 // replica_groups contain partition_id, group contains all partitions for the
65 // current replica.
66 //
67 // kCrossReplicaAndPartition:
68 // implied by: channel_id is set, use_global_device_ids = false:
69 // replica_groups contain replica_id, group contains all replicas for all
70 // partitions (as opposed to just current partition).
71 //
72 // kFlattenedID:
73 // implied by: channel_id is set, use_global_device_ids = true:
74 // replica_groups contain flattened-ids, group contains devices that are
75 // listed in the flattened-id list.
76 //
77 // Rest of the combinations are invalid.
78 //
79 // Since the actual value of channel_id does not matter, we use a bool argument
80 // `has_channel_id`, and optional<bool> for use_global_device_ids.
81 // Note that use_global_device_ids true requires channel_id to be set as well.
82 // Additionally, if use_global_device_ids = true, replica groups cannot be
83 // empty (verified in the HLO verifier).
84 enum class CollectiveOpGroupMode {
85 kCrossReplica,
86 kCrossPartition,
87 kCrossReplicaAndPartition,
88 kFlattenedID,
89 };
90
91 absl::string_view CollectiveOpGroupModeToString(
92 CollectiveOpGroupMode group_mode);
93
94 // Returns the group formation mode implied by (a) whether the operation has
95 // channel_id and (b) if it has use_global_device_ids and if yes, its value.
96 StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
97 bool has_channel_id, absl::optional<bool> use_global_device_ids);
98
99 // Figures out subgroups of participating devices from given replica_groups and
100 // group_mode.
101 //
102 // Returns list of participants, where each participant is a list of
103 // GlobalDeviceIds.
104 //
105 // For example:
106 // device_assignment={{33, 34}, {44, 45}, {55, 56}} 3 replicas 2 partitions
107 // group_mode=CollectiveOpGroupMode::kCrossReplica
108 // replica_groups={{0}, {1, 2}}
109 //
110 // This functions returns {{33, 34}, {44, 45, 55, 56}}
111 // There are 2 subgroups of participating devices {33, 34}, {44, 45, 55, 56}.
112 StatusOr<std::vector<std::vector<GlobalDeviceId>>>
113 GetParticipatingDevicesGroups(const DeviceAssignment& device_assignment,
114 absl::Span<const ReplicaGroup> replica_groups,
115 CollectiveOpGroupMode group_mode);
116
117 // Figures out which devices are participating in the collective subgroup.
118 StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
119 GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
120 absl::Span<const ReplicaGroup> replica_groups,
121 CollectiveOpGroupMode group_mode);
122
123 // Returns true if the two replica group are orthogonal.
124 bool ReplicaGroupsOrthogonal(absl::Span<const ReplicaGroup> first,
125 absl::Span<const ReplicaGroup> second);
126
127 // Key that identifies a particular Rendezvous object in our global hashtable.
128 // This determines which calls to ExecuteOnStream communicate with each other.
129 // The rules are as follows.
130 //
131 // * Only ops with the same RunId can communicate with each other. (This is the
132 // whole purpose of RunId).
133 //
134 // * Only ops with the same set of participating replicas can communicate with
135 // each other. This is how we separate out different replica groups (e.g. a
136 // single AllReduce HLO might do two reductions, between say GPUs {0,2} and
137 // {1,3}).
138 //
139 // * Only ops with the same opcode can communicate with each other. At the
140 // moment we only support kAllReduce, so we don't check for this explicitly.
141 //
142 // * For cross-module all-reduces (i.e. instr->channel_id().has_value()),
143 // only ops with the same value for channel_id() can communicate with each
144 // other.
145 //
146 // * For cross-replica (i.e. same-module) all-reduces (i.e.
147 // !channel_id().has_value()), only ops from the same module (as
148 // identified by its unique_id()) can communicate with each other.
149 //
150 struct RendezvousKey {
151 enum CollectiveOpKind {
152 kCrossModule,
153 kCrossReplica,
154 };
155
RendezvousKeyRendezvousKey156 explicit RendezvousKey(const RunId& run_id,
157 std::vector<GlobalDeviceId> global_devices,
158 int num_local_participants,
159 CollectiveOpKind collective_op_kind, int64_t op_id)
160 : run_id(run_id),
161 global_devices(std::move(global_devices)),
162 num_local_participants(num_local_participants),
163 collective_op_kind(collective_op_kind),
164 op_id(op_id) {}
165
166 template <typename H>
AbslHashValueRendezvousKey167 friend H AbslHashValue(H h, const RendezvousKey& k) {
168 return H::combine(std::move(h), k.run_id, k.global_devices,
169 k.num_local_participants,
170 static_cast<int>(k.collective_op_kind), k.op_id);
171 }
172 friend bool operator==(const RendezvousKey& a, const RendezvousKey& b) {
173 return a.run_id == b.run_id && a.global_devices == b.global_devices &&
174 a.num_local_participants == b.num_local_participants &&
175 a.collective_op_kind == b.collective_op_kind && //
176 a.op_id == b.op_id;
177 }
178 friend bool operator!=(const RendezvousKey& a, const RendezvousKey& b) {
179 return !(a == b);
180 }
181
CollectiveOpKindStringRendezvousKey182 absl::string_view CollectiveOpKindString() const {
183 switch (collective_op_kind) {
184 case kCrossModule:
185 return "cross_module";
186 case kCrossReplica:
187 return "cross_replica";
188 }
189 }
190
ToStringRendezvousKey191 string ToString() const {
192 return absl::StrFormat(
193 "RendezvousKey{run_id=%s, global_devices=[%s], "
194 "num_local_participants=%d, collective_op_kind=%s, op_id=%d}",
195 run_id.ToString(), GlobalDeviceIdsToString(global_devices),
196 num_local_participants, CollectiveOpKindString(), op_id);
197 }
198
199 RunId run_id;
200 std::vector<GlobalDeviceId> global_devices;
201 int num_local_participants;
202 CollectiveOpKind collective_op_kind;
203 int64 op_id;
204 };
205
206 template <typename DescFn>
WaitAndLogIfStuck(tensorflow::BlockingCounter * counter,const DescFn & desc_fn)207 void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
208 const DescFn& desc_fn) {
209 VLOG(3) << "Begin: " << desc_fn();
210 const std::chrono::milliseconds timeout(5000);
211 bool ok = counter->WaitFor(timeout);
212 if (ok) {
213 VLOG(3) << "Finished: " << desc_fn();
214 return;
215 }
216 LOG(ERROR) << "This thread has been waiting for " << timeout.count()
217 << "ms for and may be stuck: " << desc_fn();
218 counter->Wait();
219 LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. "
220 "Perhaps the timeout is too short: "
221 << desc_fn();
222 }
223
224 // Participant data for each rendezvous.
225 struct ParticipantData {
ParticipantDataParticipantData226 explicit ParticipantData(const RendezvousKey& rendezvous_key)
227 : rendezvous_key(rendezvous_key) {}
228
~ParticipantDataParticipantData229 virtual ~ParticipantData() {}
230
231 RendezvousKey rendezvous_key;
232
233 virtual std::string ToString() const = 0;
234 };
235
236 // Encapsulates parameters to Rendezvous::SubmitParticipant.
237 struct AllReduceParticipantData : ParticipantData {
AllReduceParticipantDataAllReduceParticipantData238 AllReduceParticipantData(const RendezvousKey& rendezvous_key_p,
239 int64_t device_ordinal_p, se::Stream* stream_p)
240 : ParticipantData(rendezvous_key_p),
241 device_ordinal(device_ordinal_p),
242 stream(stream_p) {}
243
244 // TODO(b/125951860): We should vet that we're buffer allocating such that
245 // source_buffer == destination_buffer if that avoids a NCCL copy (will depend
246 // on how well the NCCL in-place implementation performs vs the out-of-place
247 // implementation).
248 struct Buffer {
249 int64 element_count;
250 se::DeviceMemoryBase source_data;
251 se::DeviceMemoryBase destination_data;
252 PrimitiveType primitive_type;
253 };
254 int64 device_ordinal;
255 se::Stream* stream;
256 std::vector<Buffer> buffers;
257
258 ReductionKind reduction_kind;
259
260 // For each local all-reduce participant a (global ID, local device ordinal)
261 // pair for the participant. Participants are in no particular order.
262 std::vector<std::pair<GlobalDeviceId, int64>> local_devices;
263
ToStringAllReduceParticipantData264 string ToString() const override {
265 std::vector<std::string> buffer_strs;
266 for (const Buffer& buffer : buffers) {
267 buffer_strs.push_back(
268 absl::StrFormat("{element_count=%d}", buffer.element_count));
269 }
270 return absl::StrFormat(
271 "AllReduceParticipantData{buffers=[%s], rendezvous_key=%s, "
272 "device_ordinal=%d, stream=%p}",
273 absl::StrJoin(buffer_strs, ","), rendezvous_key.ToString(),
274 device_ordinal, stream);
275 }
276 };
277
278 // The set of threads that want to do a collective op together all pick the same
279 // Rendezvous object out of the global cache and call SubmitParticipant.
280 //
281 // The Rendezvous instance handles waiting for all threads to join, ensuring
282 // that a clique exists for the desired set of GPUs, etc.
283 //
284 // Rendezvous objects can only be used once.
285 //
286 // I: Participant data.
287 // O: Participant output.
288 template <typename I, typename O,
289 typename =
290 std::enable_if_t<std::is_base_of<ParticipantData, I>::value>>
291 class Rendezvous {
292 public:
~Rendezvous()293 virtual ~Rendezvous() {}
Rendezvous(const RendezvousKey & k)294 explicit Rendezvous(const RendezvousKey& k) : key_(k) {}
295
296 // Submit a participant to the rendezvous. We get the rendezvous from
297 // `rendezvous_getter`, which we can then use to drop the existing reference.
SubmitParticipant(std::function<std::shared_ptr<Rendezvous<I,O>> ()> rendezvous_getter,I participant)298 static StatusOr<O> SubmitParticipant(
299 std::function<std::shared_ptr<Rendezvous<I, O>>()> rendezvous_getter,
300 I participant) {
301 std::shared_ptr<Rendezvous<I, O>> rendezvous = rendezvous_getter();
302 TF_ASSIGN_OR_RETURN(auto p, rendezvous->SubmitParticipant(participant));
303
304 // Drop our reference to the Rendezvous and wait for all other threads to do
305 // the same. If we didn't do this, one of the threads could run past this
306 // point, reenter ExecuteOnStream for another all-reduce, and attempt to
307 // reuse the Rendezvous!
308 //
309 // An alternative way of accomplishing this goal would be to implement
310 // RefcountingHashMap::erase() and call it during SubmitParticipant. But
311 // erase() is deceptively complex to implement correctly.
312 std::shared_ptr<tensorflow::BlockingCounter> blocking_counter = p.second;
313 rendezvous.reset();
314 blocking_counter->DecrementCount();
315 xla::WaitAndLogIfStuck(blocking_counter.get(), [&] {
316 return absl::StrFormat(
317 "participant waiting for all threads to drop their reference to the "
318 "rendezvous: %p",
319 rendezvous.get());
320 });
321 return std::move(p.first);
322 }
323
324 protected:
325 // Returns domain-specific output O and whether this replica is primary.
326 virtual StatusOr<O> RunCollectiveOp(const I& participant) = 0;
327
328 // Initialize the rendezvous by the first ("primary") thread which reaches the
329 // barrier. Returns whether this thread is primary.
InitializationBarrier()330 bool InitializationBarrier() {
331 tensorflow::mutex_lock lock(mu_);
332 if (!initialized_) {
333 initialized_ = true;
334 return true;
335 }
336 return false;
337 }
338
339 tensorflow::mutex mu_;
340
341 bool initialized_ TF_GUARDED_BY(mu_) = false;
342
343 std::vector<I> participants_ TF_GUARDED_BY(mu_);
344
345 private:
346 // Runs the all-reduce on the given thread. If successful, returns
347 // - a handle to the clique that was used, so that the caller may keep the
348 // clique alive if it chooses.
349 // - a BlockingCounter initialized to the number of participants, so that
350 // the caller can coordinate with the participants one last time if it
351 // chooses. This is useful for coordinating destruction of the Rendezvous.
352 StatusOr<std::pair<O, std::shared_ptr<tensorflow::BlockingCounter>>>
SubmitParticipant(const I & participant)353 SubmitParticipant(const I& participant) {
354 {
355 tensorflow::mutex_lock lock(mu_);
356 CHECK(!initialized_);
357
358 // Spot check for consistent replica counts among submitting threads.
359 if (!participants_.empty() &&
360 participants_.back().rendezvous_key != participant.rendezvous_key) {
361 return InvalidArgument(
362 "Mismatch among all-reduce participants. Expected same "
363 "replica-count, element-count, and rendezvous-key but were %s and "
364 "%s",
365 participants_.back().ToString(), participant.ToString());
366 }
367 participants_.push_back(participant);
368 }
369
370 // Wait for all participants to arrive.
371 all_participants_present_.DecrementCount();
372 WaitAndLogIfStuck(&all_participants_present_, [&] {
373 return absl::StrFormat(
374 "participant %s waiting for all participants to arrive at rendezvous "
375 "%s",
376 participant.ToString(), key_.ToString());
377 });
378
379 TF_ASSIGN_OR_RETURN(O output, RunCollectiveOp(participant));
380 return std::make_pair(std::move(output), returned_blocking_counter_);
381 }
382
383 const RendezvousKey key_;
384
385 tensorflow::BlockingCounter all_participants_present_{
386 key_.num_local_participants};
387
388 // tensorflow::BlockingCounter returned by SubmitParticipant.
389 std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_{
390 std::make_shared<tensorflow::BlockingCounter>(
391 key_.num_local_participants)};
392 };
393
394 } // end namespace xla
395
396 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
397