• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
17 
18 #include <cstdarg>
19 #include <cstddef>
20 #include <cstring>
21 #include <functional>
22 #include <limits>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/strings/str_format.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/synchronization/mutex.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/compiler/xla/executable_run_options.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/primitive_util.h"
32 #include "tensorflow/compiler/xla/refcounting_hash_map.h"
33 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
34 #include "tensorflow/compiler/xla/service/computation_placer.h"
35 #include "tensorflow/compiler/xla/service/hlo_parser.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/statusor.h"
38 #include "tensorflow/core/platform/dynamic_annotations.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/macros.h"
41 #include "tensorflow/core/platform/mem.h"
42 #include "tensorflow/core/platform/status.h"
43 #include "tensorflow/core/platform/types.h"
44 #include "tensorflow/core/profiler/lib/traceme.h"
45 #include "tensorflow/stream_executor/device_memory.h"
46 #include "tensorflow/stream_executor/stream_executor.h"
47 
48 namespace se = ::stream_executor;
49 
50 namespace xla {
51 namespace cpu {
52 namespace runtime {
53 
GetXfeedManager(int device_ordinal)54 XfeedManager* GetXfeedManager(int device_ordinal) {
55   static auto* managers = new absl::flat_hash_map<int, XfeedManager*>();
56   static absl::Mutex* mutex = new absl::Mutex();
57 
58   absl::MutexLock lock(mutex);
59   auto it = managers->find(device_ordinal);
60   if (it == managers->end()) {
61     it = managers->emplace(device_ordinal, new XfeedManager()).first;
62   }
63   return it->second;
64 }
65 
66 extern const char* const kEigenMatMulF16SymbolName =
67     "__xla_cpu_runtime_EigenMatMulF16";
68 extern const char* const kEigenMatMulF32SymbolName =
69     "__xla_cpu_runtime_EigenMatMulF32";
70 extern const char* const kEigenMatMulF64SymbolName =
71     "__xla_cpu_runtime_EigenMatMulF64";
72 extern const char* const kEigenMatMulC64SymbolName =
73     "__xla_cpu_runtime_EigenMatMulC64";
74 extern const char* const kEigenMatMulC128SymbolName =
75     "__xla_cpu_runtime_EigenMatMulC128";
76 extern const char* const kEigenMatMulS32SymbolName =
77     "__xla_cpu_runtime_EigenMatMulS32";
78 extern const char* const kMKLConvF32SymbolName = "__xla_cpu_runtime_MKLConvF32";
79 extern const char* const kMKLMatMulF32SymbolName =
80     "__xla_cpu_runtime_MKLMatMulF32";
81 extern const char* const kMKLMatMulF64SymbolName =
82     "__xla_cpu_runtime_MKLMatMulF64";
83 extern const char* const kMKLSingleThreadedMatMulF32SymbolName =
84     "__xla_cpu_runtime_MKLSingleThreadedMatMulF32";
85 extern const char* const kMKLSingleThreadedMatMulF64SymbolName =
86     "__xla_cpu_runtime_MKLSingleThreadedMatMulF64";
87 extern const char* const kEigenConvF16SymbolName =
88     "__xla_cpu_runtime_EigenConvF16";
89 extern const char* const kEigenConvF32SymbolName =
90     "__xla_cpu_runtime_EigenConvF32";
91 extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft";
92 extern const char* const kEigenSingleThreadedFftSymbolName =
93     "__xla_cpu_runtime_EigenSingleThreadedFft";
94 extern const char* const kEigenSingleThreadedMatMulF16SymbolName =
95     "__xla_cpu_runtime_EigenSingleThreadedMatMulF16";
96 extern const char* const kEigenSingleThreadedMatMulF32SymbolName =
97     "__xla_cpu_runtime_EigenSingleThreadedMatMulF32";
98 extern const char* const kEigenSingleThreadedMatMulF64SymbolName =
99     "__xla_cpu_runtime_EigenSingleThreadedMatMulF64";
100 extern const char* const kEigenSingleThreadedMatMulC64SymbolName =
101     "__xla_cpu_runtime_EigenSingleThreadedMatMulC64";
102 extern const char* const kEigenSingleThreadedMatMulC128SymbolName =
103     "__xla_cpu_runtime_EigenSingleThreadedMatMulC128";
104 extern const char* const kEigenSingleThreadedMatMulS32SymbolName =
105     "__xla_cpu_runtime_EigenSingleThreadedMatMulS32";
106 extern const char* const kEigenSingleThreadedConvF16SymbolName =
107     "__xla_cpu_runtime_EigenSingleThreadedConvF16";
108 extern const char* const kEigenSingleThreadedConvF32SymbolName =
109     "__xla_cpu_runtime_EigenSingleThreadedConvF32";
110 extern const char* const kAcquireInfeedBufferForDequeueSymbolName =
111     "__xla_cpu_runtime_AcquireInfeedBufferForDequeue";
112 extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName =
113     "__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue";
114 extern const char* const kAcquireOutfeedBufferForPopulationSymbolName =
115     "__xla_cpu_runtime_AcquireOutfeedBufferForPopulation";
116 extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName =
117     "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation";
118 extern const char* const kParallelForkJoinSymbolName =
119     "__xla_cpu_runtime_ParallelForkJoin";
120 extern const char* const kPrintfToStderrSymbolName =
121     "__xla_cpu_runtime_PrintfToStderr";
122 extern const char* const kKeyValueSortSymbolName =
123     "__xla_cpu_runtime_KeyValueSort";
124 extern const char* const kTopKF32SymbolName = "__xla_cpu_runtime_TopKF32";
125 extern const char* const kTracingStartSymbolName =
126     "__xla_cpu_runtime_TracingStart";
127 extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
128 extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
129 extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce";
130 extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll";
131 extern const char* const kCollectivePermuteSymbolName =
132     "__xla_cpu_runtime_CollectivePermute";
133 extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId";
134 
135 }  // namespace runtime
136 }  // namespace cpu
137 }  // namespace xla
138 
139 namespace {
140 
141 struct CollectivePermuteParticipantData : xla::ParticipantData {
CollectivePermuteParticipantData__anonf775ca150111::CollectivePermuteParticipantData142   CollectivePermuteParticipantData(const xla::RendezvousKey& rendezvous_key_p,
143                                    xla::int64 device_ordinal_p,
144                                    se::Stream* stream_p)
145       : ParticipantData(rendezvous_key_p),
146         device_ordinal(device_ordinal_p),
147         stream(stream_p) {}
148 
149   xla::int64 device_ordinal;
150   se::Stream* stream;
151   int replica_id;
152   se::DeviceMemoryBase source_data;
153   se::DeviceMemoryBase destination_data;
154   xla::int64 byte_size;
155   std::vector<int> replica_ids_to_copy_to;
156 
ToString__anonf775ca150111::CollectivePermuteParticipantData157   std::string ToString() const override {
158     return absl::StrFormat(
159         "CollectivePermuteParticipantData{replica_id=%d, "
160         "source_data=%p, destination_data=%p, byte_size=%d, "
161         "replica_ids_to_copy_to=[%s], device_ordinal=%d, stream=%p}",
162         replica_id, source_data.opaque(), destination_data.opaque(), byte_size,
163         absl::StrJoin(replica_ids_to_copy_to, ", "), device_ordinal, stream);
164   }
165 };
166 
167 struct AllToAllParticipantData : xla::ParticipantData {
AllToAllParticipantData__anonf775ca150111::AllToAllParticipantData168   AllToAllParticipantData(const xla::RendezvousKey& rendezvous_key_p,
169                           xla::int64 device_ordinal_p, se::Stream* stream_p)
170       : ParticipantData(rendezvous_key_p),
171         device_ordinal(device_ordinal_p),
172         stream(stream_p) {}
173 
174   xla::int64 device_ordinal;
175   se::Stream* stream;
176   std::vector<se::DeviceMemoryBase> source_buffers;
177   std::vector<se::DeviceMemoryBase> destination_buffers;
178   int replica_id;
179 
180   // Replica ids participating in AllToAll, concatenation happens in the order
181   // of appearence.
182   std::vector<int> replica_ids_to_copy_to;
183 
ToString__anonf775ca150111::AllToAllParticipantData184   std::string ToString() const override {
185     auto addr_formatter = [](std::string* out,
186                              const se::DeviceMemoryBase& mem) {
187       absl::StrAppend(out, absl::StrFormat("%p", mem.opaque()));
188     };
189     return absl::StrFormat(
190         "AllToAllParticipantData{replica_id=%d, "
191         "replica_ids_to_copy_to=[%s], source_buffers=[%s], "
192         "destination_buffers=[%s], device_ordinal=%d, stream=%p}",
193         replica_id, absl::StrJoin(replica_ids_to_copy_to, ", "),
194         absl::StrJoin(source_buffers, ", ", addr_formatter),
195         absl::StrJoin(destination_buffers, ", ", addr_formatter),
196         device_ordinal, stream);
197   }
198 };
199 
200 // Inverses the encoding of a Shape protobuf into an LLVM global variable.
DecodeSelfDescribingShapeConstant(const void * shape_ptr,xla::int32 size_bytes)201 xla::StatusOr<xla::Shape> DecodeSelfDescribingShapeConstant(
202     const void* shape_ptr, xla::int32 size_bytes) {
203   xla::ShapeProto shape_proto;
204   if (!shape_proto.ParseFromArray(shape_ptr, size_bytes)) {
205     return tensorflow::errors::Internal("Failed parsing the shape proto");
206   }
207   xla::Shape shape(shape_proto);
208   auto status = xla::ShapeUtil::ValidateShape(shape);
209   if (!status.ok()) {
210     return status;
211   }
212   return std::move(shape);
213 }
214 
ShapeString(const void * shape_ptr,xla::int32 shape_length)215 tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) {
216   xla::StatusOr<xla::Shape> shape =
217       DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
218   if (shape.ok()) {
219     return xla::ShapeUtil::HumanStringWithLayout(shape.ValueOrDie());
220   }
221   return "<invalid shape>";
222 }
223 
224 // TODO(zhangqiaorjc): Prefer to make callers set and use device_ordinal
225 // directly since callers may not have a Stream*.
GetDeviceOrdinal(const xla::ExecutableRunOptions * run_options)226 int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options) {
227   if (!run_options) {
228     return 0;
229   } else if (run_options->device_ordinal() != -1) {
230     return run_options->device_ordinal();
231   }
232   return run_options->stream()->parent()->device_ordinal();
233 }
234 
235 }  // namespace
236 
237 extern "C" {
238 
__xla_cpu_runtime_PrintfToStderr(const char * format,...)239 TF_ATTRIBUTE_NO_SANITIZE_MEMORY int __xla_cpu_runtime_PrintfToStderr(
240     const char* format, ...) {
241   VLOG(3) << "__xla_cpu_runtime_PrintfToStderr " << format;
242   va_list args;
243   va_start(args, format);
244   int result = vfprintf(stderr, format, args);
245   va_end(args);
246   return result;
247 }
248 
__xla_cpu_runtime_TracingStart(const void * run_options_ptr,const char * name)249 TF_ATTRIBUTE_NO_SANITIZE_MEMORY xla::int64 __xla_cpu_runtime_TracingStart(
250     const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
251     const char* name) {
252   VLOG(3) << "TracingStart " << name;
253   return tensorflow::profiler::TraceMe::ActivityStart(name);
254 }
255 
__xla_cpu_runtime_TracingEnd(const void * run_options_ptr,xla::int64 id)256 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_TracingEnd(
257     const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
258     xla::int64 id) {
259   VLOG(3) << "TracingEnd " << id;
260   tensorflow::profiler::TraceMe::ActivityEnd(id);
261 }
262 
263 }  // extern "C"
264 
265 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
__xla_cpu_runtime_AcquireInfeedBufferForDequeue(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,const void * shape,xla::int32 shape_length)266 __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
267     const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
268     const void* shape, xla::int32 shape_length) {
269   int device_ordinal = GetDeviceOrdinal(run_options);
270 
271   VLOG(2) << "AcquireInfeedBufferForDequeue: "
272           << ShapeString(shape, shape_length) << " on stream executor "
273           << device_ordinal;
274 
275   xla::cpu::runtime::XfeedManager* xfeed =
276       xla::cpu::runtime::GetXfeedManager(device_ordinal);
277   // Wait until there's a buffer to dequeue.
278   xla::cpu::runtime::XfeedBuffer* buffer =
279       xfeed->infeed()->BlockingDequeueBuffer();
280   CHECK_EQ(buffer->length(), buffer_length)
281       << "XLA program infeed request buffer size " << buffer_length
282       << " did not match the runtime's infed buffer length " << buffer->length()
283       << "; program reports desired shape: "
284       << ShapeString(shape, shape_length);
285   return buffer->data();
286 }
287 
288 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,void * buffer_ptr,const void * shape_ptr,xla::int32 shape_length)289 __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
290     const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
291     void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
292   int device_ordinal = GetDeviceOrdinal(run_options);
293 
294   VLOG(2) << "ReleaseInfeedBufferAfterDeque: "
295           << ShapeString(shape_ptr, shape_length) << " on stream executor "
296           << device_ordinal;
297 
298   xla::cpu::runtime::XfeedManager* xfeed =
299       xla::cpu::runtime::GetXfeedManager(device_ordinal);
300   xla::StatusOr<xla::Shape> shape =
301       DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
302   xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
303                                         std::move(shape));
304 }
305 
306 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,const void * shape_ptr,xla::int32 shape_length)307 __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
308     const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
309     const void* shape_ptr, xla::int32 shape_length) {
310   int device_ordinal = GetDeviceOrdinal(run_options);
311 
312   VLOG(2) << "AcquireOutfeedBufferForPopulation: "
313           << ShapeString(shape_ptr, shape_length) << " on stream executor "
314           << device_ordinal;
315 
316   xla::cpu::runtime::XfeedManager* xfeed =
317       xla::cpu::runtime::GetXfeedManager(device_ordinal);
318   // Wait until there's a buffer to dequeue.
319   xla::cpu::runtime::XfeedBuffer* buffer =
320       xfeed->outfeed()->BlockingDequeueBuffer();
321   CHECK_EQ(buffer->length(), buffer_length)
322       << "XLA program outfeed request buffer size " << buffer_length
323       << " did not match the runtime's outfeed buffer length "
324       << buffer->length() << "; program reports outfed shape: "
325       << ShapeString(shape_ptr, shape_length);
326   return buffer->data();
327 }
328 
329 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,void * buffer_ptr,const void * shape_ptr,xla::int32 shape_length)330 __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
331     const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
332     void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
333   int device_ordinal = GetDeviceOrdinal(run_options);
334 
335   VLOG(2) << "ReleaseOutfeedBufferAfterPopulation: "
336           << ShapeString(shape_ptr, shape_length) << " on stream executor "
337           << device_ordinal;
338 
339   xla::cpu::runtime::XfeedManager* xfeed =
340       xla::cpu::runtime::GetXfeedManager(device_ordinal);
341   xla::StatusOr<xla::Shape> shape =
342       DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
343   xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
344                                          std::move(shape));
345 }
346 
347 namespace {
348 
349 class CpuAllToAllRendezvous
350     : public xla::Rendezvous<AllToAllParticipantData, std::nullptr_t> {
351  public:
CpuAllToAllRendezvous(const xla::RendezvousKey & k)352   explicit CpuAllToAllRendezvous(const xla::RendezvousKey& k)
353       : xla::Rendezvous<AllToAllParticipantData, std::nullptr_t>(k) {}
354 
355  protected:
RunCollectiveOp(const AllToAllParticipantData &)356   xla::StatusOr<std::nullptr_t> RunCollectiveOp(
357       const AllToAllParticipantData& /*participant*/) override {
358     bool is_primary = InitializationBarrier();
359 
360     if (is_primary) {
361       tensorflow::mutex_lock lock(mu_);
362 
363       CHECK(!participants_.empty());
364       CHECK(!participants_[0].source_buffers.empty());
365       int expected_buffer_size = participants_[0].source_buffers[0].size();
366 
367       // Replica id -> position in participants_.
368       absl::flat_hash_map<int, int> replica_id_map;
369 
370       for (int pos = 0; pos < participants_.size(); pos++) {
371         const AllToAllParticipantData& p = participants_[pos];
372         CHECK_EQ(p.source_buffers.size(), p.destination_buffers.size());
373         CHECK_EQ(p.source_buffers.size(), participants_.size());
374         for (int i = 0; i < p.source_buffers.size(); i++) {
375           CHECK_EQ(p.destination_buffers[i].size(), expected_buffer_size);
376           CHECK_EQ(p.source_buffers[i].size(), expected_buffer_size);
377         }
378         replica_id_map[p.replica_id] = pos;
379       }
380 
381       const std::vector<int>& replica_ids_to_copy_to =
382           participants_[0].replica_ids_to_copy_to;
383 
384       // Replica id -> rank
385       absl::flat_hash_map<int, int> replica_ranks;
386       for (int rank = 0; rank < replica_ids_to_copy_to.size(); ++rank) {
387         int replica_id = replica_ids_to_copy_to[rank];
388         replica_ranks[replica_id] = rank;
389       }
390 
391       for (const AllToAllParticipantData& sender : participants_) {
392         VLOG(3) << "Processing AllToAll participant: " << sender.ToString();
393 
394         int rank = xla::FindOrDie(replica_ranks, sender.replica_id);
395 
396         for (int i = 0; i < participants_.size(); ++i) {
397           int replica_id = replica_ids_to_copy_to[i];
398           int participant_num = xla::FindOrDie(replica_id_map, replica_id);
399           AllToAllParticipantData& receiver = participants_[participant_num];
400 
401           std::memcpy(receiver.destination_buffers[rank].opaque(),
402                       sender.source_buffers[i].opaque(), expected_buffer_size);
403         }
404       }
405     }
406     return nullptr;
407   }
408 };
409 
410 class CpuCollectivePermuteRendezvous
411     : public xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t> {
412  public:
CpuCollectivePermuteRendezvous(const xla::RendezvousKey & k)413   explicit CpuCollectivePermuteRendezvous(const xla::RendezvousKey& k)
414       : xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t>(k) {}
415 
416  protected:
RunCollectiveOp(const CollectivePermuteParticipantData &)417   xla::StatusOr<std::nullptr_t> RunCollectiveOp(
418       const CollectivePermuteParticipantData& /*participant*/) override {
419     bool primary = InitializationBarrier();
420 
421     // Perform all copies from the primary thread.
422     if (primary) {
423       tensorflow::mutex_lock lock(mu_);
424 
425       std::map<int, int> replica_idx_to_participant_idx;
426       for (int p_idx = 0; p_idx < participants_.size(); p_idx++) {
427         replica_idx_to_participant_idx[participants_[p_idx].replica_id] = p_idx;
428       }
429 
430       for (auto& p : participants_) {
431         for (int dest_replica : p.replica_ids_to_copy_to) {
432           auto& dest_p = participants_[xla::FindOrDie(
433               replica_idx_to_participant_idx, dest_replica)];
434           std::memcpy(dest_p.destination_data.opaque(), p.source_data.opaque(),
435                       p.byte_size);
436 
437           // Each replica may be copied into only once.
438           replica_idx_to_participant_idx.erase(dest_replica);
439         }
440       }
441 
442       // Zero out untouched participants.
443       for (auto& replica_p : replica_idx_to_participant_idx) {
444         auto& p = participants_[replica_p.second];
445         std::memset(p.destination_data.opaque(), 0, p.byte_size);
446       }
447     }
448     return nullptr;
449   }
450 };
451 
452 class CpuAllReduceRendezvous
453     : public xla::Rendezvous<xla::AllReduceParticipantData, std::nullptr_t> {
454  public:
CpuAllReduceRendezvous(const xla::RendezvousKey & k)455   explicit CpuAllReduceRendezvous(const xla::RendezvousKey& k)
456       : xla::Rendezvous<xla::AllReduceParticipantData, std::nullptr_t>(k) {}
457 
458  protected:
RunCollectiveOp(const xla::AllReduceParticipantData & participant)459   xla::StatusOr<std::nullptr_t> RunCollectiveOp(
460       const xla::AllReduceParticipantData& participant) override {
461     xla::PrimitiveType datatype = participant.buffers.front().primitive_type;
462     bool primary = InitializationBarrier();
463 
464     if (primary) {
465       switch (datatype) {
466         case xla::S8:
467           DoAllReduce<xla::S8>(participant);
468           break;
469         case xla::PRED:
470         case xla::U8:
471           DoAllReduce<xla::U8>(participant);
472           break;
473         case xla::S32:
474           DoAllReduce<xla::S32>(participant);
475           break;
476         case xla::U32:
477           DoAllReduce<xla::U32>(participant);
478           break;
479         case xla::S64:
480           DoAllReduce<xla::S64>(participant);
481           break;
482         case xla::U64:
483           DoAllReduce<xla::U64>(participant);
484           break;
485         case xla::F16:
486           DoAllReduce<xla::F16>(participant);
487           break;
488         case xla::F32:
489           DoAllReduce<xla::F32>(participant);
490           break;
491         case xla::F64:
492           DoAllReduce<xla::F64>(participant);
493           break;
494         default:
495           LOG(FATAL) << "Unexpected datatype;";
496       }
497     }
498     return nullptr;
499   }
500 
501  private:
502   template <xla::PrimitiveType PT>
DoAllReduce(xla::AllReduceParticipantData participant)503   void DoAllReduce(xla::AllReduceParticipantData participant) {
504     using T = typename xla::primitive_util::PrimitiveTypeToNative<PT>::type;
505     tensorflow::mutex_lock lock(mu_);
506     CHECK(!participants_.empty());
507     xla::ReductionKind reduction_kind = participant.reduction_kind;
508     for (const auto& p : participants_) {
509       CHECK(p.reduction_kind == reduction_kind);
510     }
511     int num_participants = participants_.size();
512 
513     // participant_idx -> buffer_idx -> buffer.
514     std::vector<std::vector<absl::Span<T>>> input_buffers;
515     std::vector<std::vector<absl::Span<T>>> output_buffers;
516     input_buffers.reserve(num_participants);
517     output_buffers.reserve(num_participants);
518     const xla::AllReduceParticipantData& first_participant =
519         participants_.front();
520 
521     int buffers_per_participant = first_participant.buffers.size();
522     for (xla::AllReduceParticipantData& p : participants_) {
523       CHECK_EQ(p.buffers.size(), buffers_per_participant);
524 
525       input_buffers.emplace_back();
526       output_buffers.emplace_back();
527       std::vector<absl::Span<T>>& participant_input_buffers =
528           input_buffers.back();
529       std::vector<absl::Span<T>>& participant_output_buffers =
530           output_buffers.back();
531       participant_input_buffers.reserve(p.buffers.size());
532       participant_output_buffers.reserve(p.buffers.size());
533 
534       for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
535            buffer_idx++) {
536         auto& participant_buffer = p.buffers[buffer_idx];
537         participant_input_buffers.emplace_back(
538             static_cast<T*>(participant_buffer.source_data.opaque()),
539             participant_buffer.element_count);
540         participant_output_buffers.emplace_back(
541             static_cast<T*>(participant_buffer.destination_data.opaque()),
542             participant_buffer.element_count);
543         CHECK_EQ(participant_buffer.element_count,
544                  first_participant.buffers[buffer_idx].element_count);
545       }
546     }
547 
548     for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
549          buffer_idx++) {
550       int element_count = first_participant.buffers[buffer_idx].element_count;
551       for (int idx = 0; idx < element_count; idx++) {
552         T out = GetInitialValue<T>(reduction_kind);
553         for (int participant_idx = 0; participant_idx < participants_.size();
554              participant_idx++) {
555           out = PerformReductionStep<T>(
556               reduction_kind, out,
557               input_buffers[participant_idx][buffer_idx][idx]);
558         }
559         for (int participant_idx = 0; participant_idx < participants_.size();
560              participant_idx++) {
561           output_buffers[participant_idx][buffer_idx][idx] = out;
562         }
563       }
564     }
565   }
566 
567   template <typename T>
GetInitialValue(xla::ReductionKind reduction_kind)568   T GetInitialValue(xla::ReductionKind reduction_kind) {
569     switch (reduction_kind) {
570       case xla::ReductionKind::SUM:
571         return static_cast<T>(0);
572       case xla::ReductionKind::PRODUCT:
573         return static_cast<T>(1);
574       case xla::ReductionKind::MIN:
575         return std::numeric_limits<T>::max();
576       case xla::ReductionKind::MAX:
577         return std::numeric_limits<T>::min();
578     }
579   }
580 
581   template <typename T>
PerformReductionStep(xla::ReductionKind reduction_kind,T a,T b)582   T PerformReductionStep(xla::ReductionKind reduction_kind, T a, T b) {
583     switch (reduction_kind) {
584       case xla::ReductionKind::SUM:
585         return a + b;
586       case xla::ReductionKind::PRODUCT:
587         return a * b;
588       case xla::ReductionKind::MIN:
589         return std::min(a, b);
590       case xla::ReductionKind::MAX:
591         return std::max(a, b);
592     }
593   }
594 };
595 
596 xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>&
GlobalAllReduceRendezvousMap()597 GlobalAllReduceRendezvousMap() {
598   static auto& m =
599       *new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>;
600   return m;
601 }
602 
603 xla::RefcountingHashMap<xla::RendezvousKey, CpuCollectivePermuteRendezvous>&
GlobalCollectivePermuteRendezvousMap()604 GlobalCollectivePermuteRendezvousMap() {
605   static auto& m = *new xla::RefcountingHashMap<xla::RendezvousKey,
606                                                 CpuCollectivePermuteRendezvous>;
607   return m;
608 }
609 
610 xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>&
GlobalAllToAllRendezvousMap()611 GlobalAllToAllRendezvousMap() {
612   static auto& m =
613       *new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>;
614   return m;
615 }
616 
GetRendezvousKey(const xla::ExecutableRunOptions * run_options,std::vector<xla::ReplicaGroup> group,xla::int32 channel_id_present,xla::int64 op_id)617 xla::RendezvousKey GetRendezvousKey(
618     const xla::ExecutableRunOptions* run_options,
619     std::vector<xla::ReplicaGroup> group, xla::int32 channel_id_present,
620     xla::int64 op_id) {
621   const xla::DeviceAssignment& device_assignment =
622       *run_options->device_assignment();
623   int device_ordinal = GetDeviceOrdinal(run_options);
624   xla::RendezvousKey::CollectiveOpKind op_kind =
625       channel_id_present ? xla::RendezvousKey::kCrossModule
626                          : xla::RendezvousKey::kCrossReplica;
627   std::vector<xla::GlobalDeviceId> participating_devices =
628       xla::GetParticipatingDevices(xla::GlobalDeviceId(device_ordinal),
629                                    device_assignment, group,
630                                    xla::CollectiveOpGroupMode::kCrossReplica)
631           .ValueOrDie();
632   int num_local_participants = participating_devices.size();
633   return xla::RendezvousKey{run_options->run_id(),
634                             std::move(participating_devices),
635                             num_local_participants, op_kind, op_id};
636 }
637 
638 }  // namespace
639 
__xla_cpu_runtime_AllToAll(const xla::ExecutableRunOptions * run_options,xla::int32 channel_id_present,xla::int64 op_id,const void * replica_groups_str,xla::int32 replica_groups_str_size,xla::int32 num_buffers,xla::int64 buffer_size,void ** source_buffers,void ** destination_buffers)640 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll(
641     const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
642     xla::int64 op_id, const void* replica_groups_str,
643     xla::int32 replica_groups_str_size, xla::int32 num_buffers,
644     xla::int64 buffer_size, void** source_buffers, void** destination_buffers) {
645   int device_ordinal = GetDeviceOrdinal(run_options);
646   xla::int32 replica_id =
647       run_options->device_assignment()
648           ->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
649           .ValueOrDie();
650   absl::string_view replica_groups_serialized(
651       static_cast<const char*>(replica_groups_str), replica_groups_str_size);
652   std::vector<xla::ReplicaGroup> group =
653       xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
654   xla::RendezvousKey rendezvous_key =
655       GetRendezvousKey(run_options, group, channel_id_present, op_id);
656 
657   AllToAllParticipantData participant(rendezvous_key, device_ordinal,
658                                       run_options->stream());
659   participant.replica_id = replica_id;
660   participant.replica_ids_to_copy_to =
661       xla::GetParticipatingIDs(
662           replica_id, run_options->device_assignment()->replica_count(), group)
663           .ValueOrDie();
664   for (int i = 0; i < num_buffers; i++) {
665     participant.source_buffers.emplace_back(source_buffers[i], buffer_size);
666     participant.destination_buffers.emplace_back(destination_buffers[i],
667                                                  buffer_size);
668   }
669   auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
670     return absl::make_unique<CpuAllToAllRendezvous>(k);
671   };
672   TF_CHECK_OK(CpuAllToAllRendezvous::SubmitParticipant(
673                   [&] {
674                     return GlobalAllToAllRendezvousMap().GetOrCreateIfAbsent(
675                         rendezvous_key, make_cpu_rendezvous);
676                   },
677                   participant)
678                   .status());
679 }
680 
__xla_cpu_runtime_AllReduce(const xla::ExecutableRunOptions * run_options,const void * replica_groups_str,xla::int32 replica_groups_str_size,xla::int32 channel_id_present,xla::int64 op_id,xla::int32 reduction_kind,const void * shape_ptr,xla::int32 shape_length,xla::int32 num_buffers,void ** input_buffers,void ** output_buffers)681 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
682     const xla::ExecutableRunOptions* run_options,
683     const void* replica_groups_str, xla::int32 replica_groups_str_size,
684     xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
685     const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers,
686     void** input_buffers, void** output_buffers) {
687   int device_ordinal = GetDeviceOrdinal(run_options);
688   absl::string_view replica_groups_serialized(
689       static_cast<const char*>(replica_groups_str), replica_groups_str_size);
690   std::vector<xla::ReplicaGroup> group =
691       xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
692   xla::RendezvousKey rendezvous_key =
693       GetRendezvousKey(run_options, group, channel_id_present, op_id);
694   auto shape_str = ShapeString(shape_ptr, shape_length);
695   VLOG(2) << "All-reduce input/output shape : " << shape_str;
696 
697   xla::Shape shape =
698       DecodeSelfDescribingShapeConstant(shape_ptr, shape_length).ValueOrDie();
699 
700   CHECK((num_buffers > 1 && shape.IsTuple()) ||
701         (num_buffers == 1 && xla::LayoutUtil::IsDenseArray(shape)));
702 
703   xla::AllReduceParticipantData participant(rendezvous_key, device_ordinal,
704                                             run_options->stream());
705   participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind);
706   for (int i = 0; i < num_buffers; i++) {
707     xla::Shape subshape = num_buffers == 1 ? shape : shape.tuple_shapes(i);
708     xla::AllReduceParticipantData::Buffer buffer;
709     buffer.element_count = xla::ShapeUtil::ElementsIn(subshape);
710     buffer.primitive_type = subshape.element_type();
711     buffer.source_data = se::DeviceMemoryBase(
712         input_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
713     buffer.destination_data = se::DeviceMemoryBase(
714         output_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
715     participant.buffers.push_back(buffer);
716   }
717 
718   auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
719     return absl::make_unique<CpuAllReduceRendezvous>(k);
720   };
721 
722   TF_CHECK_OK(CpuAllReduceRendezvous::SubmitParticipant(
723                   [&] {
724                     return GlobalAllReduceRendezvousMap().GetOrCreateIfAbsent(
725                         rendezvous_key, make_cpu_rendezvous);
726                   },
727                   participant)
728                   .status());
729 }
730 
__xla_cpu_runtime_ReplicaId(const xla::ExecutableRunOptions * run_options,void * output_buffer)731 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId(
732     const xla::ExecutableRunOptions* run_options, void* output_buffer) {
733   int device_ordinal = GetDeviceOrdinal(run_options);
734   xla::int32 replica_id =
735       run_options->device_assignment()
736           ->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
737           .ValueOrDie();
738   std::memcpy(output_buffer, &replica_id, 4);
739 }
740 
__xla_cpu_runtime_CollectivePermute(const xla::ExecutableRunOptions * run_options,xla::int32 channel_id_present,xla::int64 op_id,xla::int32 byte_size,void * input_buffer,void * output_buffer,const void * source_target_pairs,xla::int32 source_target_pairs_size)741 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_CollectivePermute(
742     const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
743     xla::int64 op_id, xla::int32 byte_size, void* input_buffer,
744     void* output_buffer, const void* source_target_pairs,
745     xla::int32 source_target_pairs_size) {
746   int device_ordinal = GetDeviceOrdinal(run_options);
747   absl::string_view source_target_pairs_serialized(
748       static_cast<const char*>(source_target_pairs), source_target_pairs_size);
749   auto pairs = absl::StrSplit(source_target_pairs_serialized, ',');
750   xla::int32 replica_id =
751       run_options->device_assignment()
752           ->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
753           .ValueOrDie();
754   std::vector<int> copy_to;
755   for (auto& p : pairs) {
756     std::vector<std::string> mapping = absl::StrSplit(p, '=');
757     CHECK_EQ(mapping.size(), 2);
758     int from = std::stoi(mapping[0]);
759     int to = std::stoi(mapping[1]);
760     if (from == replica_id) {
761       copy_to.push_back(to);
762     }
763   }
764   xla::RendezvousKey rendezvous_key =
765       GetRendezvousKey(run_options, {}, channel_id_present, op_id);
766 
767   CollectivePermuteParticipantData participant(rendezvous_key, device_ordinal,
768                                                run_options->stream());
769   participant.replica_id = replica_id;
770   participant.source_data = se::DeviceMemoryBase(input_buffer, byte_size);
771   participant.destination_data = se::DeviceMemoryBase(output_buffer, byte_size);
772   participant.replica_ids_to_copy_to = copy_to;
773   participant.byte_size = byte_size;
774 
775   auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
776     return absl::make_unique<CpuCollectivePermuteRendezvous>(k);
777   };
778   TF_CHECK_OK(
779       CpuCollectivePermuteRendezvous::SubmitParticipant(
780           [&] {
781             return GlobalCollectivePermuteRendezvousMap().GetOrCreateIfAbsent(
782                 rendezvous_key, make_cpu_rendezvous);
783           },
784           participant)
785           .status());
786 }
787