1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
18
19 #include <memory>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/executable_run_options.h"
23 #include "tensorflow/compiler/xla/service/computation_placer.h"
24 #include "tensorflow/compiler/xla/service/global_device_id.h"
25 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/core/lib/core/blocking_counter.h"
31
32 namespace xla {
33
34 enum class ReductionKind { SUM, PRODUCT, MIN, MAX };
35
36 // Attempts to match computation to one of the possible cases in ReductionKind.
37 absl::optional<ReductionKind> MatchReductionComputation(
38 const HloComputation* computation);
39
40 // Figures out which replicas are participating in the collective subgroup.
41 // An empty `replica_groups` indicates that all replicas are participating.
42 StatusOr<std::vector<int>> GetParticipatingReplicas(
43 int replica_id, int total_replica_count,
44 absl::Span<const ReplicaGroup> replica_groups);
45
46 // Figures out which devices are participating in the collective subgroup.
47 // An empty `replica_groups` indicates that all replicas are participating.
48 StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
49 GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
50 int total_replica_count, absl::Span<const ReplicaGroup> replica_groups);
51
52 // Key that identifies a particular Rendezvous object in our global hashtable.
53 // This determines which calls to ExecuteOnStream communicate with each other.
54 // The rules are as follows.
55 //
56 // * Only ops with the same RunId can communicate with each other. (This is the
57 // whole purpose of RunId).
58 //
59 // * Only ops with the same set of participating replicas can communicate with
60 // each other. This is how we separate out different replica groups (e.g. a
61 // single AllReduce HLO might do two reductions, between say GPUs {0,2} and
62 // {1,3}).
63 //
64 // * Only ops with the same opcode can communicate with each other. At the
65 // moment we only support kAllReduce, so we don't check for this explicitly.
66 //
67 // * For cross-module all-reduces (i.e. instr->channel_id().has_value()),
68 // only ops with the same value for channel_id() can communicate with each
69 // other.
70 //
71 // * For cross-replica (i.e. same-module) all-reduces (i.e.
72 // !channel_id().has_value()), only ops from the same module (as
73 // identified by its unique_id()) can communicate with each other.
74 //
75 struct RendezvousKey {
76 enum CollectiveOpKind {
77 kCrossModule,
78 kCrossReplica,
79 };
80
RendezvousKeyRendezvousKey81 explicit RendezvousKey(const RunId& run_id,
82 std::vector<GlobalDeviceId> global_devices,
83 int num_local_participants,
84 CollectiveOpKind collective_op_kind, int64 op_id)
85 : run_id(run_id),
86 global_devices(std::move(global_devices)),
87 num_local_participants(num_local_participants),
88 collective_op_kind(collective_op_kind),
89 op_id(op_id) {}
90
91 template <typename H>
AbslHashValueRendezvousKey92 friend H AbslHashValue(H h, const RendezvousKey& k) {
93 return H::combine(std::move(h), k.run_id, k.global_devices,
94 k.num_local_participants,
95 static_cast<int>(k.collective_op_kind), k.op_id);
96 }
97 friend bool operator==(const RendezvousKey& a, const RendezvousKey& b) {
98 return a.run_id == b.run_id && a.global_devices == b.global_devices &&
99 a.num_local_participants == b.num_local_participants &&
100 a.collective_op_kind == b.collective_op_kind && //
101 a.op_id == b.op_id;
102 }
103 friend bool operator!=(const RendezvousKey& a, const RendezvousKey& b) {
104 return !(a == b);
105 }
106
ToStringRendezvousKey107 string ToString() const {
108 return absl::StrFormat(
109 "RendezvousKey{run_id=%s, global_devices=[%s], "
110 "num_local_participants=%d, collective_op_kind=%d, op_id=%d}",
111 run_id.ToString(), GlobalDeviceIdsToString(global_devices),
112 num_local_participants, static_cast<int>(collective_op_kind), op_id);
113 }
114
115 RunId run_id;
116 std::vector<GlobalDeviceId> global_devices;
117 int num_local_participants;
118 CollectiveOpKind collective_op_kind;
119 int64 op_id;
120 };
121
122 template <typename DescFn>
WaitAndLogIfStuck(tensorflow::BlockingCounter * counter,const DescFn & desc_fn)123 void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
124 const DescFn& desc_fn) {
125 VLOG(3) << "Begin: " << desc_fn();
126 const std::chrono::milliseconds timeout(5000);
127 bool ok = counter->WaitFor(timeout);
128 if (ok) {
129 VLOG(3) << "Finished: " << desc_fn();
130 return;
131 }
132 LOG(ERROR) << "This thread has been waiting for " << timeout.count()
133 << "ms for and may be stuck: " << desc_fn();
134 counter->Wait();
135 LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. "
136 "Perhaps the timeout is too short: "
137 << desc_fn();
138 }
139
140 // Participant data for each rendezvous.
141 struct ParticipantData {
ParticipantDataParticipantData142 ParticipantData(const RendezvousKey& rendezvous_key, int64 device_ordinal,
143 se::Stream* stream)
144 : rendezvous_key(rendezvous_key),
145 device_ordinal(device_ordinal),
146 stream(stream) {}
147
~ParticipantDataParticipantData148 virtual ~ParticipantData() {}
149
150 RendezvousKey rendezvous_key;
151 int64 device_ordinal;
152 se::Stream* stream;
153
154 virtual std::string ToString() const = 0;
155 };
156
157 // Encapsulates parameters to Rendezvous::SubmitParticipant.
158 struct AllReduceParticipantData : ParticipantData {
AllReduceParticipantDataAllReduceParticipantData159 AllReduceParticipantData(const RendezvousKey& rendezvous_key_p,
160 int64 device_ordinal_p, se::Stream* stream_p)
161 : ParticipantData(rendezvous_key_p, device_ordinal_p, stream_p) {}
162
163 // TODO(b/125951860): We should vet that we're buffer allocating such that
164 // source_buffer == destination_buffer if that avoids a NCCL copy (will depend
165 // on how well the NCCL in-place implementation performs vs the out-of-place
166 // implementation).
167 struct Buffer {
168 int64 element_count;
169 se::DeviceMemoryBase source_data;
170 se::DeviceMemoryBase destination_data;
171 PrimitiveType primitive_type;
172 };
173 std::vector<Buffer> buffers;
174 const gpu::NcclUniqueIdCallback* nccl_unique_id_callback = nullptr;
175
176 ReductionKind reduction_kind;
177
178 // For each local all-reduce participant a (global ID, local device ordinal)
179 // pair for the participant. Participants are in no particular order.
180 std::vector<std::pair<GlobalDeviceId, int64>> local_devices;
181
ToStringAllReduceParticipantData182 string ToString() const override {
183 std::vector<std::string> buffer_strs;
184 for (const Buffer& buffer : buffers) {
185 buffer_strs.push_back(
186 absl::StrFormat("{element_count=%d}", buffer.element_count));
187 }
188 return absl::StrFormat(
189 "AllReduceParticipantData{buffers=[%s], rendezvous_key=%s, "
190 "device_ordinal=%d, stream=%p}",
191 absl::StrJoin(buffer_strs, ","), rendezvous_key.ToString(),
192 device_ordinal, stream);
193 }
194 };
195
196 // The set of threads that want to do a collective op together all pick the same
197 // Rendezvous object out of the global cache and call SubmitParticipant.
198 //
199 // The Rendezvous instance handles waiting for all threads to join, ensuring
200 // that a clique exists for the desired set of GPUs, etc.
201 //
202 // Rendezvous objects can only be used once.
203 //
204 // I: Participant data.
205 // O: Participant output.
206 template <typename I, typename O,
207 typename =
208 std::enable_if_t<std::is_base_of<ParticipantData, I>::value>>
209 class Rendezvous {
210 public:
~Rendezvous()211 virtual ~Rendezvous() {}
Rendezvous(const RendezvousKey & k)212 explicit Rendezvous(const RendezvousKey& k) : key_(k) {}
213
214 // Submit a participant to the rendezvous. We get the rendezvous from
215 // `rendezvous_getter`, which we can then use to drop the existing reference.
SubmitParticipant(std::function<std::shared_ptr<Rendezvous<I,O>> ()> rendezvous_getter,I participant)216 static StatusOr<O> SubmitParticipant(
217 std::function<std::shared_ptr<Rendezvous<I, O>>()> rendezvous_getter,
218 I participant) {
219 std::shared_ptr<Rendezvous<I, O>> rendezvous = rendezvous_getter();
220 TF_ASSIGN_OR_RETURN(auto p, rendezvous->SubmitParticipant(participant));
221
222 // Drop our reference to the Rendezvous and wait for all other threads to do
223 // the same. If we didn't do this, one of the threads could run past this
224 // point, reenter ExecuteOnStream for another all-reduce, and attempt to
225 // reuse the Rendezvous!
226 //
227 // An alternative way of accomplishing this goal would be to implement
228 // RefcountingHashMap::erase() and call it during SubmitParticipant. But
229 // erase() is deceptively complex to implement correctly.
230 std::shared_ptr<tensorflow::BlockingCounter> blocking_counter = p.second;
231 rendezvous.reset();
232 blocking_counter->DecrementCount();
233 xla::WaitAndLogIfStuck(blocking_counter.get(), [&] {
234 return absl::StrFormat(
235 "participant waiting for all threads to drop their reference to the "
236 "rendezvous: %p",
237 rendezvous.get());
238 });
239 return std::move(p.first);
240 }
241
242 protected:
243 // Returns domain-specific output O and whether this replica is primary.
244 virtual StatusOr<O> RunCollectiveOp(const I& participant) = 0;
245
246 // Initialize the rendezvous by the first ("primary") thread which reaches the
247 // barrier. Returns whether this thread is primary.
InitializationBarrier()248 bool InitializationBarrier() {
249 tensorflow::mutex_lock lock(mu_);
250 if (!initialized_) {
251 initialized_ = true;
252 return true;
253 }
254 return false;
255 }
256
257 tensorflow::mutex mu_;
258
259 bool initialized_ TF_GUARDED_BY(mu_) = false;
260
261 std::vector<I> participants_ TF_GUARDED_BY(mu_);
262
263 private:
264 // Runs the all-reduce on the given thread. If successful, returns
265 // - a handle to the clique that was used, so that the caller may keep the
266 // clique alive if it chooses.
267 // - a BlockingCounter initialized to the number of participants, so that
268 // the caller can coordinate with the participants one last time if it
269 // chooses. This is useful for coordinating destruction of the Rendezvous.
270 StatusOr<std::pair<O, std::shared_ptr<tensorflow::BlockingCounter>>>
SubmitParticipant(const I & participant)271 SubmitParticipant(const I& participant) {
272 {
273 tensorflow::mutex_lock lock(mu_);
274 CHECK(!initialized_);
275
276 // Spot check for consistent replica counts among submitting threads.
277 if (!participants_.empty() &&
278 participants_.back().rendezvous_key != participant.rendezvous_key) {
279 return InvalidArgument(
280 "Mismatch among all-reduce participants. Expected same "
281 "replica-count, element-count, and rendezvous-key but were %s and "
282 "%s",
283 participants_.back().ToString(), participant.ToString());
284 }
285 participants_.push_back(participant);
286 }
287
288 // Wait for all participants to arrive.
289 all_participants_present_.DecrementCount();
290 WaitAndLogIfStuck(&all_participants_present_, [&] {
291 return absl::StrFormat(
292 "participant for device ordinal %d, stream %p waiting for all "
293 "participants to arrive at rendezvous %s",
294 participant.device_ordinal, participant.stream, key_.ToString());
295 });
296
297 TF_ASSIGN_OR_RETURN(O output, RunCollectiveOp(participant));
298 return std::make_pair(std::move(output), returned_blocking_counter_);
299 }
300
301 const RendezvousKey key_;
302
303 tensorflow::BlockingCounter all_participants_present_{
304 key_.num_local_participants};
305
306 // tensorflow::BlockingCounter returned by SubmitParticipant.
307 std::shared_ptr<tensorflow::BlockingCounter> returned_blocking_counter_{
308 std::make_shared<tensorflow::BlockingCounter>(
309 key_.num_local_participants)};
310 };
311
312 } // end namespace xla
313
314 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_OPS_UTILS_H_
315