• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/executable_run_options.h"
23 #include "tensorflow/compiler/xla/service/computation_placer.h"
24 #include "tensorflow/compiler/xla/service/global_device_id.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 
31 namespace xla {
32 
33 enum class ReductionKind { SUM, PRODUCT, MIN, MAX };
34 
35 // Attempts to match instruction to one of the possible cases for ReductionKind.
36 absl::optional<ReductionKind> MatchReductionInstruction(
37     const HloInstruction* hlo);
38 
39 // Attempts to match computation to one of the possible cases in ReductionKind.
40 absl::optional<ReductionKind> MatchReductionComputation(
41     const HloComputation* computation);
42 
43 // Figures out which IDs are participating in the collective subgroup.
44 // An empty `groups` indicates that all [0, total_participant_count) IDs
45 // are participating. Note that for CollectiveOpGroupMode::kFlattenedID,
46 // groups cannot be empty, so `total_participant_count` is an optional.
47 StatusOr<std::vector<int>> GetParticipatingIDs(
48     int current_id, absl::optional<int> total_participant_count,
49     absl::Span<const ReplicaGroup> groups);
50 
51 // There are broadly 4 modes that collective communication ops use to describe
52 // which sets of devices are participating with a given device in the operation.
53 // These modes are determined by the values of channel_id (optional) and
54 // use_global_device_ids (optional). The modes are as follows:
55 //
56 // kCrossReplica:
57 //    implied by: no channel id, use_global_device_ids = false, or
58 //                no channel_id, no use_global_device_ids:
59 //    replica_groups contain replica_id, group contains all replicas for the
60 //    current partition
61 //
62 // kCrossPartition:
63 //    implied by: channel_id is set, no use_global_device_ids:
64 //    replica_groups contain partition_id, group contains all partitions for the
65 //    current replica.
66 //
67 // kCrossReplicaAndPartition:
68 //    implied by: channel_id is set, use_global_device_ids = false:
69 //    replica_groups contain replica_id, group contains all replicas for all
70 //    partitions (as opposed to just current partition).
71 //
72 // kFlattenedID:
73 //    implied by: channel_id is set, use_global_device_ids = true:
74 //    replica_groups contain flattened-ids, group contains devices that are
75 //    listed in the flattened-id list.
76 //
77 // Rest of the combinations are invalid.
78 //
79 // Since the actual value of channel_id does not matter, we use a bool argument
80 // `has_channel_id`, and optional<bool> for use_global_device_ids.
81 // Note that use_global_device_ids true requires channel_id to be set as well.
82 // Additionally, if use_global_device_ids = true, replica groups cannot be
83 // empty (verified in the HLO verifier).
84 enum class CollectiveOpGroupMode {
85   kCrossReplica,
86   kCrossPartition,
87   kCrossReplicaAndPartition,
88   kFlattenedID,
89 };
90 
91 absl::string_view CollectiveOpGroupModeToString(
92     CollectiveOpGroupMode group_mode);
93 
94 // Returns the group formation mode implied by (a) whether the operation has
95 // channel_id and (b) if it has use_global_device_ids and if yes, its value.
96 StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
97     bool has_channel_id, absl::optional<bool> use_global_device_ids);
98 
99 // Figures out subgroups of participating devices from given replica_groups and
100 // group_mode.
101 //
102 // Returns list of participants, where each participant is a list of
103 // GlobalDeviceIds.
104 //
105 // For example:
106 //   device_assignment={{33, 34}, {44, 45}, {55, 56}}  3 replicas 2 partitions
107 //   group_mode=CollectiveOpGroupMode::kCrossReplica
108 //   replica_groups={{0}, {1, 2}}
109 //
110 //   This functions returns {{33, 34}, {44, 45, 55, 56}}
111 //   There are 2 subgroups of participating devices {33, 34}, {44, 45, 55, 56}.
112 StatusOr<std::vector<std::vector<GlobalDeviceId>>>
113 GetParticipatingDevicesGroups(const DeviceAssignment& device_assignment,
114                               absl::Span<const ReplicaGroup> replica_groups,
115                               CollectiveOpGroupMode group_mode);
116 
117 // Figures out which devices are participating in the collective subgroup.
118 StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
119     GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
120     absl::Span<const ReplicaGroup> replica_groups,
121     CollectiveOpGroupMode group_mode);
122 
123 // Returns true if the two replica group are orthogonal.
124 bool ReplicaGroupsOrthogonal(absl::Span<const ReplicaGroup> first,
125                              absl::Span<const ReplicaGroup> second);
126 
127 // Key that identifies a particular Rendezvous object in our global hashtable.
128 // This determines which calls to ExecuteOnStream communicate with each other.
129 // The rules are as follows.
130 //
131 // * Only ops with the same RunId can communicate with each other. (This is the
132 //   whole purpose of RunId).
133 //
134 // * Only ops with the same set of participating replicas can communicate with
135 //   each other.  This is how we separate out different replica groups (e.g. a
136 //   single AllReduce HLO might do two reductions, between say GPUs {0,2} and
137 //   {1,3}).
138 //
139 // * Only ops with the same opcode can communicate with each other.  At the
140 //   moment we only support kAllReduce, so we don't check for this explicitly.
141 //
142 // * For cross-module all-reduces (i.e. instr->channel_id().has_value()),
143 //   only ops with the same value for channel_id() can communicate with each
144 //   other.
145 //
146 // * For cross-replica (i.e. same-module) all-reduces (i.e.
147 //   !channel_id().has_value()), only ops from the same module (as
148 //   identified by its unique_id()) can communicate with each other.
149 //
150 struct RendezvousKey {
151   enum CollectiveOpKind {
152     kCrossModule,
153     kCrossReplica,
154   };
155 
RendezvousKeyRendezvousKey156   explicit RendezvousKey(const RunId& run_id,
157                          std::vector<GlobalDeviceId> global_devices,
158                          int num_local_participants,
159                          CollectiveOpKind collective_op_kind, int64_t op_id)
160       : run_id(run_id),
161         global_devices(std::move(global_devices)),
162         num_local_participants(num_local_participants),
163         collective_op_kind(collective_op_kind),
164         op_id(op_id) {}
165 
166   template <typename H>
AbslHashValueRendezvousKey167   friend H AbslHashValue(H h, const RendezvousKey& k) {
168     return H::combine(std::move(h), k.run_id, k.global_devices,
169                       k.num_local_participants,
170                       static_cast<int>(k.collective_op_kind), k.op_id);
171   }
172   friend bool operator==(const RendezvousKey& a, const RendezvousKey& b) {
173     return a.run_id == b.run_id && a.global_devices == b.global_devices &&
174            a.num_local_participants == b.num_local_participants &&
175            a.collective_op_kind == b.collective_op_kind &&  //
176            a.op_id == b.op_id;
177   }
178   friend bool operator!=(const RendezvousKey& a, const RendezvousKey& b) {
179     return !(a == b);
180   }
181 
CollectiveOpKindStringRendezvousKey182   absl::string_view CollectiveOpKindString() const {
183     switch (collective_op_kind) {
184       case kCrossModule:
185         return "cross_module";
186       case kCrossReplica:
187         return "cross_replica";
188     }
189   }
190 
ToStringRendezvousKey191   string ToString() const {
192     return absl::StrFormat(
193         "RendezvousKey{run_id=%s, global_devices=[%s], "
194         "num_local_participants=%d, collective_op_kind=%s, op_id=%d}",
195         run_id.ToString(), GlobalDeviceIdsToString(global_devices),
196         num_local_participants, CollectiveOpKindString(), op_id);
197   }
198 
199   RunId run_id;
200   std::vector<GlobalDeviceId> global_devices;
201   int num_local_participants;
202   CollectiveOpKind collective_op_kind;
203   int64 op_id;
204 };
205 
206 template <typename DescFn>
WaitAndLogIfStuck(tensorflow::BlockingCounter * counter,const DescFn & desc_fn)207 void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
208                        const DescFn& desc_fn) {
209   VLOG(3) << "Begin: " << desc_fn();
210   const std::chrono::milliseconds timeout(5000);
211   bool ok = counter->WaitFor(timeout);
212   if (ok) {
213     VLOG(3) << "Finished: " << desc_fn();
214     return;
215   }
216   LOG(ERROR) << "This thread has been waiting for " << timeout.count()
217              << "ms for and may be stuck: " << desc_fn();
218   counter->Wait();
219   LOG(ERROR) << "Thread is unstuck!  Warning above was a false-positive.  "
220                 "Perhaps the timeout is too short: "
221              << desc_fn();
222 }
223 
224 // Participant data for each rendezvous.
225 struct ParticipantData {
ParticipantDataParticipantData226   explicit ParticipantData(const RendezvousKey& rendezvous_key)
227       : rendezvous_key(rendezvous_key) {}
228 
~ParticipantDataParticipantData229   virtual ~ParticipantData() {}
230 
231   RendezvousKey rendezvous_key;
232 
233   virtual std::string ToString() const = 0;
234 };
235 
236 // Encapsulates parameters to Rendezvous::SubmitParticipant.
237 struct AllReduceParticipantData : ParticipantData {
AllReduceParticipantDataAllReduceParticipantData238   AllReduceParticipantData(const RendezvousKey& rendezvous_key_p,
239                            int64_t device_ordinal_p, se::Stream* stream_p)
240       : ParticipantData(rendezvous_key_p),
241         device_ordinal(device_ordinal_p),
242         stream(stream_p) {}
243 
244   // TODO(b/125951860): We should vet that we're buffer allocating such that
245   // source_buffer == destination_buffer if that avoids a NCCL copy (will depend
246   // on how well the NCCL in-place implementation performs vs the out-of-place
247   // implementation).
248   struct Buffer {
249     int64 element_count;
250     se::DeviceMemoryBase source_data;
251     se::DeviceMemoryBase destination_data;
252     PrimitiveType primitive_type;
253   };
254   int64 device_ordinal;
255   se::Stream* stream;
256   std::vector<Buffer> buffers;
257 
258   ReductionKind reduction_kind;
259 
260   // For each local all-reduce participant a (global ID, local device ordinal)
261   // pair for the participant. Participants are in no particular order.
262   std::vector<std::pair<GlobalDeviceId, int64>> local_devices;
263 
ToStringAllReduceParticipantData264   string ToString() const override {
265     std::vector<std::string> buffer_strs;
266     for (const Buffer& buffer : buffers) {
267       buffer_strs.push_back(
268           absl::StrFormat("{element_count=%d}", buffer.element_count));
269     }
270     return absl::StrFormat(
271         "AllReduceParticipantData{buffers=[%s], rendezvous_key=%s, "
272         "device_ordinal=%d, stream=%p}",
273         absl::StrJoin(buffer_strs, ","), rendezvous_key.ToString(),
274         device_ordinal, stream);
275   }
276 };
277 
278 // The set of threads that want to do a collective op together all pick the same
279 // Rendezvous object out of the global cache and call SubmitParticipant.
280 //
281 // The Rendezvous instance handles waiting for all threads to join, ensuring
282 // that a clique exists for the desired set of GPUs, etc.
283 //
284 // Rendezvous objects can only be used once.
285 //
286 // I: Participant data.
287 // O: Participant output.
288 template <typename I, typename O,
289           typename =
290               std::enable_if_t<std::is_base_of<ParticipantData, I>::value>>
291 class Rendezvous {
292  public:
~Rendezvous()293   virtual ~Rendezvous() {}
Rendezvous(const RendezvousKey & k)294   explicit Rendezvous(const RendezvousKey& k) : key_(k) {}
295 
296   // Submit a participant to the rendezvous. We get the rendezvous from
297   // `rendezvous_getter`, which we can then use to drop the existing reference.
SubmitParticipant(std::function<std::shared_ptr<Rendezvous<I,O>> ()> rendezvous_getter,I participant)298   static StatusOr<O> SubmitParticipant(
299       std::function<std::shared_ptr<Rendezvous<I, O>>()> rendezvous_getter,
300       I participant) {
301     std::shared_ptr<Rendezvous<I, O>> rendezvous = rendezvous_getter();
302     TF_ASSIGN_OR_RETURN(auto p, rendezvous->SubmitParticipant(participant));
303 
304     // Drop our reference to the Rendezvous and wait for all other threads to do
305     // the same.  If we didn't do this, one of the threads could run past this
306     // point, reenter ExecuteOnStream for another all-reduce, and attempt to
307     // reuse the Rendezvous!
308     //
309     // An alternative way of accomplishing this goal would be to implement
310     // RefcountingHashMap::erase() and call it during SubmitParticipant.  But
311     // erase() is deceptively complex to implement correctly.
312     std::shared_ptr<tensorflow::BlockingCounter> blocking_counter = p.second;
313     rendezvous.reset();
314     blocking_counter->DecrementCount();
315     xla::WaitAndLogIfStuck(blocking_counter.get(), [&] {
316       return absl::StrFormat(
317           "participant waiting for all threads to drop their reference to the "
318           "rendezvous: %p",
319           rendezvous.get());
320     });
321     return std::move(p.first);
322   }
323 
324  protected:
325   // Returns domain-specific output O and whether this replica is primary.
326   virtual StatusOr<O> RunCollectiveOp(const I& participant) = 0;
327 
328   // Initialize the rendezvous by the first ("primary") thread which reaches the
329   // barrier. Returns whether this thread is primary.
InitializationBarrier()330   bool InitializationBarrier() {
331     tensorflow::mutex_lock lock(mu_);
332     if (!initialized_) {
333       initialized_ = true;
334       return true;
335     }
336     return false;
337   }
338 
339   tensorflow::mutex mu_;
340 
341   bool initialized_ TF_GUARDED_BY(mu_) = false;
342 
343   std::vector<I> participants_ TF_GUARDED_BY(mu_);
344 
345  private:
346   // Runs the all-reduce on the given thread.  If successful, returns
347   //  - a handle to the clique that was used, so that the caller may keep the
348   //    clique alive if it chooses.
349   //  - a BlockingCounter initialized to the number of participants, so that
350   //    the caller can coordinate with the participants one last time if it
351   //    chooses.  This is useful for coordinating destruction of the Rendezvous.
352   StatusOr<std::pair<O, std::shared_ptr<tensorflow::BlockingCounter>>>
SubmitParticipant(const I & participant)353   SubmitParticipant(const I& participant) {
354     {
355       tensorflow::mutex_lock lock(mu_);
356       CHECK(!initialized_);
357 
358       // Spot check for consistent replica counts among submitting threads.
359       if (!participants_.empty() &&
360           participants_.back().rendezvous_key != participant.rendezvous_key) {
361         return InvalidArgument(
362             "Mismatch among all-reduce participants.  Expected same "
363             "replica-count, element-count, and rendezvous-key but were %s and "
364             "%s",
365             participants_.back().ToString(), participant.ToString());
366       }
367       participants_.push_back(participant);
368     }
369 
370     // Wait for all participants to arrive.
371     all_participants_present_.DecrementCount();
372     WaitAndLogIfStuck(&all_participants_present_, [&] {
373       return absl::StrFormat(
374           "participant %s waiting for all participants to arrive at rendezvous "
375           "%s",
376           participant.ToString(), key_.ToString());
377     });
378 
379     TF_ASSIGN_OR_RETURN(O output, RunCollectiveOp(participant));
380     return std::make_pair(std::move(output), returned_blocking_counter_);
381   }
382 
383   const RendezvousKey key_;
384 
385   tensorflow::BlockingCounter all_participants_present_{
386       key_.num_local_participants};
387 
388   // tensorflow::BlockingCounter returned by SubmitParticipant.
389   std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_{
390       std::make_shared<tensorflow::BlockingCounter>(
391           key_.num_local_participants)};
392 };
393 
394 }  // end namespace xla
395 
396 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
397