1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
17
18 #include <cstddef>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/strings/str_format.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/synchronization/mutex.h"
27 #include "tensorflow/compiler/xla/executable_run_options.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/primitive_util.h"
30 #include "tensorflow/compiler/xla/refcounting_hash_map.h"
31 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
32 #include "tensorflow/compiler/xla/service/computation_placer.h"
33 #include "tensorflow/compiler/xla/service/hlo_parser.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/core/platform/dynamic_annotations.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/macros.h"
39 #include "tensorflow/core/platform/mem.h"
40 #include "tensorflow/core/platform/status.h"
41 #include "tensorflow/core/platform/types.h"
42 #include "tensorflow/core/profiler/lib/traceme.h"
43 #include "tensorflow/stream_executor/device_memory.h"
44 #include "tensorflow/stream_executor/stream_executor.h"
45
46 namespace se = ::stream_executor;
47
48 namespace xla {
49 namespace cpu {
50 namespace runtime {
51
GetXfeedManager(int device_ordinal)52 XfeedManager* GetXfeedManager(int device_ordinal) {
53 static auto* managers = new absl::flat_hash_map<int, XfeedManager*>();
54 static absl::Mutex* mutex = new absl::Mutex();
55
56 absl::MutexLock lock(mutex);
57 auto it = managers->find(device_ordinal);
58 if (it == managers->end()) {
59 it = managers->emplace(device_ordinal, new XfeedManager()).first;
60 }
61 return it->second;
62 }
63
64 extern const char* const kEigenMatMulF16SymbolName =
65 "__xla_cpu_runtime_EigenMatMulF16";
66 extern const char* const kEigenMatMulF32SymbolName =
67 "__xla_cpu_runtime_EigenMatMulF32";
68 extern const char* const kEigenMatMulF64SymbolName =
69 "__xla_cpu_runtime_EigenMatMulF64";
70 extern const char* const kEigenMatMulC64SymbolName =
71 "__xla_cpu_runtime_EigenMatMulC64";
72 extern const char* const kEigenMatMulC128SymbolName =
73 "__xla_cpu_runtime_EigenMatMulC128";
74 extern const char* const kEigenMatMulS32SymbolName =
75 "__xla_cpu_runtime_EigenMatMulS32";
76 extern const char* const kMKLConvF32SymbolName = "__xla_cpu_runtime_MKLConvF32";
77 extern const char* const kMKLMatMulF32SymbolName =
78 "__xla_cpu_runtime_MKLMatMulF32";
79 extern const char* const kMKLMatMulF64SymbolName =
80 "__xla_cpu_runtime_MKLMatMulF64";
81 extern const char* const kMKLSingleThreadedMatMulF32SymbolName =
82 "__xla_cpu_runtime_MKLSingleThreadedMatMulF32";
83 extern const char* const kMKLSingleThreadedMatMulF64SymbolName =
84 "__xla_cpu_runtime_MKLSingleThreadedMatMulF64";
85 extern const char* const kEigenConvF16SymbolName =
86 "__xla_cpu_runtime_EigenConvF16";
87 extern const char* const kEigenConvF32SymbolName =
88 "__xla_cpu_runtime_EigenConvF32";
89 extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft";
90 extern const char* const kEigenSingleThreadedFftSymbolName =
91 "__xla_cpu_runtime_EigenSingleThreadedFft";
92 extern const char* const kEigenSingleThreadedMatMulF16SymbolName =
93 "__xla_cpu_runtime_EigenSingleThreadedMatMulF16";
94 extern const char* const kEigenSingleThreadedMatMulF32SymbolName =
95 "__xla_cpu_runtime_EigenSingleThreadedMatMulF32";
96 extern const char* const kEigenSingleThreadedMatMulF64SymbolName =
97 "__xla_cpu_runtime_EigenSingleThreadedMatMulF64";
98 extern const char* const kEigenSingleThreadedMatMulC64SymbolName =
99 "__xla_cpu_runtime_EigenSingleThreadedMatMulC64";
100 extern const char* const kEigenSingleThreadedMatMulC128SymbolName =
101 "__xla_cpu_runtime_EigenSingleThreadedMatMulC128";
102 extern const char* const kEigenSingleThreadedMatMulS32SymbolName =
103 "__xla_cpu_runtime_EigenSingleThreadedMatMulS32";
104 extern const char* const kEigenSingleThreadedConvF16SymbolName =
105 "__xla_cpu_runtime_EigenSingleThreadedConvF16";
106 extern const char* const kEigenSingleThreadedConvF32SymbolName =
107 "__xla_cpu_runtime_EigenSingleThreadedConvF32";
108 extern const char* const kAcquireInfeedBufferForDequeueSymbolName =
109 "__xla_cpu_runtime_AcquireInfeedBufferForDequeue";
110 extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName =
111 "__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue";
112 extern const char* const kAcquireOutfeedBufferForPopulationSymbolName =
113 "__xla_cpu_runtime_AcquireOutfeedBufferForPopulation";
114 extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName =
115 "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation";
116 extern const char* const kParallelForkJoinSymbolName =
117 "__xla_cpu_runtime_ParallelForkJoin";
118 extern const char* const kKeyValueSortSymbolName =
119 "__xla_cpu_runtime_KeyValueSort";
120 extern const char* const kTopKF32SymbolName = "__xla_cpu_runtime_TopKF32";
121 extern const char* const kTracingStartSymbolName =
122 "__xla_cpu_runtime_TracingStart";
123 extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
124 extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
125 extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce";
126 extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll";
127 extern const char* const kCollectivePermuteSymbolName =
128 "__xla_cpu_runtime_CollectivePermute";
129 extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId";
130
131 } // namespace runtime
132 } // namespace cpu
133 } // namespace xla
134
135 namespace {
136
137 struct CollectivePermuteParticipantData : xla::ParticipantData {
CollectivePermuteParticipantData__anon95669c100111::CollectivePermuteParticipantData138 CollectivePermuteParticipantData(const xla::RendezvousKey& rendezvous_key_p,
139 xla::int64 device_ordinal_p,
140 se::Stream* stream_p)
141 : ParticipantData(rendezvous_key_p, device_ordinal_p, stream_p) {}
142
143 int replica_id;
144 se::DeviceMemoryBase source_data;
145 se::DeviceMemoryBase destination_data;
146 xla::int64 byte_size;
147 std::vector<int> replica_ids_to_copy_to;
148
ToString__anon95669c100111::CollectivePermuteParticipantData149 std::string ToString() const override {
150 return absl::StrFormat(
151 "CollectivePermuteParticipantData{replica_id=%d, "
152 "source_data=%p, destination_data=%p, byte_size=%d, "
153 "replica_ids_to_copy_to=[%s]}",
154 replica_id, source_data.opaque(), destination_data.opaque(), byte_size,
155 absl::StrJoin(replica_ids_to_copy_to, ", "));
156 }
157 };
158
159 struct AllToAllParticipantData : xla::ParticipantData {
AllToAllParticipantData__anon95669c100111::AllToAllParticipantData160 AllToAllParticipantData(const xla::RendezvousKey& rendezvous_key_p,
161 xla::int64 device_ordinal_p, se::Stream* stream_p)
162 : ParticipantData(rendezvous_key_p, device_ordinal_p, stream_p) {}
163
164 std::vector<se::DeviceMemoryBase> source_buffers;
165 std::vector<se::DeviceMemoryBase> destination_buffers;
166 int replica_id;
167
168 // Replica ids participating in AllToAll, concatenation happens in the order
169 // of appearence.
170 std::vector<int> replica_ids_to_copy_to;
171
ToString__anon95669c100111::AllToAllParticipantData172 std::string ToString() const override {
173 auto addr_formatter = [](std::string* out,
174 const se::DeviceMemoryBase& mem) {
175 absl::StrAppend(out, absl::StrFormat("%p", mem.opaque()));
176 };
177 return absl::StrFormat(
178 "AllToAllParticipantData{replica_id=%d, "
179 "replica_ids_to_copy_to=[%s], source_buffers=[%s], "
180 "destination_buffers=[%s]}",
181 replica_id, absl::StrJoin(replica_ids_to_copy_to, ", "),
182 absl::StrJoin(source_buffers, ", ", addr_formatter),
183 absl::StrJoin(destination_buffers, ", ", addr_formatter));
184 }
185 };
186
187 // Inverses the encoding of a Shape protobuf into an LLVM global variable.
DecodeSelfDescribingShapeConstant(const void * shape_ptr,xla::int32 size_bytes)188 xla::StatusOr<xla::Shape> DecodeSelfDescribingShapeConstant(
189 const void* shape_ptr, xla::int32 size_bytes) {
190 xla::ShapeProto shape_proto;
191 if (!shape_proto.ParseFromArray(shape_ptr, size_bytes)) {
192 return tensorflow::errors::Internal("Failed parsing the shape proto");
193 }
194 xla::Shape shape(shape_proto);
195 auto status = xla::ShapeUtil::ValidateShape(shape);
196 if (!status.ok()) {
197 return status;
198 }
199 return std::move(shape);
200 }
201
ShapeString(const void * shape_ptr,xla::int32 shape_length)202 tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) {
203 xla::StatusOr<xla::Shape> shape =
204 DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
205 if (shape.ok()) {
206 return xla::ShapeUtil::HumanStringWithLayout(shape.ValueOrDie());
207 }
208 return "<invalid shape>";
209 }
210
211 } // namespace
212
213 extern "C" {
214
__xla_cpu_runtime_TracingStart(const void * run_options_ptr,const char * name)215 TF_ATTRIBUTE_NO_SANITIZE_MEMORY xla::int64 __xla_cpu_runtime_TracingStart(
216 const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
217 const char* name) {
218 VLOG(3) << "TracingStart " << name;
219 return tensorflow::profiler::TraceMe::ActivityStart(name);
220 }
221
__xla_cpu_runtime_TracingEnd(const void * run_options_ptr,xla::int64 id)222 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_TracingEnd(
223 const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
224 xla::int64 id) {
225 VLOG(3) << "TracingEnd " << id;
226 tensorflow::profiler::TraceMe::ActivityEnd(id);
227 }
228
229 } // extern "C"
230
231 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
__xla_cpu_runtime_AcquireInfeedBufferForDequeue(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,const void * shape,xla::int32 shape_length)232 __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
233 const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
234 const void* shape, xla::int32 shape_length) {
235 int device_ordinal =
236 run_options ? run_options->stream()->parent()->device_ordinal() : 0;
237
238 VLOG(2) << "AcquireInfeedBufferForDequeue: "
239 << ShapeString(shape, shape_length) << " on stream executor "
240 << device_ordinal;
241
242 xla::cpu::runtime::XfeedManager* xfeed =
243 xla::cpu::runtime::GetXfeedManager(device_ordinal);
244 // Wait until there's a buffer to dequeue.
245 xla::cpu::runtime::XfeedBuffer* buffer =
246 xfeed->infeed()->BlockingDequeueBuffer();
247 CHECK_EQ(buffer->length(), buffer_length)
248 << "XLA program infeed request buffer size " << buffer_length
249 << " did not match the runtime's infed buffer length " << buffer->length()
250 << "; program reports desired shape: "
251 << ShapeString(shape, shape_length);
252 return buffer->data();
253 }
254
255 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,void * buffer_ptr,const void * shape_ptr,xla::int32 shape_length)256 __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
257 const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
258 void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
259 int device_ordinal =
260 run_options ? run_options->stream()->parent()->device_ordinal() : 0;
261
262 VLOG(2) << "ReleaseInfeedBufferAfterDeque: "
263 << ShapeString(shape_ptr, shape_length) << " on stream executor "
264 << device_ordinal;
265
266 xla::cpu::runtime::XfeedManager* xfeed =
267 xla::cpu::runtime::GetXfeedManager(device_ordinal);
268 xla::StatusOr<xla::Shape> shape =
269 DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
270 xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
271 std::move(shape));
272 }
273
274 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,const void * shape_ptr,xla::int32 shape_length)275 __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
276 const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
277 const void* shape_ptr, xla::int32 shape_length) {
278 int device_ordinal =
279 run_options ? run_options->stream()->parent()->device_ordinal() : 0;
280
281 VLOG(2) << "AcquireOutfeedBufferForPopulation: "
282 << ShapeString(shape_ptr, shape_length) << " on stream executor "
283 << device_ordinal;
284
285 xla::cpu::runtime::XfeedManager* xfeed =
286 xla::cpu::runtime::GetXfeedManager(device_ordinal);
287 // Wait until there's a buffer to dequeue.
288 xla::cpu::runtime::XfeedBuffer* buffer =
289 xfeed->outfeed()->BlockingDequeueBuffer();
290 CHECK_EQ(buffer->length(), buffer_length)
291 << "XLA program outfeed request buffer size " << buffer_length
292 << " did not match the runtime's outfeed buffer length "
293 << buffer->length() << "; program reports outfed shape: "
294 << ShapeString(shape_ptr, shape_length);
295 return buffer->data();
296 }
297
298 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,void * buffer_ptr,const void * shape_ptr,xla::int32 shape_length)299 __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
300 const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
301 void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
302 int device_ordinal =
303 run_options ? run_options->stream()->parent()->device_ordinal() : 0;
304
305 VLOG(2) << "ReleaseOutfeedBufferAfterPopulation: "
306 << ShapeString(shape_ptr, shape_length) << " on stream executor "
307 << device_ordinal;
308
309 xla::cpu::runtime::XfeedManager* xfeed =
310 xla::cpu::runtime::GetXfeedManager(device_ordinal);
311 xla::StatusOr<xla::Shape> shape =
312 DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
313 xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
314 std::move(shape));
315 }
316
317 namespace {
318
319 class CpuAllToAllRendezvous
320 : public xla::Rendezvous<AllToAllParticipantData, std::nullptr_t> {
321 public:
CpuAllToAllRendezvous(const xla::RendezvousKey & k)322 explicit CpuAllToAllRendezvous(const xla::RendezvousKey& k)
323 : xla::Rendezvous<AllToAllParticipantData, std::nullptr_t>(k) {}
324
325 protected:
RunCollectiveOp(const AllToAllParticipantData &)326 xla::StatusOr<std::nullptr_t> RunCollectiveOp(
327 const AllToAllParticipantData& /*participant*/) override {
328 bool is_primary = InitializationBarrier();
329
330 if (is_primary) {
331 tensorflow::mutex_lock lock(mu_);
332
333 CHECK(!participants_.empty());
334 CHECK(!participants_[0].source_buffers.empty());
335 int expected_buffer_size = participants_[0].source_buffers[0].size();
336
337 // Replica id -> position in participants_.
338 absl::flat_hash_map<int, int> replica_id_map;
339
340 for (int pos = 0; pos < participants_.size(); pos++) {
341 const AllToAllParticipantData& p = participants_[pos];
342 CHECK_EQ(p.source_buffers.size(), p.destination_buffers.size());
343 CHECK_EQ(p.source_buffers.size(), participants_.size());
344 for (int i = 0; i < p.source_buffers.size(); i++) {
345 CHECK_EQ(p.destination_buffers[i].size(), expected_buffer_size);
346 CHECK_EQ(p.source_buffers[i].size(), expected_buffer_size);
347 }
348 replica_id_map[p.replica_id] = pos;
349 }
350
351 const std::vector<int>& replica_ids_to_copy_to =
352 participants_[0].replica_ids_to_copy_to;
353
354 // Replica id -> rank
355 absl::flat_hash_map<int, int> replica_ranks;
356 for (int rank = 0; rank < replica_ids_to_copy_to.size(); ++rank) {
357 int replica_id = replica_ids_to_copy_to[rank];
358 replica_ranks[replica_id] = rank;
359 }
360
361 for (const AllToAllParticipantData& sender : participants_) {
362 VLOG(3) << "Processing AllToAll participant: " << sender.ToString();
363
364 int rank = xla::FindOrDie(replica_ranks, sender.replica_id);
365
366 for (int i = 0; i < participants_.size(); ++i) {
367 int replica_id = replica_ids_to_copy_to[i];
368 int participant_num = xla::FindOrDie(replica_id_map, replica_id);
369 AllToAllParticipantData& receiver = participants_[participant_num];
370
371 std::memcpy(receiver.destination_buffers[rank].opaque(),
372 sender.source_buffers[i].opaque(), expected_buffer_size);
373 }
374 }
375 }
376 return nullptr;
377 }
378 };
379
380 class CpuCollectivePermuteRendezvous
381 : public xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t> {
382 public:
CpuCollectivePermuteRendezvous(const xla::RendezvousKey & k)383 explicit CpuCollectivePermuteRendezvous(const xla::RendezvousKey& k)
384 : xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t>(k) {}
385
386 protected:
RunCollectiveOp(const CollectivePermuteParticipantData &)387 xla::StatusOr<std::nullptr_t> RunCollectiveOp(
388 const CollectivePermuteParticipantData& /*participant*/) override {
389 bool primary = InitializationBarrier();
390
391 // Perform all copies from the primary thread.
392 if (primary) {
393 tensorflow::mutex_lock lock(mu_);
394
395 std::map<int, int> replica_idx_to_participant_idx;
396 for (int p_idx = 0; p_idx < participants_.size(); p_idx++) {
397 replica_idx_to_participant_idx[participants_[p_idx].replica_id] = p_idx;
398 }
399
400 for (auto& p : participants_) {
401 for (int dest_replica : p.replica_ids_to_copy_to) {
402 auto& dest_p = participants_[xla::FindOrDie(
403 replica_idx_to_participant_idx, dest_replica)];
404 std::memcpy(dest_p.destination_data.opaque(), p.source_data.opaque(),
405 p.byte_size);
406
407 // Each replica may be copied into only once.
408 replica_idx_to_participant_idx.erase(dest_replica);
409 }
410 }
411
412 // Zero out untouched participants.
413 for (auto& replica_p : replica_idx_to_participant_idx) {
414 auto& p = participants_[replica_p.second];
415 std::memset(p.destination_data.opaque(), 0, p.byte_size);
416 }
417 }
418 return nullptr;
419 }
420 };
421
422 class CpuAllReduceRendezvous
423 : public xla::Rendezvous<xla::AllReduceParticipantData, std::nullptr_t> {
424 public:
CpuAllReduceRendezvous(const xla::RendezvousKey & k)425 explicit CpuAllReduceRendezvous(const xla::RendezvousKey& k)
426 : xla::Rendezvous<xla::AllReduceParticipantData, std::nullptr_t>(k) {}
427
428 protected:
RunCollectiveOp(const xla::AllReduceParticipantData & participant)429 xla::StatusOr<std::nullptr_t> RunCollectiveOp(
430 const xla::AllReduceParticipantData& participant) override {
431 xla::PrimitiveType datatype = participant.buffers.front().primitive_type;
432 bool primary = InitializationBarrier();
433
434 if (primary) {
435 switch (datatype) {
436 case xla::S8:
437 DoAllReduce<xla::S8>(participant);
438 break;
439 case xla::PRED:
440 case xla::U8:
441 DoAllReduce<xla::U8>(participant);
442 break;
443 case xla::S32:
444 DoAllReduce<xla::S32>(participant);
445 break;
446 case xla::U32:
447 DoAllReduce<xla::U32>(participant);
448 break;
449 case xla::S64:
450 DoAllReduce<xla::S64>(participant);
451 break;
452 case xla::U64:
453 DoAllReduce<xla::U64>(participant);
454 break;
455 case xla::F16:
456 DoAllReduce<xla::F16>(participant);
457 break;
458 case xla::F32:
459 DoAllReduce<xla::F32>(participant);
460 break;
461 case xla::F64:
462 DoAllReduce<xla::F64>(participant);
463 break;
464 default:
465 LOG(FATAL) << "Unexpected datatype;";
466 }
467 }
468 return nullptr;
469 }
470
471 private:
472 template <xla::PrimitiveType PT>
DoAllReduce(xla::AllReduceParticipantData participant)473 void DoAllReduce(xla::AllReduceParticipantData participant) {
474 using T = typename xla::primitive_util::PrimitiveTypeToNative<PT>::type;
475 tensorflow::mutex_lock lock(mu_);
476 CHECK(!participants_.empty());
477 xla::ReductionKind reduction_kind = participant.reduction_kind;
478 for (const auto& p : participants_) {
479 CHECK(p.reduction_kind == reduction_kind);
480 }
481 int num_participants = participants_.size();
482
483 // participant_idx -> buffer_idx -> buffer.
484 std::vector<std::vector<absl::Span<T>>> input_buffers;
485 std::vector<std::vector<absl::Span<T>>> output_buffers;
486 input_buffers.reserve(num_participants);
487 output_buffers.reserve(num_participants);
488 const xla::AllReduceParticipantData& first_participant =
489 participants_.front();
490
491 int buffers_per_participant = first_participant.buffers.size();
492 for (xla::AllReduceParticipantData& p : participants_) {
493 CHECK_EQ(p.buffers.size(), buffers_per_participant);
494
495 input_buffers.emplace_back();
496 output_buffers.emplace_back();
497 std::vector<absl::Span<T>>& participant_input_buffers =
498 input_buffers.back();
499 std::vector<absl::Span<T>>& participant_output_buffers =
500 output_buffers.back();
501 participant_input_buffers.reserve(p.buffers.size());
502 participant_output_buffers.reserve(p.buffers.size());
503
504 for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
505 buffer_idx++) {
506 auto& participant_buffer = p.buffers[buffer_idx];
507 participant_input_buffers.emplace_back(
508 static_cast<T*>(participant_buffer.source_data.opaque()),
509 participant_buffer.element_count);
510 participant_output_buffers.emplace_back(
511 static_cast<T*>(participant_buffer.destination_data.opaque()),
512 participant_buffer.element_count);
513 CHECK_EQ(participant_buffer.element_count,
514 first_participant.buffers[buffer_idx].element_count);
515 }
516 }
517
518 for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
519 buffer_idx++) {
520 int element_count = first_participant.buffers[buffer_idx].element_count;
521 for (int idx = 0; idx < element_count; idx++) {
522 T out = GetInitialValue<T>(reduction_kind);
523 for (int participant_idx = 0; participant_idx < participants_.size();
524 participant_idx++) {
525 out = PerformReductionStep<T>(
526 reduction_kind, out,
527 input_buffers[participant_idx][buffer_idx][idx]);
528 }
529 for (int participant_idx = 0; participant_idx < participants_.size();
530 participant_idx++) {
531 output_buffers[participant_idx][buffer_idx][idx] = out;
532 }
533 }
534 }
535 }
536
537 template <typename T>
GetInitialValue(xla::ReductionKind reduction_kind)538 T GetInitialValue(xla::ReductionKind reduction_kind) {
539 switch (reduction_kind) {
540 case xla::ReductionKind::SUM:
541 return static_cast<T>(0);
542 case xla::ReductionKind::PRODUCT:
543 return static_cast<T>(1);
544 case xla::ReductionKind::MIN:
545 return std::numeric_limits<T>::max();
546 case xla::ReductionKind::MAX:
547 return std::numeric_limits<T>::min();
548 }
549 }
550
551 template <typename T>
PerformReductionStep(xla::ReductionKind reduction_kind,T a,T b)552 T PerformReductionStep(xla::ReductionKind reduction_kind, T a, T b) {
553 switch (reduction_kind) {
554 case xla::ReductionKind::SUM:
555 return a + b;
556 case xla::ReductionKind::PRODUCT:
557 return a * b;
558 case xla::ReductionKind::MIN:
559 return std::min(a, b);
560 case xla::ReductionKind::MAX:
561 return std::max(a, b);
562 }
563 }
564 };
565
566 xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>&
GlobalAllReduceRendezvousMap()567 GlobalAllReduceRendezvousMap() {
568 static auto& m =
569 *new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>;
570 return m;
571 }
572
573 xla::RefcountingHashMap<xla::RendezvousKey, CpuCollectivePermuteRendezvous>&
GlobalCollectivePermuteRendezvousMap()574 GlobalCollectivePermuteRendezvousMap() {
575 static auto& m = *new xla::RefcountingHashMap<xla::RendezvousKey,
576 CpuCollectivePermuteRendezvous>;
577 return m;
578 }
579
580 xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>&
GlobalAllToAllRendezvousMap()581 GlobalAllToAllRendezvousMap() {
582 static auto& m =
583 *new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>;
584 return m;
585 }
586
GetDeviceOrdinal(const xla::ExecutableRunOptions * run_options)587 int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options) {
588 if (run_options->stream()) {
589 return run_options->stream()->parent()->device_ordinal();
590 } else {
591 return run_options->device_ordinal();
592 }
593 }
594
GetRendezvousKey(const xla::ExecutableRunOptions * run_options,std::vector<xla::ReplicaGroup> group,xla::int32 channel_id_present,xla::int64 op_id)595 xla::RendezvousKey GetRendezvousKey(
596 const xla::ExecutableRunOptions* run_options,
597 std::vector<xla::ReplicaGroup> group, xla::int32 channel_id_present,
598 xla::int64 op_id) {
599 const xla::DeviceAssignment& device_assignment =
600 *run_options->device_assignment();
601 int device_ordinal = GetDeviceOrdinal(run_options);
602 xla::RendezvousKey::CollectiveOpKind op_kind =
603 channel_id_present ? xla::RendezvousKey::kCrossModule
604 : xla::RendezvousKey::kCrossReplica;
605 std::vector<xla::GlobalDeviceId> participating_devices =
606 xla::GetParticipatingDevices(xla::GlobalDeviceId(device_ordinal),
607 device_assignment,
608 device_assignment.replica_count(), group)
609 .ValueOrDie();
610 int num_local_participants = participating_devices.size();
611 return xla::RendezvousKey{run_options->run_id(),
612 std::move(participating_devices),
613 num_local_participants, op_kind, op_id};
614 }
615
616 } // namespace
617
__xla_cpu_runtime_AllToAll(const xla::ExecutableRunOptions * run_options,xla::int32 channel_id_present,xla::int64 op_id,const void * replica_groups_str,xla::int32 replica_groups_str_size,xla::int32 num_buffers,xla::int64 buffer_size,void ** source_buffers,void ** destination_buffers)618 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll(
619 const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
620 xla::int64 op_id, const void* replica_groups_str,
621 xla::int32 replica_groups_str_size, xla::int32 num_buffers,
622 xla::int64 buffer_size, void** source_buffers, void** destination_buffers) {
623 int device_ordinal = GetDeviceOrdinal(run_options);
624 xla::int32 replica_id =
625 run_options->device_assignment()
626 ->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
627 .ValueOrDie();
628 absl::string_view replica_groups_serialized(
629 static_cast<const char*>(replica_groups_str), replica_groups_str_size);
630 std::vector<xla::ReplicaGroup> group =
631 xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
632 xla::RendezvousKey rendezvous_key =
633 GetRendezvousKey(run_options, group, channel_id_present, op_id);
634
635 AllToAllParticipantData participant(rendezvous_key, device_ordinal,
636 run_options->stream());
637 participant.replica_id = replica_id;
638 participant.replica_ids_to_copy_to =
639 xla::GetParticipatingReplicas(
640 replica_id, run_options->device_assignment()->replica_count(), group)
641 .ValueOrDie();
642 for (int i = 0; i < num_buffers; i++) {
643 participant.source_buffers.emplace_back(source_buffers[i], buffer_size);
644 participant.destination_buffers.emplace_back(destination_buffers[i],
645 buffer_size);
646 }
647 auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
648 return absl::make_unique<CpuAllToAllRendezvous>(k);
649 };
650 TF_CHECK_OK(CpuAllToAllRendezvous::SubmitParticipant(
651 [&] {
652 return GlobalAllToAllRendezvousMap().GetOrCreateIfAbsent(
653 rendezvous_key, make_cpu_rendezvous);
654 },
655 participant)
656 .status());
657 }
658
__xla_cpu_runtime_AllReduce(const xla::ExecutableRunOptions * run_options,const void * replica_groups_str,xla::int32 replica_groups_str_size,xla::int32 channel_id_present,xla::int64 op_id,xla::int32 reduction_kind,const void * shape_ptr,xla::int32 shape_length,xla::int32 num_buffers,void ** input_buffers,void ** output_buffers)659 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
660 const xla::ExecutableRunOptions* run_options,
661 const void* replica_groups_str, xla::int32 replica_groups_str_size,
662 xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
663 const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers,
664 void** input_buffers, void** output_buffers) {
665 int device_ordinal = GetDeviceOrdinal(run_options);
666 absl::string_view replica_groups_serialized(
667 static_cast<const char*>(replica_groups_str), replica_groups_str_size);
668 std::vector<xla::ReplicaGroup> group =
669 xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
670 xla::RendezvousKey rendezvous_key =
671 GetRendezvousKey(run_options, group, channel_id_present, op_id);
672 auto shape_str = ShapeString(shape_ptr, shape_length);
673 VLOG(2) << "All-reduce input/output shape : " << shape_str;
674
675 xla::Shape shape =
676 DecodeSelfDescribingShapeConstant(shape_ptr, shape_length).ValueOrDie();
677
678 CHECK((num_buffers > 1 && shape.IsTuple()) ||
679 (num_buffers == 1 && xla::LayoutUtil::IsDenseArray(shape)));
680
681 xla::AllReduceParticipantData participant(rendezvous_key, device_ordinal,
682 run_options->stream());
683 participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind);
684 for (int i = 0; i < num_buffers; i++) {
685 xla::Shape subshape = num_buffers == 1 ? shape : shape.tuple_shapes(i);
686 xla::AllReduceParticipantData::Buffer buffer;
687 buffer.element_count = xla::ShapeUtil::ElementsIn(subshape);
688 buffer.primitive_type = subshape.element_type();
689 buffer.source_data = se::DeviceMemoryBase(
690 input_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
691 buffer.destination_data = se::DeviceMemoryBase(
692 output_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
693 participant.buffers.push_back(buffer);
694 }
695
696 auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
697 return absl::make_unique<CpuAllReduceRendezvous>(k);
698 };
699
700 TF_CHECK_OK(CpuAllReduceRendezvous::SubmitParticipant(
701 [&] {
702 return GlobalAllReduceRendezvousMap().GetOrCreateIfAbsent(
703 rendezvous_key, make_cpu_rendezvous);
704 },
705 participant)
706 .status());
707 }
708
__xla_cpu_runtime_ReplicaId(const xla::ExecutableRunOptions * run_options,void * output_buffer)709 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId(
710 const xla::ExecutableRunOptions* run_options, void* output_buffer) {
711 int device_ordinal = GetDeviceOrdinal(run_options);
712 xla::int32 replica_id =
713 run_options->device_assignment()
714 ->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
715 .ValueOrDie();
716 std::memcpy(output_buffer, &replica_id, 4);
717 }
718
__xla_cpu_runtime_CollectivePermute(const xla::ExecutableRunOptions * run_options,xla::int32 channel_id_present,xla::int64 op_id,xla::int32 byte_size,void * input_buffer,void * output_buffer,const void * source_target_pairs,xla::int32 source_target_pairs_size)719 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_CollectivePermute(
720 const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
721 xla::int64 op_id, xla::int32 byte_size, void* input_buffer,
722 void* output_buffer, const void* source_target_pairs,
723 xla::int32 source_target_pairs_size) {
724 int device_ordinal = GetDeviceOrdinal(run_options);
725 absl::string_view source_target_pairs_serialized(
726 static_cast<const char*>(source_target_pairs), source_target_pairs_size);
727 auto pairs = absl::StrSplit(source_target_pairs_serialized, ',');
728 xla::int32 replica_id =
729 run_options->device_assignment()
730 ->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
731 .ValueOrDie();
732 std::vector<int> copy_to;
733 for (auto& p : pairs) {
734 std::vector<std::string> mapping = absl::StrSplit(p, '=');
735 CHECK_EQ(mapping.size(), 2);
736 int from = std::stoi(mapping[0]);
737 int to = std::stoi(mapping[1]);
738 if (from == replica_id) {
739 copy_to.push_back(to);
740 }
741 }
742 xla::RendezvousKey rendezvous_key =
743 GetRendezvousKey(run_options, {}, channel_id_present, op_id);
744
745 CollectivePermuteParticipantData participant(rendezvous_key, device_ordinal,
746 run_options->stream());
747 participant.replica_id = replica_id;
748 participant.source_data = se::DeviceMemoryBase(input_buffer, byte_size);
749 participant.destination_data = se::DeviceMemoryBase(output_buffer, byte_size);
750 participant.replica_ids_to_copy_to = copy_to;
751 participant.byte_size = byte_size;
752
753 auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
754 return absl::make_unique<CpuCollectivePermuteRendezvous>(k);
755 };
756 TF_CHECK_OK(
757 CpuCollectivePermuteRendezvous::SubmitParticipant(
758 [&] {
759 return GlobalCollectivePermuteRendezvousMap().GetOrCreateIfAbsent(
760 rendezvous_key, make_cpu_rendezvous);
761 },
762 participant)
763 .status());
764 }
765