• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/executable_run_options.h"
23 #include "tensorflow/compiler/xla/service/computation_placer.h"
24 #include "tensorflow/compiler/xla/service/global_device_id.h"
25 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/core/lib/core/blocking_counter.h"
31 
32 namespace xla {
33 
34 enum class ReductionKind { SUM, PRODUCT, MIN, MAX };
35 
36 // Attempts to match computation to one of the possible cases in ReductionKind.
37 absl::optional<ReductionKind> MatchReductionComputation(
38     const HloComputation* computation);
39 
40 // Figures out which replicas are participating in the collective subgroup.
41 // An empty `replica_groups` indicates that all replicas are participating.
42 StatusOr<std::vector<int>> GetParticipatingReplicas(
43     int replica_id, int total_replica_count,
44     absl::Span<const ReplicaGroup> replica_groups);
45 
46 // Figures out which devices are participating in the collective subgroup.
47 // An empty `replica_groups` indicates that all replicas are participating.
48 StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
49     GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
50     int total_replica_count, absl::Span<const ReplicaGroup> replica_groups);
51 
52 // Key that identifies a particular Rendezvous object in our global hashtable.
53 // This determines which calls to ExecuteOnStream communicate with each other.
54 // The rules are as follows.
55 //
56 // * Only ops with the same RunId can communicate with each other. (This is the
57 //   whole purpose of RunId).
58 //
59 // * Only ops with the same set of participating replicas can communicate with
60 //   each other.  This is how we separate out different replica groups (e.g. a
61 //   single AllReduce HLO might do two reductions, between say GPUs {0,2} and
62 //   {1,3}).
63 //
64 // * Only ops with the same opcode can communicate with each other.  At the
65 //   moment we only support kAllReduce, so we don't check for this explicitly.
66 //
67 // * For cross-module all-reduces (i.e. instr->channel_id().has_value()),
68 //   only ops with the same value for channel_id() can communicate with each
69 //   other.
70 //
71 // * For cross-replica (i.e. same-module) all-reduces (i.e.
72 //   !channel_id().has_value()), only ops from the same module (as
73 //   identified by its unique_id()) can communicate with each other.
74 //
75 struct RendezvousKey {
76   enum CollectiveOpKind {
77     kCrossModule,
78     kCrossReplica,
79   };
80 
RendezvousKeyRendezvousKey81   explicit RendezvousKey(const RunId& run_id,
82                          std::vector<GlobalDeviceId> global_devices,
83                          int num_local_participants,
84                          CollectiveOpKind collective_op_kind, int64 op_id)
85       : run_id(run_id),
86         global_devices(std::move(global_devices)),
87         num_local_participants(num_local_participants),
88         collective_op_kind(collective_op_kind),
89         op_id(op_id) {}
90 
91   template <typename H>
AbslHashValueRendezvousKey92   friend H AbslHashValue(H h, const RendezvousKey& k) {
93     return H::combine(std::move(h), k.run_id, k.global_devices,
94                       k.num_local_participants,
95                       static_cast<int>(k.collective_op_kind), k.op_id);
96   }
97   friend bool operator==(const RendezvousKey& a, const RendezvousKey& b) {
98     return a.run_id == b.run_id && a.global_devices == b.global_devices &&
99            a.num_local_participants == b.num_local_participants &&
100            a.collective_op_kind == b.collective_op_kind &&  //
101            a.op_id == b.op_id;
102   }
103   friend bool operator!=(const RendezvousKey& a, const RendezvousKey& b) {
104     return !(a == b);
105   }
106 
ToStringRendezvousKey107   string ToString() const {
108     return absl::StrFormat(
109         "RendezvousKey{run_id=%s, global_devices=[%s], "
110         "num_local_participants=%d, collective_op_kind=%d, op_id=%d}",
111         run_id.ToString(), GlobalDeviceIdsToString(global_devices),
112         num_local_participants, static_cast<int>(collective_op_kind), op_id);
113   }
114 
115   RunId run_id;
116   std::vector<GlobalDeviceId> global_devices;
117   int num_local_participants;
118   CollectiveOpKind collective_op_kind;
119   int64 op_id;
120 };
121 
122 template <typename DescFn>
WaitAndLogIfStuck(tensorflow::BlockingCounter * counter,const DescFn & desc_fn)123 void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
124                        const DescFn& desc_fn) {
125   VLOG(3) << "Begin: " << desc_fn();
126   const std::chrono::milliseconds timeout(5000);
127   bool ok = counter->WaitFor(timeout);
128   if (ok) {
129     VLOG(3) << "Finished: " << desc_fn();
130     return;
131   }
132   LOG(ERROR) << "This thread has been waiting for " << timeout.count()
133              << "ms for and may be stuck: " << desc_fn();
134   counter->Wait();
135   LOG(ERROR) << "Thread is unstuck!  Warning above was a false-positive.  "
136                 "Perhaps the timeout is too short: "
137              << desc_fn();
138 }
139 
140 // Participant data for each rendezvous.
141 struct ParticipantData {
ParticipantDataParticipantData142   ParticipantData(const RendezvousKey& rendezvous_key, int64 device_ordinal,
143                   se::Stream* stream)
144       : rendezvous_key(rendezvous_key),
145         device_ordinal(device_ordinal),
146         stream(stream) {}
147 
~ParticipantDataParticipantData148   virtual ~ParticipantData() {}
149 
150   RendezvousKey rendezvous_key;
151   int64 device_ordinal;
152   se::Stream* stream;
153 
154   virtual std::string ToString() const = 0;
155 };
156 
157 // Encapsulates parameters to Rendezvous::SubmitParticipant.
158 struct AllReduceParticipantData : ParticipantData {
AllReduceParticipantDataAllReduceParticipantData159   AllReduceParticipantData(const RendezvousKey& rendezvous_key_p,
160                            int64 device_ordinal_p, se::Stream* stream_p)
161       : ParticipantData(rendezvous_key_p, device_ordinal_p, stream_p) {}
162 
163   // TODO(b/125951860): We should vet that we're buffer allocating such that
164   // source_buffer == destination_buffer if that avoids a NCCL copy (will depend
165   // on how well the NCCL in-place implementation performs vs the out-of-place
166   // implementation).
167   struct Buffer {
168     int64 element_count;
169     se::DeviceMemoryBase source_data;
170     se::DeviceMemoryBase destination_data;
171     PrimitiveType primitive_type;
172   };
173   std::vector<Buffer> buffers;
174   const gpu::NcclUniqueIdCallback* nccl_unique_id_callback = nullptr;
175 
176   ReductionKind reduction_kind;
177 
178   // For each local all-reduce participant a (global ID, local device ordinal)
179   // pair for the participant. Participants are in no particular order.
180   std::vector<std::pair<GlobalDeviceId, int64>> local_devices;
181 
ToStringAllReduceParticipantData182   string ToString() const override {
183     std::vector<std::string> buffer_strs;
184     for (const Buffer& buffer : buffers) {
185       buffer_strs.push_back(
186           absl::StrFormat("{element_count=%d}", buffer.element_count));
187     }
188     return absl::StrFormat(
189         "AllReduceParticipantData{buffers=[%s], rendezvous_key=%s, "
190         "device_ordinal=%d, stream=%p}",
191         absl::StrJoin(buffer_strs, ","), rendezvous_key.ToString(),
192         device_ordinal, stream);
193   }
194 };
195 
196 // The set of threads that want to do a collective op together all pick the same
197 // Rendezvous object out of the global cache and call SubmitParticipant.
198 //
199 // The Rendezvous instance handles waiting for all threads to join, ensuring
200 // that a clique exists for the desired set of GPUs, etc.
201 //
202 // Rendezvous objects can only be used once.
203 //
204 // I: Participant data.
205 // O: Participant output.
206 template <typename I, typename O,
207           typename =
208               std::enable_if_t<std::is_base_of<ParticipantData, I>::value>>
209 class Rendezvous {
210  public:
~Rendezvous()211   virtual ~Rendezvous() {}
Rendezvous(const RendezvousKey & k)212   explicit Rendezvous(const RendezvousKey& k) : key_(k) {}
213 
214   // Submit a participant to the rendezvous. We get the rendezvous from
215   // `rendezvous_getter`, which we can then use to drop the existing reference.
SubmitParticipant(std::function<std::shared_ptr<Rendezvous<I,O>> ()> rendezvous_getter,I participant)216   static StatusOr<O> SubmitParticipant(
217       std::function<std::shared_ptr<Rendezvous<I, O>>()> rendezvous_getter,
218       I participant) {
219     std::shared_ptr<Rendezvous<I, O>> rendezvous = rendezvous_getter();
220     TF_ASSIGN_OR_RETURN(auto p, rendezvous->SubmitParticipant(participant));
221 
222     // Drop our reference to the Rendezvous and wait for all other threads to do
223     // the same.  If we didn't do this, one of the threads could run past this
224     // point, reenter ExecuteOnStream for another all-reduce, and attempt to
225     // reuse the Rendezvous!
226     //
227     // An alternative way of accomplishing this goal would be to implement
228     // RefcountingHashMap::erase() and call it during SubmitParticipant.  But
229     // erase() is deceptively complex to implement correctly.
230     std::shared_ptr<tensorflow::BlockingCounter> blocking_counter = p.second;
231     rendezvous.reset();
232     blocking_counter->DecrementCount();
233     xla::WaitAndLogIfStuck(blocking_counter.get(), [&] {
234       return absl::StrFormat(
235           "participant waiting for all threads to drop their reference to the "
236           "rendezvous: %p",
237           rendezvous.get());
238     });
239     return std::move(p.first);
240   }
241 
242  protected:
243   // Returns domain-specific output O and whether this replica is primary.
244   virtual StatusOr<O> RunCollectiveOp(const I& participant) = 0;
245 
246   // Initialize the rendezvous by the first ("primary") thread which reaches the
247   // barrier. Returns whether this thread is primary.
InitializationBarrier()248   bool InitializationBarrier() {
249     tensorflow::mutex_lock lock(mu_);
250     if (!initialized_) {
251       initialized_ = true;
252       return true;
253     }
254     return false;
255   }
256 
257   tensorflow::mutex mu_;
258 
259   bool initialized_ TF_GUARDED_BY(mu_) = false;
260 
261   std::vector<I> participants_ TF_GUARDED_BY(mu_);
262 
263  private:
264   // Runs the all-reduce on the given thread.  If successful, returns
265   //  - a handle to the clique that was used, so that the caller may keep the
266   //    clique alive if it chooses.
267   //  - a BlockingCounter initialized to the number of participants, so that
268   //    the caller can coordinate with the participants one last time if it
269   //    chooses.  This is useful for coordinating destruction of the Rendezvous.
270   StatusOr<std::pair<O, std::shared_ptr<tensorflow::BlockingCounter>>>
SubmitParticipant(const I & participant)271   SubmitParticipant(const I& participant) {
272     {
273       tensorflow::mutex_lock lock(mu_);
274       CHECK(!initialized_);
275 
276       // Spot check for consistent replica counts among submitting threads.
277       if (!participants_.empty() &&
278           participants_.back().rendezvous_key != participant.rendezvous_key) {
279         return InvalidArgument(
280             "Mismatch among all-reduce participants.  Expected same "
281             "replica-count, element-count, and rendezvous-key but were %s and "
282             "%s",
283             participants_.back().ToString(), participant.ToString());
284       }
285       participants_.push_back(participant);
286     }
287 
288     // Wait for all participants to arrive.
289     all_participants_present_.DecrementCount();
290     WaitAndLogIfStuck(&all_participants_present_, [&] {
291       return absl::StrFormat(
292           "participant for device ordinal %d, stream %p waiting for all "
293           "participants to arrive at rendezvous %s",
294           participant.device_ordinal, participant.stream, key_.ToString());
295     });
296 
297     TF_ASSIGN_OR_RETURN(O output, RunCollectiveOp(participant));
298     return std::make_pair(std::move(output), returned_blocking_counter_);
299   }
300 
301   const RendezvousKey key_;
302 
303   tensorflow::BlockingCounter all_participants_present_{
304       key_.num_local_participants};
305 
306   // tensorflow::BlockingCounter returned by SubmitParticipant.
307   std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_{
308       std::make_shared<tensorflow::BlockingCounter>(
309           key_.num_local_participants)};
310 };
311 
312 }  // end namespace xla
313 
314 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
315