1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
17
18 #include <cstdarg>
19 #include <cstddef>
20 #include <cstring>
21 #include <functional>
22 #include <limits>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/strings/str_format.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/synchronization/mutex.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/compiler/xla/executable_run_options.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/primitive_util.h"
32 #include "tensorflow/compiler/xla/refcounting_hash_map.h"
33 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
34 #include "tensorflow/compiler/xla/service/computation_placer.h"
35 #include "tensorflow/compiler/xla/service/hlo_parser.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/statusor.h"
38 #include "tensorflow/core/platform/dynamic_annotations.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/macros.h"
41 #include "tensorflow/core/platform/mem.h"
42 #include "tensorflow/core/platform/status.h"
43 #include "tensorflow/core/platform/types.h"
44 #include "tensorflow/core/profiler/lib/traceme.h"
45 #include "tensorflow/stream_executor/device_memory.h"
46 #include "tensorflow/stream_executor/stream_executor.h"
47
48 namespace se = ::stream_executor;
49
50 namespace xla {
51 namespace cpu {
52 namespace runtime {
53
GetXfeedManager(int device_ordinal)54 XfeedManager* GetXfeedManager(int device_ordinal) {
55 static auto* managers = new absl::flat_hash_map<int, XfeedManager*>();
56 static absl::Mutex* mutex = new absl::Mutex();
57
58 absl::MutexLock lock(mutex);
59 auto it = managers->find(device_ordinal);
60 if (it == managers->end()) {
61 it = managers->emplace(device_ordinal, new XfeedManager()).first;
62 }
63 return it->second;
64 }
65
66 extern const char* const kEigenMatMulF16SymbolName =
67 "__xla_cpu_runtime_EigenMatMulF16";
68 extern const char* const kEigenMatMulF32SymbolName =
69 "__xla_cpu_runtime_EigenMatMulF32";
70 extern const char* const kEigenMatMulF64SymbolName =
71 "__xla_cpu_runtime_EigenMatMulF64";
72 extern const char* const kEigenMatMulC64SymbolName =
73 "__xla_cpu_runtime_EigenMatMulC64";
74 extern const char* const kEigenMatMulC128SymbolName =
75 "__xla_cpu_runtime_EigenMatMulC128";
76 extern const char* const kEigenMatMulS32SymbolName =
77 "__xla_cpu_runtime_EigenMatMulS32";
78 extern const char* const kMKLConvF32SymbolName = "__xla_cpu_runtime_MKLConvF32";
79 extern const char* const kMKLMatMulF32SymbolName =
80 "__xla_cpu_runtime_MKLMatMulF32";
81 extern const char* const kMKLMatMulF64SymbolName =
82 "__xla_cpu_runtime_MKLMatMulF64";
83 extern const char* const kMKLSingleThreadedMatMulF32SymbolName =
84 "__xla_cpu_runtime_MKLSingleThreadedMatMulF32";
85 extern const char* const kMKLSingleThreadedMatMulF64SymbolName =
86 "__xla_cpu_runtime_MKLSingleThreadedMatMulF64";
87 extern const char* const kEigenConvF16SymbolName =
88 "__xla_cpu_runtime_EigenConvF16";
89 extern const char* const kEigenConvF32SymbolName =
90 "__xla_cpu_runtime_EigenConvF32";
91 extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft";
92 extern const char* const kEigenSingleThreadedFftSymbolName =
93 "__xla_cpu_runtime_EigenSingleThreadedFft";
94 extern const char* const kEigenSingleThreadedMatMulF16SymbolName =
95 "__xla_cpu_runtime_EigenSingleThreadedMatMulF16";
96 extern const char* const kEigenSingleThreadedMatMulF32SymbolName =
97 "__xla_cpu_runtime_EigenSingleThreadedMatMulF32";
98 extern const char* const kEigenSingleThreadedMatMulF64SymbolName =
99 "__xla_cpu_runtime_EigenSingleThreadedMatMulF64";
100 extern const char* const kEigenSingleThreadedMatMulC64SymbolName =
101 "__xla_cpu_runtime_EigenSingleThreadedMatMulC64";
102 extern const char* const kEigenSingleThreadedMatMulC128SymbolName =
103 "__xla_cpu_runtime_EigenSingleThreadedMatMulC128";
104 extern const char* const kEigenSingleThreadedMatMulS32SymbolName =
105 "__xla_cpu_runtime_EigenSingleThreadedMatMulS32";
106 extern const char* const kEigenSingleThreadedConvF16SymbolName =
107 "__xla_cpu_runtime_EigenSingleThreadedConvF16";
108 extern const char* const kEigenSingleThreadedConvF32SymbolName =
109 "__xla_cpu_runtime_EigenSingleThreadedConvF32";
110 extern const char* const kAcquireInfeedBufferForDequeueSymbolName =
111 "__xla_cpu_runtime_AcquireInfeedBufferForDequeue";
112 extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName =
113 "__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue";
114 extern const char* const kAcquireOutfeedBufferForPopulationSymbolName =
115 "__xla_cpu_runtime_AcquireOutfeedBufferForPopulation";
116 extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName =
117 "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation";
118 extern const char* const kParallelForkJoinSymbolName =
119 "__xla_cpu_runtime_ParallelForkJoin";
120 extern const char* const kPrintfToStderrSymbolName =
121 "__xla_cpu_runtime_PrintfToStderr";
122 extern const char* const kKeyValueSortSymbolName =
123 "__xla_cpu_runtime_KeyValueSort";
124 extern const char* const kTopKF32SymbolName = "__xla_cpu_runtime_TopKF32";
125 extern const char* const kTracingStartSymbolName =
126 "__xla_cpu_runtime_TracingStart";
127 extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
128 extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
129 extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce";
130 extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll";
131 extern const char* const kCollectivePermuteSymbolName =
132 "__xla_cpu_runtime_CollectivePermute";
133 extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId";
134
135 } // namespace runtime
136 } // namespace cpu
137 } // namespace xla
138
139 namespace {
140
141 struct CollectivePermuteParticipantData : xla::ParticipantData {
CollectivePermuteParticipantData__anonf775ca150111::CollectivePermuteParticipantData142 CollectivePermuteParticipantData(const xla::RendezvousKey& rendezvous_key_p,
143 xla::int64 device_ordinal_p,
144 se::Stream* stream_p)
145 : ParticipantData(rendezvous_key_p),
146 device_ordinal(device_ordinal_p),
147 stream(stream_p) {}
148
149 xla::int64 device_ordinal;
150 se::Stream* stream;
151 int replica_id;
152 se::DeviceMemoryBase source_data;
153 se::DeviceMemoryBase destination_data;
154 xla::int64 byte_size;
155 std::vector<int> replica_ids_to_copy_to;
156
ToString__anonf775ca150111::CollectivePermuteParticipantData157 std::string ToString() const override {
158 return absl::StrFormat(
159 "CollectivePermuteParticipantData{replica_id=%d, "
160 "source_data=%p, destination_data=%p, byte_size=%d, "
161 "replica_ids_to_copy_to=[%s], device_ordinal=%d, stream=%p}",
162 replica_id, source_data.opaque(), destination_data.opaque(), byte_size,
163 absl::StrJoin(replica_ids_to_copy_to, ", "), device_ordinal, stream);
164 }
165 };
166
167 struct AllToAllParticipantData : xla::ParticipantData {
AllToAllParticipantData__anonf775ca150111::AllToAllParticipantData168 AllToAllParticipantData(const xla::RendezvousKey& rendezvous_key_p,
169 xla::int64 device_ordinal_p, se::Stream* stream_p)
170 : ParticipantData(rendezvous_key_p),
171 device_ordinal(device_ordinal_p),
172 stream(stream_p) {}
173
174 xla::int64 device_ordinal;
175 se::Stream* stream;
176 std::vector<se::DeviceMemoryBase> source_buffers;
177 std::vector<se::DeviceMemoryBase> destination_buffers;
178 int replica_id;
179
180 // Replica ids participating in AllToAll, concatenation happens in the order
181 // of appearence.
182 std::vector<int> replica_ids_to_copy_to;
183
ToString__anonf775ca150111::AllToAllParticipantData184 std::string ToString() const override {
185 auto addr_formatter = [](std::string* out,
186 const se::DeviceMemoryBase& mem) {
187 absl::StrAppend(out, absl::StrFormat("%p", mem.opaque()));
188 };
189 return absl::StrFormat(
190 "AllToAllParticipantData{replica_id=%d, "
191 "replica_ids_to_copy_to=[%s], source_buffers=[%s], "
192 "destination_buffers=[%s], device_ordinal=%d, stream=%p}",
193 replica_id, absl::StrJoin(replica_ids_to_copy_to, ", "),
194 absl::StrJoin(source_buffers, ", ", addr_formatter),
195 absl::StrJoin(destination_buffers, ", ", addr_formatter),
196 device_ordinal, stream);
197 }
198 };
199
200 // Inverses the encoding of a Shape protobuf into an LLVM global variable.
DecodeSelfDescribingShapeConstant(const void * shape_ptr,xla::int32 size_bytes)201 xla::StatusOr<xla::Shape> DecodeSelfDescribingShapeConstant(
202 const void* shape_ptr, xla::int32 size_bytes) {
203 xla::ShapeProto shape_proto;
204 if (!shape_proto.ParseFromArray(shape_ptr, size_bytes)) {
205 return tensorflow::errors::Internal("Failed parsing the shape proto");
206 }
207 xla::Shape shape(shape_proto);
208 auto status = xla::ShapeUtil::ValidateShape(shape);
209 if (!status.ok()) {
210 return status;
211 }
212 return std::move(shape);
213 }
214
ShapeString(const void * shape_ptr,xla::int32 shape_length)215 tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) {
216 xla::StatusOr<xla::Shape> shape =
217 DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
218 if (shape.ok()) {
219 return xla::ShapeUtil::HumanStringWithLayout(shape.ValueOrDie());
220 }
221 return "<invalid shape>";
222 }
223
224 // TODO(zhangqiaorjc): Prefer to make callers set and use device_ordinal
225 // directly since callers may not have a Stream*.
GetDeviceOrdinal(const xla::ExecutableRunOptions * run_options)226 int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options) {
227 if (!run_options) {
228 return 0;
229 } else if (run_options->device_ordinal() != -1) {
230 return run_options->device_ordinal();
231 }
232 return run_options->stream()->parent()->device_ordinal();
233 }
234
235 } // namespace
236
237 extern "C" {
238
__xla_cpu_runtime_PrintfToStderr(const char * format,...)239 TF_ATTRIBUTE_NO_SANITIZE_MEMORY int __xla_cpu_runtime_PrintfToStderr(
240 const char* format, ...) {
241 VLOG(3) << "__xla_cpu_runtime_PrintfToStderr " << format;
242 va_list args;
243 va_start(args, format);
244 int result = vfprintf(stderr, format, args);
245 va_end(args);
246 return result;
247 }
248
__xla_cpu_runtime_TracingStart(const void * run_options_ptr,const char * name)249 TF_ATTRIBUTE_NO_SANITIZE_MEMORY xla::int64 __xla_cpu_runtime_TracingStart(
250 const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
251 const char* name) {
252 VLOG(3) << "TracingStart " << name;
253 return tensorflow::profiler::TraceMe::ActivityStart(name);
254 }
255
__xla_cpu_runtime_TracingEnd(const void * run_options_ptr,xla::int64 id)256 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_TracingEnd(
257 const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
258 xla::int64 id) {
259 VLOG(3) << "TracingEnd " << id;
260 tensorflow::profiler::TraceMe::ActivityEnd(id);
261 }
262
263 } // extern "C"
264
265 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
__xla_cpu_runtime_AcquireInfeedBufferForDequeue(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,const void * shape,xla::int32 shape_length)266 __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
267 const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
268 const void* shape, xla::int32 shape_length) {
269 int device_ordinal = GetDeviceOrdinal(run_options);
270
271 VLOG(2) << "AcquireInfeedBufferForDequeue: "
272 << ShapeString(shape, shape_length) << " on stream executor "
273 << device_ordinal;
274
275 xla::cpu::runtime::XfeedManager* xfeed =
276 xla::cpu::runtime::GetXfeedManager(device_ordinal);
277 // Wait until there's a buffer to dequeue.
278 xla::cpu::runtime::XfeedBuffer* buffer =
279 xfeed->infeed()->BlockingDequeueBuffer();
280 CHECK_EQ(buffer->length(), buffer_length)
281 << "XLA program infeed request buffer size " << buffer_length
282 << " did not match the runtime's infed buffer length " << buffer->length()
283 << "; program reports desired shape: "
284 << ShapeString(shape, shape_length);
285 return buffer->data();
286 }
287
288 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,void * buffer_ptr,const void * shape_ptr,xla::int32 shape_length)289 __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
290 const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
291 void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
292 int device_ordinal = GetDeviceOrdinal(run_options);
293
294 VLOG(2) << "ReleaseInfeedBufferAfterDeque: "
295 << ShapeString(shape_ptr, shape_length) << " on stream executor "
296 << device_ordinal;
297
298 xla::cpu::runtime::XfeedManager* xfeed =
299 xla::cpu::runtime::GetXfeedManager(device_ordinal);
300 xla::StatusOr<xla::Shape> shape =
301 DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
302 xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
303 std::move(shape));
304 }
305
306 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,const void * shape_ptr,xla::int32 shape_length)307 __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
308 const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
309 const void* shape_ptr, xla::int32 shape_length) {
310 int device_ordinal = GetDeviceOrdinal(run_options);
311
312 VLOG(2) << "AcquireOutfeedBufferForPopulation: "
313 << ShapeString(shape_ptr, shape_length) << " on stream executor "
314 << device_ordinal;
315
316 xla::cpu::runtime::XfeedManager* xfeed =
317 xla::cpu::runtime::GetXfeedManager(device_ordinal);
318 // Wait until there's a buffer to dequeue.
319 xla::cpu::runtime::XfeedBuffer* buffer =
320 xfeed->outfeed()->BlockingDequeueBuffer();
321 CHECK_EQ(buffer->length(), buffer_length)
322 << "XLA program outfeed request buffer size " << buffer_length
323 << " did not match the runtime's outfeed buffer length "
324 << buffer->length() << "; program reports outfed shape: "
325 << ShapeString(shape_ptr, shape_length);
326 return buffer->data();
327 }
328
329 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,void * buffer_ptr,const void * shape_ptr,xla::int32 shape_length)330 __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
331 const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
332 void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
333 int device_ordinal = GetDeviceOrdinal(run_options);
334
335 VLOG(2) << "ReleaseOutfeedBufferAfterPopulation: "
336 << ShapeString(shape_ptr, shape_length) << " on stream executor "
337 << device_ordinal;
338
339 xla::cpu::runtime::XfeedManager* xfeed =
340 xla::cpu::runtime::GetXfeedManager(device_ordinal);
341 xla::StatusOr<xla::Shape> shape =
342 DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
343 xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
344 std::move(shape));
345 }
346
347 namespace {
348
349 class CpuAllToAllRendezvous
350 : public xla::Rendezvous<AllToAllParticipantData, std::nullptr_t> {
351 public:
CpuAllToAllRendezvous(const xla::RendezvousKey & k)352 explicit CpuAllToAllRendezvous(const xla::RendezvousKey& k)
353 : xla::Rendezvous<AllToAllParticipantData, std::nullptr_t>(k) {}
354
355 protected:
RunCollectiveOp(const AllToAllParticipantData &)356 xla::StatusOr<std::nullptr_t> RunCollectiveOp(
357 const AllToAllParticipantData& /*participant*/) override {
358 bool is_primary = InitializationBarrier();
359
360 if (is_primary) {
361 tensorflow::mutex_lock lock(mu_);
362
363 CHECK(!participants_.empty());
364 CHECK(!participants_[0].source_buffers.empty());
365 int expected_buffer_size = participants_[0].source_buffers[0].size();
366
367 // Replica id -> position in participants_.
368 absl::flat_hash_map<int, int> replica_id_map;
369
370 for (int pos = 0; pos < participants_.size(); pos++) {
371 const AllToAllParticipantData& p = participants_[pos];
372 CHECK_EQ(p.source_buffers.size(), p.destination_buffers.size());
373 CHECK_EQ(p.source_buffers.size(), participants_.size());
374 for (int i = 0; i < p.source_buffers.size(); i++) {
375 CHECK_EQ(p.destination_buffers[i].size(), expected_buffer_size);
376 CHECK_EQ(p.source_buffers[i].size(), expected_buffer_size);
377 }
378 replica_id_map[p.replica_id] = pos;
379 }
380
381 const std::vector<int>& replica_ids_to_copy_to =
382 participants_[0].replica_ids_to_copy_to;
383
384 // Replica id -> rank
385 absl::flat_hash_map<int, int> replica_ranks;
386 for (int rank = 0; rank < replica_ids_to_copy_to.size(); ++rank) {
387 int replica_id = replica_ids_to_copy_to[rank];
388 replica_ranks[replica_id] = rank;
389 }
390
391 for (const AllToAllParticipantData& sender : participants_) {
392 VLOG(3) << "Processing AllToAll participant: " << sender.ToString();
393
394 int rank = xla::FindOrDie(replica_ranks, sender.replica_id);
395
396 for (int i = 0; i < participants_.size(); ++i) {
397 int replica_id = replica_ids_to_copy_to[i];
398 int participant_num = xla::FindOrDie(replica_id_map, replica_id);
399 AllToAllParticipantData& receiver = participants_[participant_num];
400
401 std::memcpy(receiver.destination_buffers[rank].opaque(),
402 sender.source_buffers[i].opaque(), expected_buffer_size);
403 }
404 }
405 }
406 return nullptr;
407 }
408 };
409
410 class CpuCollectivePermuteRendezvous
411 : public xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t> {
412 public:
CpuCollectivePermuteRendezvous(const xla::RendezvousKey & k)413 explicit CpuCollectivePermuteRendezvous(const xla::RendezvousKey& k)
414 : xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t>(k) {}
415
416 protected:
RunCollectiveOp(const CollectivePermuteParticipantData &)417 xla::StatusOr<std::nullptr_t> RunCollectiveOp(
418 const CollectivePermuteParticipantData& /*participant*/) override {
419 bool primary = InitializationBarrier();
420
421 // Perform all copies from the primary thread.
422 if (primary) {
423 tensorflow::mutex_lock lock(mu_);
424
425 std::map<int, int> replica_idx_to_participant_idx;
426 for (int p_idx = 0; p_idx < participants_.size(); p_idx++) {
427 replica_idx_to_participant_idx[participants_[p_idx].replica_id] = p_idx;
428 }
429
430 for (auto& p : participants_) {
431 for (int dest_replica : p.replica_ids_to_copy_to) {
432 auto& dest_p = participants_[xla::FindOrDie(
433 replica_idx_to_participant_idx, dest_replica)];
434 std::memcpy(dest_p.destination_data.opaque(), p.source_data.opaque(),
435 p.byte_size);
436
437 // Each replica may be copied into only once.
438 replica_idx_to_participant_idx.erase(dest_replica);
439 }
440 }
441
442 // Zero out untouched participants.
443 for (auto& replica_p : replica_idx_to_participant_idx) {
444 auto& p = participants_[replica_p.second];
445 std::memset(p.destination_data.opaque(), 0, p.byte_size);
446 }
447 }
448 return nullptr;
449 }
450 };
451
452 class CpuAllReduceRendezvous
453 : public xla::Rendezvous<xla::AllReduceParticipantData, std::nullptr_t> {
454 public:
CpuAllReduceRendezvous(const xla::RendezvousKey & k)455 explicit CpuAllReduceRendezvous(const xla::RendezvousKey& k)
456 : xla::Rendezvous<xla::AllReduceParticipantData, std::nullptr_t>(k) {}
457
458 protected:
RunCollectiveOp(const xla::AllReduceParticipantData & participant)459 xla::StatusOr<std::nullptr_t> RunCollectiveOp(
460 const xla::AllReduceParticipantData& participant) override {
461 xla::PrimitiveType datatype = participant.buffers.front().primitive_type;
462 bool primary = InitializationBarrier();
463
464 if (primary) {
465 switch (datatype) {
466 case xla::S8:
467 DoAllReduce<xla::S8>(participant);
468 break;
469 case xla::PRED:
470 case xla::U8:
471 DoAllReduce<xla::U8>(participant);
472 break;
473 case xla::S32:
474 DoAllReduce<xla::S32>(participant);
475 break;
476 case xla::U32:
477 DoAllReduce<xla::U32>(participant);
478 break;
479 case xla::S64:
480 DoAllReduce<xla::S64>(participant);
481 break;
482 case xla::U64:
483 DoAllReduce<xla::U64>(participant);
484 break;
485 case xla::F16:
486 DoAllReduce<xla::F16>(participant);
487 break;
488 case xla::F32:
489 DoAllReduce<xla::F32>(participant);
490 break;
491 case xla::F64:
492 DoAllReduce<xla::F64>(participant);
493 break;
494 default:
495 LOG(FATAL) << "Unexpected datatype;";
496 }
497 }
498 return nullptr;
499 }
500
501 private:
502 template <xla::PrimitiveType PT>
DoAllReduce(xla::AllReduceParticipantData participant)503 void DoAllReduce(xla::AllReduceParticipantData participant) {
504 using T = typename xla::primitive_util::PrimitiveTypeToNative<PT>::type;
505 tensorflow::mutex_lock lock(mu_);
506 CHECK(!participants_.empty());
507 xla::ReductionKind reduction_kind = participant.reduction_kind;
508 for (const auto& p : participants_) {
509 CHECK(p.reduction_kind == reduction_kind);
510 }
511 int num_participants = participants_.size();
512
513 // participant_idx -> buffer_idx -> buffer.
514 std::vector<std::vector<absl::Span<T>>> input_buffers;
515 std::vector<std::vector<absl::Span<T>>> output_buffers;
516 input_buffers.reserve(num_participants);
517 output_buffers.reserve(num_participants);
518 const xla::AllReduceParticipantData& first_participant =
519 participants_.front();
520
521 int buffers_per_participant = first_participant.buffers.size();
522 for (xla::AllReduceParticipantData& p : participants_) {
523 CHECK_EQ(p.buffers.size(), buffers_per_participant);
524
525 input_buffers.emplace_back();
526 output_buffers.emplace_back();
527 std::vector<absl::Span<T>>& participant_input_buffers =
528 input_buffers.back();
529 std::vector<absl::Span<T>>& participant_output_buffers =
530 output_buffers.back();
531 participant_input_buffers.reserve(p.buffers.size());
532 participant_output_buffers.reserve(p.buffers.size());
533
534 for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
535 buffer_idx++) {
536 auto& participant_buffer = p.buffers[buffer_idx];
537 participant_input_buffers.emplace_back(
538 static_cast<T*>(participant_buffer.source_data.opaque()),
539 participant_buffer.element_count);
540 participant_output_buffers.emplace_back(
541 static_cast<T*>(participant_buffer.destination_data.opaque()),
542 participant_buffer.element_count);
543 CHECK_EQ(participant_buffer.element_count,
544 first_participant.buffers[buffer_idx].element_count);
545 }
546 }
547
548 for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
549 buffer_idx++) {
550 int element_count = first_participant.buffers[buffer_idx].element_count;
551 for (int idx = 0; idx < element_count; idx++) {
552 T out = GetInitialValue<T>(reduction_kind);
553 for (int participant_idx = 0; participant_idx < participants_.size();
554 participant_idx++) {
555 out = PerformReductionStep<T>(
556 reduction_kind, out,
557 input_buffers[participant_idx][buffer_idx][idx]);
558 }
559 for (int participant_idx = 0; participant_idx < participants_.size();
560 participant_idx++) {
561 output_buffers[participant_idx][buffer_idx][idx] = out;
562 }
563 }
564 }
565 }
566
567 template <typename T>
GetInitialValue(xla::ReductionKind reduction_kind)568 T GetInitialValue(xla::ReductionKind reduction_kind) {
569 switch (reduction_kind) {
570 case xla::ReductionKind::SUM:
571 return static_cast<T>(0);
572 case xla::ReductionKind::PRODUCT:
573 return static_cast<T>(1);
574 case xla::ReductionKind::MIN:
575 return std::numeric_limits<T>::max();
576 case xla::ReductionKind::MAX:
577 return std::numeric_limits<T>::min();
578 }
579 }
580
581 template <typename T>
PerformReductionStep(xla::ReductionKind reduction_kind,T a,T b)582 T PerformReductionStep(xla::ReductionKind reduction_kind, T a, T b) {
583 switch (reduction_kind) {
584 case xla::ReductionKind::SUM:
585 return a + b;
586 case xla::ReductionKind::PRODUCT:
587 return a * b;
588 case xla::ReductionKind::MIN:
589 return std::min(a, b);
590 case xla::ReductionKind::MAX:
591 return std::max(a, b);
592 }
593 }
594 };
595
596 xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>&
GlobalAllReduceRendezvousMap()597 GlobalAllReduceRendezvousMap() {
598 static auto& m =
599 *new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>;
600 return m;
601 }
602
603 xla::RefcountingHashMap<xla::RendezvousKey, CpuCollectivePermuteRendezvous>&
GlobalCollectivePermuteRendezvousMap()604 GlobalCollectivePermuteRendezvousMap() {
605 static auto& m = *new xla::RefcountingHashMap<xla::RendezvousKey,
606 CpuCollectivePermuteRendezvous>;
607 return m;
608 }
609
610 xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>&
GlobalAllToAllRendezvousMap()611 GlobalAllToAllRendezvousMap() {
612 static auto& m =
613 *new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>;
614 return m;
615 }
616
GetRendezvousKey(const xla::ExecutableRunOptions * run_options,std::vector<xla::ReplicaGroup> group,xla::int32 channel_id_present,xla::int64 op_id)617 xla::RendezvousKey GetRendezvousKey(
618 const xla::ExecutableRunOptions* run_options,
619 std::vector<xla::ReplicaGroup> group, xla::int32 channel_id_present,
620 xla::int64 op_id) {
621 const xla::DeviceAssignment& device_assignment =
622 *run_options->device_assignment();
623 int device_ordinal = GetDeviceOrdinal(run_options);
624 xla::RendezvousKey::CollectiveOpKind op_kind =
625 channel_id_present ? xla::RendezvousKey::kCrossModule
626 : xla::RendezvousKey::kCrossReplica;
627 std::vector<xla::GlobalDeviceId> participating_devices =
628 xla::GetParticipatingDevices(xla::GlobalDeviceId(device_ordinal),
629 device_assignment, group,
630 xla::CollectiveOpGroupMode::kCrossReplica)
631 .ValueOrDie();
632 int num_local_participants = participating_devices.size();
633 return xla::RendezvousKey{run_options->run_id(),
634 std::move(participating_devices),
635 num_local_participants, op_kind, op_id};
636 }
637
638 } // namespace
639
__xla_cpu_runtime_AllToAll(const xla::ExecutableRunOptions * run_options,xla::int32 channel_id_present,xla::int64 op_id,const void * replica_groups_str,xla::int32 replica_groups_str_size,xla::int32 num_buffers,xla::int64 buffer_size,void ** source_buffers,void ** destination_buffers)640 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll(
641 const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
642 xla::int64 op_id, const void* replica_groups_str,
643 xla::int32 replica_groups_str_size, xla::int32 num_buffers,
644 xla::int64 buffer_size, void** source_buffers, void** destination_buffers) {
645 int device_ordinal = GetDeviceOrdinal(run_options);
646 xla::int32 replica_id =
647 run_options->device_assignment()
648 ->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
649 .ValueOrDie();
650 absl::string_view replica_groups_serialized(
651 static_cast<const char*>(replica_groups_str), replica_groups_str_size);
652 std::vector<xla::ReplicaGroup> group =
653 xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
654 xla::RendezvousKey rendezvous_key =
655 GetRendezvousKey(run_options, group, channel_id_present, op_id);
656
657 AllToAllParticipantData participant(rendezvous_key, device_ordinal,
658 run_options->stream());
659 participant.replica_id = replica_id;
660 participant.replica_ids_to_copy_to =
661 xla::GetParticipatingIDs(
662 replica_id, run_options->device_assignment()->replica_count(), group)
663 .ValueOrDie();
664 for (int i = 0; i < num_buffers; i++) {
665 participant.source_buffers.emplace_back(source_buffers[i], buffer_size);
666 participant.destination_buffers.emplace_back(destination_buffers[i],
667 buffer_size);
668 }
669 auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
670 return absl::make_unique<CpuAllToAllRendezvous>(k);
671 };
672 TF_CHECK_OK(CpuAllToAllRendezvous::SubmitParticipant(
673 [&] {
674 return GlobalAllToAllRendezvousMap().GetOrCreateIfAbsent(
675 rendezvous_key, make_cpu_rendezvous);
676 },
677 participant)
678 .status());
679 }
680
__xla_cpu_runtime_AllReduce(const xla::ExecutableRunOptions * run_options,const void * replica_groups_str,xla::int32 replica_groups_str_size,xla::int32 channel_id_present,xla::int64 op_id,xla::int32 reduction_kind,const void * shape_ptr,xla::int32 shape_length,xla::int32 num_buffers,void ** input_buffers,void ** output_buffers)681 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
682 const xla::ExecutableRunOptions* run_options,
683 const void* replica_groups_str, xla::int32 replica_groups_str_size,
684 xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
685 const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers,
686 void** input_buffers, void** output_buffers) {
687 int device_ordinal = GetDeviceOrdinal(run_options);
688 absl::string_view replica_groups_serialized(
689 static_cast<const char*>(replica_groups_str), replica_groups_str_size);
690 std::vector<xla::ReplicaGroup> group =
691 xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
692 xla::RendezvousKey rendezvous_key =
693 GetRendezvousKey(run_options, group, channel_id_present, op_id);
694 auto shape_str = ShapeString(shape_ptr, shape_length);
695 VLOG(2) << "All-reduce input/output shape : " << shape_str;
696
697 xla::Shape shape =
698 DecodeSelfDescribingShapeConstant(shape_ptr, shape_length).ValueOrDie();
699
700 CHECK((num_buffers > 1 && shape.IsTuple()) ||
701 (num_buffers == 1 && xla::LayoutUtil::IsDenseArray(shape)));
702
703 xla::AllReduceParticipantData participant(rendezvous_key, device_ordinal,
704 run_options->stream());
705 participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind);
706 for (int i = 0; i < num_buffers; i++) {
707 xla::Shape subshape = num_buffers == 1 ? shape : shape.tuple_shapes(i);
708 xla::AllReduceParticipantData::Buffer buffer;
709 buffer.element_count = xla::ShapeUtil::ElementsIn(subshape);
710 buffer.primitive_type = subshape.element_type();
711 buffer.source_data = se::DeviceMemoryBase(
712 input_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
713 buffer.destination_data = se::DeviceMemoryBase(
714 output_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
715 participant.buffers.push_back(buffer);
716 }
717
718 auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
719 return absl::make_unique<CpuAllReduceRendezvous>(k);
720 };
721
722 TF_CHECK_OK(CpuAllReduceRendezvous::SubmitParticipant(
723 [&] {
724 return GlobalAllReduceRendezvousMap().GetOrCreateIfAbsent(
725 rendezvous_key, make_cpu_rendezvous);
726 },
727 participant)
728 .status());
729 }
730
__xla_cpu_runtime_ReplicaId(const xla::ExecutableRunOptions * run_options,void * output_buffer)731 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId(
732 const xla::ExecutableRunOptions* run_options, void* output_buffer) {
733 int device_ordinal = GetDeviceOrdinal(run_options);
734 xla::int32 replica_id =
735 run_options->device_assignment()
736 ->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
737 .ValueOrDie();
738 std::memcpy(output_buffer, &replica_id, 4);
739 }
740
__xla_cpu_runtime_CollectivePermute(const xla::ExecutableRunOptions * run_options,xla::int32 channel_id_present,xla::int64 op_id,xla::int32 byte_size,void * input_buffer,void * output_buffer,const void * source_target_pairs,xla::int32 source_target_pairs_size)741 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_CollectivePermute(
742 const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
743 xla::int64 op_id, xla::int32 byte_size, void* input_buffer,
744 void* output_buffer, const void* source_target_pairs,
745 xla::int32 source_target_pairs_size) {
746 int device_ordinal = GetDeviceOrdinal(run_options);
747 absl::string_view source_target_pairs_serialized(
748 static_cast<const char*>(source_target_pairs), source_target_pairs_size);
749 auto pairs = absl::StrSplit(source_target_pairs_serialized, ',');
750 xla::int32 replica_id =
751 run_options->device_assignment()
752 ->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
753 .ValueOrDie();
754 std::vector<int> copy_to;
755 for (auto& p : pairs) {
756 std::vector<std::string> mapping = absl::StrSplit(p, '=');
757 CHECK_EQ(mapping.size(), 2);
758 int from = std::stoi(mapping[0]);
759 int to = std::stoi(mapping[1]);
760 if (from == replica_id) {
761 copy_to.push_back(to);
762 }
763 }
764 xla::RendezvousKey rendezvous_key =
765 GetRendezvousKey(run_options, {}, channel_id_present, op_id);
766
767 CollectivePermuteParticipantData participant(rendezvous_key, device_ordinal,
768 run_options->stream());
769 participant.replica_id = replica_id;
770 participant.source_data = se::DeviceMemoryBase(input_buffer, byte_size);
771 participant.destination_data = se::DeviceMemoryBase(output_buffer, byte_size);
772 participant.replica_ids_to_copy_to = copy_to;
773 participant.byte_size = byte_size;
774
775 auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
776 return absl::make_unique<CpuCollectivePermuteRendezvous>(k);
777 };
778 TF_CHECK_OK(
779 CpuCollectivePermuteRendezvous::SubmitParticipant(
780 [&] {
781 return GlobalCollectivePermuteRendezvousMap().GetOrCreateIfAbsent(
782 rendezvous_key, make_cpu_rendezvous);
783 },
784 participant)
785 .status());
786 }
787