• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
17 
18 #include <cstddef>
19 #include <cstring>
20 #include <functional>
21 #include <limits>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/strings/str_format.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/synchronization/mutex.h"
27 #include "tensorflow/compiler/xla/executable_run_options.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/primitive_util.h"
30 #include "tensorflow/compiler/xla/refcounting_hash_map.h"
31 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
32 #include "tensorflow/compiler/xla/service/computation_placer.h"
33 #include "tensorflow/compiler/xla/service/hlo_parser.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/core/platform/dynamic_annotations.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/macros.h"
39 #include "tensorflow/core/platform/mem.h"
40 #include "tensorflow/core/platform/status.h"
41 #include "tensorflow/core/platform/types.h"
42 #include "tensorflow/core/profiler/lib/traceme.h"
43 #include "tensorflow/stream_executor/device_memory.h"
44 #include "tensorflow/stream_executor/stream_executor.h"
45 
46 namespace se = ::stream_executor;
47 
48 namespace xla {
49 namespace cpu {
50 namespace runtime {
51 
GetXfeedManager(int device_ordinal)52 XfeedManager* GetXfeedManager(int device_ordinal) {
53   static auto* managers = new absl::flat_hash_map<int, XfeedManager*>();
54   static absl::Mutex* mutex = new absl::Mutex();
55 
56   absl::MutexLock lock(mutex);
57   auto it = managers->find(device_ordinal);
58   if (it == managers->end()) {
59     it = managers->emplace(device_ordinal, new XfeedManager()).first;
60   }
61   return it->second;
62 }
63 
64 extern const char* const kEigenMatMulF16SymbolName =
65     "__xla_cpu_runtime_EigenMatMulF16";
66 extern const char* const kEigenMatMulF32SymbolName =
67     "__xla_cpu_runtime_EigenMatMulF32";
68 extern const char* const kEigenMatMulF64SymbolName =
69     "__xla_cpu_runtime_EigenMatMulF64";
70 extern const char* const kEigenMatMulC64SymbolName =
71     "__xla_cpu_runtime_EigenMatMulC64";
72 extern const char* const kEigenMatMulC128SymbolName =
73     "__xla_cpu_runtime_EigenMatMulC128";
74 extern const char* const kEigenMatMulS32SymbolName =
75     "__xla_cpu_runtime_EigenMatMulS32";
76 extern const char* const kMKLConvF32SymbolName = "__xla_cpu_runtime_MKLConvF32";
77 extern const char* const kMKLMatMulF32SymbolName =
78     "__xla_cpu_runtime_MKLMatMulF32";
79 extern const char* const kMKLMatMulF64SymbolName =
80     "__xla_cpu_runtime_MKLMatMulF64";
81 extern const char* const kMKLSingleThreadedMatMulF32SymbolName =
82     "__xla_cpu_runtime_MKLSingleThreadedMatMulF32";
83 extern const char* const kMKLSingleThreadedMatMulF64SymbolName =
84     "__xla_cpu_runtime_MKLSingleThreadedMatMulF64";
85 extern const char* const kEigenConvF16SymbolName =
86     "__xla_cpu_runtime_EigenConvF16";
87 extern const char* const kEigenConvF32SymbolName =
88     "__xla_cpu_runtime_EigenConvF32";
89 extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft";
90 extern const char* const kEigenSingleThreadedFftSymbolName =
91     "__xla_cpu_runtime_EigenSingleThreadedFft";
92 extern const char* const kEigenSingleThreadedMatMulF16SymbolName =
93     "__xla_cpu_runtime_EigenSingleThreadedMatMulF16";
94 extern const char* const kEigenSingleThreadedMatMulF32SymbolName =
95     "__xla_cpu_runtime_EigenSingleThreadedMatMulF32";
96 extern const char* const kEigenSingleThreadedMatMulF64SymbolName =
97     "__xla_cpu_runtime_EigenSingleThreadedMatMulF64";
98 extern const char* const kEigenSingleThreadedMatMulC64SymbolName =
99     "__xla_cpu_runtime_EigenSingleThreadedMatMulC64";
100 extern const char* const kEigenSingleThreadedMatMulC128SymbolName =
101     "__xla_cpu_runtime_EigenSingleThreadedMatMulC128";
102 extern const char* const kEigenSingleThreadedMatMulS32SymbolName =
103     "__xla_cpu_runtime_EigenSingleThreadedMatMulS32";
104 extern const char* const kEigenSingleThreadedConvF16SymbolName =
105     "__xla_cpu_runtime_EigenSingleThreadedConvF16";
106 extern const char* const kEigenSingleThreadedConvF32SymbolName =
107     "__xla_cpu_runtime_EigenSingleThreadedConvF32";
108 extern const char* const kAcquireInfeedBufferForDequeueSymbolName =
109     "__xla_cpu_runtime_AcquireInfeedBufferForDequeue";
110 extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName =
111     "__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue";
112 extern const char* const kAcquireOutfeedBufferForPopulationSymbolName =
113     "__xla_cpu_runtime_AcquireOutfeedBufferForPopulation";
114 extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName =
115     "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation";
116 extern const char* const kParallelForkJoinSymbolName =
117     "__xla_cpu_runtime_ParallelForkJoin";
118 extern const char* const kKeyValueSortSymbolName =
119     "__xla_cpu_runtime_KeyValueSort";
120 extern const char* const kTopKF32SymbolName = "__xla_cpu_runtime_TopKF32";
121 extern const char* const kTracingStartSymbolName =
122     "__xla_cpu_runtime_TracingStart";
123 extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
124 extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
125 extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce";
126 extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll";
127 extern const char* const kCollectivePermuteSymbolName =
128     "__xla_cpu_runtime_CollectivePermute";
129 extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId";
130 
131 }  // namespace runtime
132 }  // namespace cpu
133 }  // namespace xla
134 
135 namespace {
136 
137 struct CollectivePermuteParticipantData : xla::ParticipantData {
CollectivePermuteParticipantData__anon95669c100111::CollectivePermuteParticipantData138   CollectivePermuteParticipantData(const xla::RendezvousKey& rendezvous_key_p,
139                                    xla::int64 device_ordinal_p,
140                                    se::Stream* stream_p)
141       : ParticipantData(rendezvous_key_p, device_ordinal_p, stream_p) {}
142 
143   int replica_id;
144   se::DeviceMemoryBase source_data;
145   se::DeviceMemoryBase destination_data;
146   xla::int64 byte_size;
147   std::vector<int> replica_ids_to_copy_to;
148 
ToString__anon95669c100111::CollectivePermuteParticipantData149   std::string ToString() const override {
150     return absl::StrFormat(
151         "CollectivePermuteParticipantData{replica_id=%d, "
152         "source_data=%p, destination_data=%p, byte_size=%d, "
153         "replica_ids_to_copy_to=[%s]}",
154         replica_id, source_data.opaque(), destination_data.opaque(), byte_size,
155         absl::StrJoin(replica_ids_to_copy_to, ", "));
156   }
157 };
158 
159 struct AllToAllParticipantData : xla::ParticipantData {
AllToAllParticipantData__anon95669c100111::AllToAllParticipantData160   AllToAllParticipantData(const xla::RendezvousKey& rendezvous_key_p,
161                           xla::int64 device_ordinal_p, se::Stream* stream_p)
162       : ParticipantData(rendezvous_key_p, device_ordinal_p, stream_p) {}
163 
164   std::vector<se::DeviceMemoryBase> source_buffers;
165   std::vector<se::DeviceMemoryBase> destination_buffers;
166   int replica_id;
167 
168   // Replica ids participating in AllToAll, concatenation happens in the order
169   // of appearence.
170   std::vector<int> replica_ids_to_copy_to;
171 
ToString__anon95669c100111::AllToAllParticipantData172   std::string ToString() const override {
173     auto addr_formatter = [](std::string* out,
174                              const se::DeviceMemoryBase& mem) {
175       absl::StrAppend(out, absl::StrFormat("%p", mem.opaque()));
176     };
177     return absl::StrFormat(
178         "AllToAllParticipantData{replica_id=%d, "
179         "replica_ids_to_copy_to=[%s], source_buffers=[%s], "
180         "destination_buffers=[%s]}",
181         replica_id, absl::StrJoin(replica_ids_to_copy_to, ", "),
182         absl::StrJoin(source_buffers, ", ", addr_formatter),
183         absl::StrJoin(destination_buffers, ", ", addr_formatter));
184   }
185 };
186 
187 // Inverses the encoding of a Shape protobuf into an LLVM global variable.
DecodeSelfDescribingShapeConstant(const void * shape_ptr,xla::int32 size_bytes)188 xla::StatusOr<xla::Shape> DecodeSelfDescribingShapeConstant(
189     const void* shape_ptr, xla::int32 size_bytes) {
190   xla::ShapeProto shape_proto;
191   if (!shape_proto.ParseFromArray(shape_ptr, size_bytes)) {
192     return tensorflow::errors::Internal("Failed parsing the shape proto");
193   }
194   xla::Shape shape(shape_proto);
195   auto status = xla::ShapeUtil::ValidateShape(shape);
196   if (!status.ok()) {
197     return status;
198   }
199   return std::move(shape);
200 }
201 
ShapeString(const void * shape_ptr,xla::int32 shape_length)202 tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) {
203   xla::StatusOr<xla::Shape> shape =
204       DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
205   if (shape.ok()) {
206     return xla::ShapeUtil::HumanStringWithLayout(shape.ValueOrDie());
207   }
208   return "<invalid shape>";
209 }
210 
211 }  // namespace
212 
213 extern "C" {
214 
__xla_cpu_runtime_TracingStart(const void * run_options_ptr,const char * name)215 TF_ATTRIBUTE_NO_SANITIZE_MEMORY xla::int64 __xla_cpu_runtime_TracingStart(
216     const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
217     const char* name) {
218   VLOG(3) << "TracingStart " << name;
219   return tensorflow::profiler::TraceMe::ActivityStart(name);
220 }
221 
__xla_cpu_runtime_TracingEnd(const void * run_options_ptr,xla::int64 id)222 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_TracingEnd(
223     const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
224     xla::int64 id) {
225   VLOG(3) << "TracingEnd " << id;
226   tensorflow::profiler::TraceMe::ActivityEnd(id);
227 }
228 
229 }  // extern "C"
230 
231 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
__xla_cpu_runtime_AcquireInfeedBufferForDequeue(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,const void * shape,xla::int32 shape_length)232 __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
233     const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
234     const void* shape, xla::int32 shape_length) {
235   int device_ordinal =
236       run_options ? run_options->stream()->parent()->device_ordinal() : 0;
237 
238   VLOG(2) << "AcquireInfeedBufferForDequeue: "
239           << ShapeString(shape, shape_length) << " on stream executor "
240           << device_ordinal;
241 
242   xla::cpu::runtime::XfeedManager* xfeed =
243       xla::cpu::runtime::GetXfeedManager(device_ordinal);
244   // Wait until there's a buffer to dequeue.
245   xla::cpu::runtime::XfeedBuffer* buffer =
246       xfeed->infeed()->BlockingDequeueBuffer();
247   CHECK_EQ(buffer->length(), buffer_length)
248       << "XLA program infeed request buffer size " << buffer_length
249       << " did not match the runtime's infed buffer length " << buffer->length()
250       << "; program reports desired shape: "
251       << ShapeString(shape, shape_length);
252   return buffer->data();
253 }
254 
255 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,void * buffer_ptr,const void * shape_ptr,xla::int32 shape_length)256 __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
257     const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
258     void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
259   int device_ordinal =
260       run_options ? run_options->stream()->parent()->device_ordinal() : 0;
261 
262   VLOG(2) << "ReleaseInfeedBufferAfterDeque: "
263           << ShapeString(shape_ptr, shape_length) << " on stream executor "
264           << device_ordinal;
265 
266   xla::cpu::runtime::XfeedManager* xfeed =
267       xla::cpu::runtime::GetXfeedManager(device_ordinal);
268   xla::StatusOr<xla::Shape> shape =
269       DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
270   xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
271                                         std::move(shape));
272 }
273 
274 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,const void * shape_ptr,xla::int32 shape_length)275 __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
276     const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
277     const void* shape_ptr, xla::int32 shape_length) {
278   int device_ordinal =
279       run_options ? run_options->stream()->parent()->device_ordinal() : 0;
280 
281   VLOG(2) << "AcquireOutfeedBufferForPopulation: "
282           << ShapeString(shape_ptr, shape_length) << " on stream executor "
283           << device_ordinal;
284 
285   xla::cpu::runtime::XfeedManager* xfeed =
286       xla::cpu::runtime::GetXfeedManager(device_ordinal);
287   // Wait until there's a buffer to dequeue.
288   xla::cpu::runtime::XfeedBuffer* buffer =
289       xfeed->outfeed()->BlockingDequeueBuffer();
290   CHECK_EQ(buffer->length(), buffer_length)
291       << "XLA program outfeed request buffer size " << buffer_length
292       << " did not match the runtime's outfeed buffer length "
293       << buffer->length() << "; program reports outfed shape: "
294       << ShapeString(shape_ptr, shape_length);
295   return buffer->data();
296 }
297 
298 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,void * buffer_ptr,const void * shape_ptr,xla::int32 shape_length)299 __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
300     const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
301     void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
302   int device_ordinal =
303       run_options ? run_options->stream()->parent()->device_ordinal() : 0;
304 
305   VLOG(2) << "ReleaseOutfeedBufferAfterPopulation: "
306           << ShapeString(shape_ptr, shape_length) << " on stream executor "
307           << device_ordinal;
308 
309   xla::cpu::runtime::XfeedManager* xfeed =
310       xla::cpu::runtime::GetXfeedManager(device_ordinal);
311   xla::StatusOr<xla::Shape> shape =
312       DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
313   xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
314                                          std::move(shape));
315 }
316 
317 namespace {
318 
319 class CpuAllToAllRendezvous
320     : public xla::Rendezvous<AllToAllParticipantData, std::nullptr_t> {
321  public:
CpuAllToAllRendezvous(const xla::RendezvousKey & k)322   explicit CpuAllToAllRendezvous(const xla::RendezvousKey& k)
323       : xla::Rendezvous<AllToAllParticipantData, std::nullptr_t>(k) {}
324 
325  protected:
RunCollectiveOp(const AllToAllParticipantData &)326   xla::StatusOr<std::nullptr_t> RunCollectiveOp(
327       const AllToAllParticipantData& /*participant*/) override {
328     bool is_primary = InitializationBarrier();
329 
330     if (is_primary) {
331       tensorflow::mutex_lock lock(mu_);
332 
333       CHECK(!participants_.empty());
334       CHECK(!participants_[0].source_buffers.empty());
335       int expected_buffer_size = participants_[0].source_buffers[0].size();
336 
337       // Replica id -> position in participants_.
338       absl::flat_hash_map<int, int> replica_id_map;
339 
340       for (int pos = 0; pos < participants_.size(); pos++) {
341         const AllToAllParticipantData& p = participants_[pos];
342         CHECK_EQ(p.source_buffers.size(), p.destination_buffers.size());
343         CHECK_EQ(p.source_buffers.size(), participants_.size());
344         for (int i = 0; i < p.source_buffers.size(); i++) {
345           CHECK_EQ(p.destination_buffers[i].size(), expected_buffer_size);
346           CHECK_EQ(p.source_buffers[i].size(), expected_buffer_size);
347         }
348         replica_id_map[p.replica_id] = pos;
349       }
350 
351       const std::vector<int>& replica_ids_to_copy_to =
352           participants_[0].replica_ids_to_copy_to;
353 
354       // Replica id -> rank
355       absl::flat_hash_map<int, int> replica_ranks;
356       for (int rank = 0; rank < replica_ids_to_copy_to.size(); ++rank) {
357         int replica_id = replica_ids_to_copy_to[rank];
358         replica_ranks[replica_id] = rank;
359       }
360 
361       for (const AllToAllParticipantData& sender : participants_) {
362         VLOG(3) << "Processing AllToAll participant: " << sender.ToString();
363 
364         int rank = xla::FindOrDie(replica_ranks, sender.replica_id);
365 
366         for (int i = 0; i < participants_.size(); ++i) {
367           int replica_id = replica_ids_to_copy_to[i];
368           int participant_num = xla::FindOrDie(replica_id_map, replica_id);
369           AllToAllParticipantData& receiver = participants_[participant_num];
370 
371           std::memcpy(receiver.destination_buffers[rank].opaque(),
372                       sender.source_buffers[i].opaque(), expected_buffer_size);
373         }
374       }
375     }
376     return nullptr;
377   }
378 };
379 
380 class CpuCollectivePermuteRendezvous
381     : public xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t> {
382  public:
CpuCollectivePermuteRendezvous(const xla::RendezvousKey & k)383   explicit CpuCollectivePermuteRendezvous(const xla::RendezvousKey& k)
384       : xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t>(k) {}
385 
386  protected:
RunCollectiveOp(const CollectivePermuteParticipantData &)387   xla::StatusOr<std::nullptr_t> RunCollectiveOp(
388       const CollectivePermuteParticipantData& /*participant*/) override {
389     bool primary = InitializationBarrier();
390 
391     // Perform all copies from the primary thread.
392     if (primary) {
393       tensorflow::mutex_lock lock(mu_);
394 
395       std::map<int, int> replica_idx_to_participant_idx;
396       for (int p_idx = 0; p_idx < participants_.size(); p_idx++) {
397         replica_idx_to_participant_idx[participants_[p_idx].replica_id] = p_idx;
398       }
399 
400       for (auto& p : participants_) {
401         for (int dest_replica : p.replica_ids_to_copy_to) {
402           auto& dest_p = participants_[xla::FindOrDie(
403               replica_idx_to_participant_idx, dest_replica)];
404           std::memcpy(dest_p.destination_data.opaque(), p.source_data.opaque(),
405                       p.byte_size);
406 
407           // Each replica may be copied into only once.
408           replica_idx_to_participant_idx.erase(dest_replica);
409         }
410       }
411 
412       // Zero out untouched participants.
413       for (auto& replica_p : replica_idx_to_participant_idx) {
414         auto& p = participants_[replica_p.second];
415         std::memset(p.destination_data.opaque(), 0, p.byte_size);
416       }
417     }
418     return nullptr;
419   }
420 };
421 
422 class CpuAllReduceRendezvous
423     : public xla::Rendezvous<xla::AllReduceParticipantData, std::nullptr_t> {
424  public:
CpuAllReduceRendezvous(const xla::RendezvousKey & k)425   explicit CpuAllReduceRendezvous(const xla::RendezvousKey& k)
426       : xla::Rendezvous<xla::AllReduceParticipantData, std::nullptr_t>(k) {}
427 
428  protected:
RunCollectiveOp(const xla::AllReduceParticipantData & participant)429   xla::StatusOr<std::nullptr_t> RunCollectiveOp(
430       const xla::AllReduceParticipantData& participant) override {
431     xla::PrimitiveType datatype = participant.buffers.front().primitive_type;
432     bool primary = InitializationBarrier();
433 
434     if (primary) {
435       switch (datatype) {
436         case xla::S8:
437           DoAllReduce<xla::S8>(participant);
438           break;
439         case xla::PRED:
440         case xla::U8:
441           DoAllReduce<xla::U8>(participant);
442           break;
443         case xla::S32:
444           DoAllReduce<xla::S32>(participant);
445           break;
446         case xla::U32:
447           DoAllReduce<xla::U32>(participant);
448           break;
449         case xla::S64:
450           DoAllReduce<xla::S64>(participant);
451           break;
452         case xla::U64:
453           DoAllReduce<xla::U64>(participant);
454           break;
455         case xla::F16:
456           DoAllReduce<xla::F16>(participant);
457           break;
458         case xla::F32:
459           DoAllReduce<xla::F32>(participant);
460           break;
461         case xla::F64:
462           DoAllReduce<xla::F64>(participant);
463           break;
464         default:
465           LOG(FATAL) << "Unexpected datatype;";
466       }
467     }
468     return nullptr;
469   }
470 
471  private:
472   template <xla::PrimitiveType PT>
DoAllReduce(xla::AllReduceParticipantData participant)473   void DoAllReduce(xla::AllReduceParticipantData participant) {
474     using T = typename xla::primitive_util::PrimitiveTypeToNative<PT>::type;
475     tensorflow::mutex_lock lock(mu_);
476     CHECK(!participants_.empty());
477     xla::ReductionKind reduction_kind = participant.reduction_kind;
478     for (const auto& p : participants_) {
479       CHECK(p.reduction_kind == reduction_kind);
480     }
481     int num_participants = participants_.size();
482 
483     // participant_idx -> buffer_idx -> buffer.
484     std::vector<std::vector<absl::Span<T>>> input_buffers;
485     std::vector<std::vector<absl::Span<T>>> output_buffers;
486     input_buffers.reserve(num_participants);
487     output_buffers.reserve(num_participants);
488     const xla::AllReduceParticipantData& first_participant =
489         participants_.front();
490 
491     int buffers_per_participant = first_participant.buffers.size();
492     for (xla::AllReduceParticipantData& p : participants_) {
493       CHECK_EQ(p.buffers.size(), buffers_per_participant);
494 
495       input_buffers.emplace_back();
496       output_buffers.emplace_back();
497       std::vector<absl::Span<T>>& participant_input_buffers =
498           input_buffers.back();
499       std::vector<absl::Span<T>>& participant_output_buffers =
500           output_buffers.back();
501       participant_input_buffers.reserve(p.buffers.size());
502       participant_output_buffers.reserve(p.buffers.size());
503 
504       for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
505            buffer_idx++) {
506         auto& participant_buffer = p.buffers[buffer_idx];
507         participant_input_buffers.emplace_back(
508             static_cast<T*>(participant_buffer.source_data.opaque()),
509             participant_buffer.element_count);
510         participant_output_buffers.emplace_back(
511             static_cast<T*>(participant_buffer.destination_data.opaque()),
512             participant_buffer.element_count);
513         CHECK_EQ(participant_buffer.element_count,
514                  first_participant.buffers[buffer_idx].element_count);
515       }
516     }
517 
518     for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
519          buffer_idx++) {
520       int element_count = first_participant.buffers[buffer_idx].element_count;
521       for (int idx = 0; idx < element_count; idx++) {
522         T out = GetInitialValue<T>(reduction_kind);
523         for (int participant_idx = 0; participant_idx < participants_.size();
524              participant_idx++) {
525           out = PerformReductionStep<T>(
526               reduction_kind, out,
527               input_buffers[participant_idx][buffer_idx][idx]);
528         }
529         for (int participant_idx = 0; participant_idx < participants_.size();
530              participant_idx++) {
531           output_buffers[participant_idx][buffer_idx][idx] = out;
532         }
533       }
534     }
535   }
536 
537   template <typename T>
GetInitialValue(xla::ReductionKind reduction_kind)538   T GetInitialValue(xla::ReductionKind reduction_kind) {
539     switch (reduction_kind) {
540       case xla::ReductionKind::SUM:
541         return static_cast<T>(0);
542       case xla::ReductionKind::PRODUCT:
543         return static_cast<T>(1);
544       case xla::ReductionKind::MIN:
545         return std::numeric_limits<T>::max();
546       case xla::ReductionKind::MAX:
547         return std::numeric_limits<T>::min();
548     }
549   }
550 
551   template <typename T>
PerformReductionStep(xla::ReductionKind reduction_kind,T a,T b)552   T PerformReductionStep(xla::ReductionKind reduction_kind, T a, T b) {
553     switch (reduction_kind) {
554       case xla::ReductionKind::SUM:
555         return a + b;
556       case xla::ReductionKind::PRODUCT:
557         return a * b;
558       case xla::ReductionKind::MIN:
559         return std::min(a, b);
560       case xla::ReductionKind::MAX:
561         return std::max(a, b);
562     }
563   }
564 };
565 
566 xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>&
GlobalAllReduceRendezvousMap()567 GlobalAllReduceRendezvousMap() {
568   static auto& m =
569       *new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>;
570   return m;
571 }
572 
573 xla::RefcountingHashMap<xla::RendezvousKey, CpuCollectivePermuteRendezvous>&
GlobalCollectivePermuteRendezvousMap()574 GlobalCollectivePermuteRendezvousMap() {
575   static auto& m = *new xla::RefcountingHashMap<xla::RendezvousKey,
576                                                 CpuCollectivePermuteRendezvous>;
577   return m;
578 }
579 
580 xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>&
GlobalAllToAllRendezvousMap()581 GlobalAllToAllRendezvousMap() {
582   static auto& m =
583       *new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>;
584   return m;
585 }
586 
GetDeviceOrdinal(const xla::ExecutableRunOptions * run_options)587 int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options) {
588   if (run_options->stream()) {
589     return run_options->stream()->parent()->device_ordinal();
590   } else {
591     return run_options->device_ordinal();
592   }
593 }
594 
GetRendezvousKey(const xla::ExecutableRunOptions * run_options,std::vector<xla::ReplicaGroup> group,xla::int32 channel_id_present,xla::int64 op_id)595 xla::RendezvousKey GetRendezvousKey(
596     const xla::ExecutableRunOptions* run_options,
597     std::vector<xla::ReplicaGroup> group, xla::int32 channel_id_present,
598     xla::int64 op_id) {
599   const xla::DeviceAssignment& device_assignment =
600       *run_options->device_assignment();
601   int device_ordinal = GetDeviceOrdinal(run_options);
602   xla::RendezvousKey::CollectiveOpKind op_kind =
603       channel_id_present ? xla::RendezvousKey::kCrossModule
604                          : xla::RendezvousKey::kCrossReplica;
605   std::vector<xla::GlobalDeviceId> participating_devices =
606       xla::GetParticipatingDevices(xla::GlobalDeviceId(device_ordinal),
607                                    device_assignment,
608                                    device_assignment.replica_count(), group)
609           .ValueOrDie();
610   int num_local_participants = participating_devices.size();
611   return xla::RendezvousKey{run_options->run_id(),
612                             std::move(participating_devices),
613                             num_local_participants, op_kind, op_id};
614 }
615 
616 }  // namespace
617 
__xla_cpu_runtime_AllToAll(const xla::ExecutableRunOptions * run_options,xla::int32 channel_id_present,xla::int64 op_id,const void * replica_groups_str,xla::int32 replica_groups_str_size,xla::int32 num_buffers,xla::int64 buffer_size,void ** source_buffers,void ** destination_buffers)618 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll(
619     const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
620     xla::int64 op_id, const void* replica_groups_str,
621     xla::int32 replica_groups_str_size, xla::int32 num_buffers,
622     xla::int64 buffer_size, void** source_buffers, void** destination_buffers) {
623   int device_ordinal = GetDeviceOrdinal(run_options);
624   xla::int32 replica_id =
625       run_options->device_assignment()
626           ->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
627           .ValueOrDie();
628   absl::string_view replica_groups_serialized(
629       static_cast<const char*>(replica_groups_str), replica_groups_str_size);
630   std::vector<xla::ReplicaGroup> group =
631       xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
632   xla::RendezvousKey rendezvous_key =
633       GetRendezvousKey(run_options, group, channel_id_present, op_id);
634 
635   AllToAllParticipantData participant(rendezvous_key, device_ordinal,
636                                       run_options->stream());
637   participant.replica_id = replica_id;
638   participant.replica_ids_to_copy_to =
639       xla::GetParticipatingReplicas(
640           replica_id, run_options->device_assignment()->replica_count(), group)
641           .ValueOrDie();
642   for (int i = 0; i < num_buffers; i++) {
643     participant.source_buffers.emplace_back(source_buffers[i], buffer_size);
644     participant.destination_buffers.emplace_back(destination_buffers[i],
645                                                  buffer_size);
646   }
647   auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
648     return absl::make_unique<CpuAllToAllRendezvous>(k);
649   };
650   TF_CHECK_OK(CpuAllToAllRendezvous::SubmitParticipant(
651                   [&] {
652                     return GlobalAllToAllRendezvousMap().GetOrCreateIfAbsent(
653                         rendezvous_key, make_cpu_rendezvous);
654                   },
655                   participant)
656                   .status());
657 }
658 
__xla_cpu_runtime_AllReduce(const xla::ExecutableRunOptions * run_options,const void * replica_groups_str,xla::int32 replica_groups_str_size,xla::int32 channel_id_present,xla::int64 op_id,xla::int32 reduction_kind,const void * shape_ptr,xla::int32 shape_length,xla::int32 num_buffers,void ** input_buffers,void ** output_buffers)659 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
660     const xla::ExecutableRunOptions* run_options,
661     const void* replica_groups_str, xla::int32 replica_groups_str_size,
662     xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
663     const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers,
664     void** input_buffers, void** output_buffers) {
665   int device_ordinal = GetDeviceOrdinal(run_options);
666   absl::string_view replica_groups_serialized(
667       static_cast<const char*>(replica_groups_str), replica_groups_str_size);
668   std::vector<xla::ReplicaGroup> group =
669       xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
670   xla::RendezvousKey rendezvous_key =
671       GetRendezvousKey(run_options, group, channel_id_present, op_id);
672   auto shape_str = ShapeString(shape_ptr, shape_length);
673   VLOG(2) << "All-reduce input/output shape : " << shape_str;
674 
675   xla::Shape shape =
676       DecodeSelfDescribingShapeConstant(shape_ptr, shape_length).ValueOrDie();
677 
678   CHECK((num_buffers > 1 && shape.IsTuple()) ||
679         (num_buffers == 1 && xla::LayoutUtil::IsDenseArray(shape)));
680 
681   xla::AllReduceParticipantData participant(rendezvous_key, device_ordinal,
682                                             run_options->stream());
683   participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind);
684   for (int i = 0; i < num_buffers; i++) {
685     xla::Shape subshape = num_buffers == 1 ? shape : shape.tuple_shapes(i);
686     xla::AllReduceParticipantData::Buffer buffer;
687     buffer.element_count = xla::ShapeUtil::ElementsIn(subshape);
688     buffer.primitive_type = subshape.element_type();
689     buffer.source_data = se::DeviceMemoryBase(
690         input_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
691     buffer.destination_data = se::DeviceMemoryBase(
692         output_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
693     participant.buffers.push_back(buffer);
694   }
695 
696   auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
697     return absl::make_unique<CpuAllReduceRendezvous>(k);
698   };
699 
700   TF_CHECK_OK(CpuAllReduceRendezvous::SubmitParticipant(
701                   [&] {
702                     return GlobalAllReduceRendezvousMap().GetOrCreateIfAbsent(
703                         rendezvous_key, make_cpu_rendezvous);
704                   },
705                   participant)
706                   .status());
707 }
708 
__xla_cpu_runtime_ReplicaId(const xla::ExecutableRunOptions * run_options,void * output_buffer)709 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId(
710     const xla::ExecutableRunOptions* run_options, void* output_buffer) {
711   int device_ordinal = GetDeviceOrdinal(run_options);
712   xla::int32 replica_id =
713       run_options->device_assignment()
714           ->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
715           .ValueOrDie();
716   std::memcpy(output_buffer, &replica_id, 4);
717 }
718 
__xla_cpu_runtime_CollectivePermute(const xla::ExecutableRunOptions * run_options,xla::int32 channel_id_present,xla::int64 op_id,xla::int32 byte_size,void * input_buffer,void * output_buffer,const void * source_target_pairs,xla::int32 source_target_pairs_size)719 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_CollectivePermute(
720     const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
721     xla::int64 op_id, xla::int32 byte_size, void* input_buffer,
722     void* output_buffer, const void* source_target_pairs,
723     xla::int32 source_target_pairs_size) {
724   int device_ordinal = GetDeviceOrdinal(run_options);
725   absl::string_view source_target_pairs_serialized(
726       static_cast<const char*>(source_target_pairs), source_target_pairs_size);
727   auto pairs = absl::StrSplit(source_target_pairs_serialized, ',');
728   xla::int32 replica_id =
729       run_options->device_assignment()
730           ->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
731           .ValueOrDie();
732   std::vector<int> copy_to;
733   for (auto& p : pairs) {
734     std::vector<std::string> mapping = absl::StrSplit(p, '=');
735     CHECK_EQ(mapping.size(), 2);
736     int from = std::stoi(mapping[0]);
737     int to = std::stoi(mapping[1]);
738     if (from == replica_id) {
739       copy_to.push_back(to);
740     }
741   }
742   xla::RendezvousKey rendezvous_key =
743       GetRendezvousKey(run_options, {}, channel_id_present, op_id);
744 
745   CollectivePermuteParticipantData participant(rendezvous_key, device_ordinal,
746                                                run_options->stream());
747   participant.replica_id = replica_id;
748   participant.source_data = se::DeviceMemoryBase(input_buffer, byte_size);
749   participant.destination_data = se::DeviceMemoryBase(output_buffer, byte_size);
750   participant.replica_ids_to_copy_to = copy_to;
751   participant.byte_size = byte_size;
752 
753   auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
754     return absl::make_unique<CpuCollectivePermuteRendezvous>(k);
755   };
756   TF_CHECK_OK(
757       CpuCollectivePermuteRendezvous::SubmitParticipant(
758           [&] {
759             return GlobalCollectivePermuteRendezvousMap().GetOrCreateIfAbsent(
760                 rendezvous_key, make_cpu_rendezvous);
761           },
762           participant)
763           .status());
764 }
765