• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Generally useful utility functions that are common to (not specific to any
17 // given part of) the XLA code base.
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_UTIL_H_
20 #define TENSORFLOW_COMPILER_XLA_UTIL_H_
21 
22 #include <algorithm>
23 #include <array>
24 #include <functional>
25 #include <limits>
26 #include <string>
27 #include <type_traits>
28 #include <utility>
29 #include <vector>
30 
31 #include "absl/algorithm/container.h"
32 #include "absl/base/thread_annotations.h"
33 #include "absl/container/inlined_vector.h"
34 #include "absl/numeric/bits.h"
35 #include "absl/strings/str_cat.h"
36 #include "absl/strings/str_format.h"
37 #include "absl/strings/string_view.h"
38 #include "absl/types/span.h"
39 #include "tensorflow/compiler/xla/status.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/errors.h"  // IWYU pragma: keep
43 #include "tensorflow/core/lib/math/math_util.h"
44 
45 namespace xla {
46 
47 // Converts the unsigned integer n into a mixed-radix representation with the
48 // given bounds (radices). More precisely, if there are K radices, then the
49 // returned vector digits has K entries and satisfies
50 //
51 //   0 <= digits[i] < bounds[i],  for i = 0, ..., K - 1
52 //
53 // and FromMixedRadix(digits) == n. The mixed radix representation is unique
54 // modulo the product of the entries of bounds.
55 std::vector<int64_t> ToMixedRadix(int64_t n, absl::Span<const int64_t> bounds);
56 
57 // Logs the provided status message with a backtrace.
58 //
59 // For use by Status-factories, logs a backtrace at the point where the status
60 // is created, such that we can use --vmodule=util=1 to see all status
61 // creation backtraces.
62 Status WithLogBacktrace(const Status& status);
63 
64 // Ranks greater than 6 are very rare, so use InlinedVector<int64_t, 6> to store
65 // the bounds and indices. And for the rare cases of ranks greater than 6,
66 // the InlinedVector will just behave like an std::vector<> and allocate the
67 // memory to store its values.
InlineRank()68 inline constexpr int InlineRank() { return 6; }
69 using DimensionVector = absl::InlinedVector<int64_t, InlineRank()>;
70 using DimLevelTypeVector = absl::InlinedVector<DimLevelType, InlineRank()>;
71 
72 // RAII timer that logs with a given label the wall clock time duration in human
73 // readable form. This differs from base's ElapsedTimer primarily in that it
74 // spits out the human-readable duration form.
75 //
76 // Keeps track of global maximum and cumulative times across all invocations.
77 //
78 // By default, the timing traces are only printed at VLOG(1) and above:
79 //
80 //   XLA_SCOPED_LOGGING_TIMER("fooing bar");  // nop if !VLOG_IS_ON(1).
81 //
82 // but you can control this via:
83 //
84 //   XLA_SCOPED_LOGGING_TIMER_LEVEL("fooing bar", 2);  // nop if !VLOG_IS_ON(2)
85 //
86 #define XLA_SCOPED_LOGGING_TIMER(label) \
87   XLA_SCOPED_LOGGING_TIMER_HELPER(label, 1, __COUNTER__)
88 #define XLA_SCOPED_LOGGING_TIMER_LEVEL(label, level) \
89   XLA_SCOPED_LOGGING_TIMER_HELPER(label, level, __COUNTER__)
90 
91 // Helper for implementing macros above.  Do not use directly.
92 //
93 // Forces the evaluation of "counter", which we expect is equal to __COUNTER__.
94 #define XLA_SCOPED_LOGGING_TIMER_HELPER(label, level, counter) \
95   XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter)
96 
97 // Helper for macros above.  Don't use directly.
98 #define XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter)      \
99   static ::xla::TimerStats XLA_TimerStats##counter;                  \
100   ::xla::ScopedLoggingTimer XLA_ScopedLoggingTimerInstance##counter( \
101       label, /*enabled=*/VLOG_IS_ON(level), __FILE__, __LINE__,      \
102       &XLA_TimerStats##counter);
103 
104 struct TimerStats {
105   absl::Mutex stats_mutex;
106   double cumulative_secs ABSL_GUARDED_BY(stats_mutex) = 0;
107   double max_secs ABSL_GUARDED_BY(stats_mutex) = 0;
108   uint64_t times_called ABSL_GUARDED_BY(stats_mutex) = 0;
109 };
110 
111 // RAII timer for XLA_SCOPED_LOGGING_TIMER and XLA_SCOPED_LOGGING_TIMER_LEVEL
112 // macros above.  Recommended usage is via the macros so you don't have to give
113 // the timer a name or worry about calling VLOG_IS_ON yourself.
114 class ScopedLoggingTimer {
115  public:
116   // label: Label to display for logging.
117   // enabled: Whether this timer should do anything at all.
118   // file: Filename to display in logging.
119   // line: Line number to display in logging.
120   // `timer_stats`: unowned non-null pointer which is used to populate the
121   // global timer statistics.
122   ScopedLoggingTimer(absl::string_view label, bool enabled, const char* file,
123                      int line, TimerStats* timer_stats);
124 
125   // Stop the timer and log the tracked time. Timer is disabled after this
126   // function is called.
127   void StopAndLog();
128 
129   ~ScopedLoggingTimer();
130 
131  private:
132   const std::string label_;
133   const char* const file_;
134   const int line_;
135   TimerStats* const timer_stats_;
136   uint64_t start_micros_;
137   bool enabled_;
138 };
139 
140 // Given a vector<T>, returns a Span<char> that points at its
141 // internals.
142 //
143 // Warning: if the vector is updated its storage pointer may change, so use this
144 // with caution (ideally in limited scopes with temporary lifetimes).
145 template <typename T>
MutableByteSlice(std::vector<T> * v)146 absl::Span<uint8_t> MutableByteSlice(std::vector<T>* v) {
147   return absl::Span<uint8_t>(reinterpret_cast<uint8_t*>(v->data()),
148                              v->size() * sizeof(T));
149 }
150 
151 // Turns an immutable slice of type T into an immutable slice of bytes with the
152 // same byte size.
153 template <typename T>
CastToByteSlice(absl::Span<const T> slice)154 absl::Span<const uint8_t> CastToByteSlice(absl::Span<const T> slice) {
155   return absl::Span<const uint8_t>(
156       reinterpret_cast<const uint8_t*>(slice.data()), slice.size() * sizeof(T));
157 }
158 
159 // Casts a byte slice to a non-byte type T, checking that the original slice
160 // length is a multiple of sizeof(T).
161 template <typename T>
CastByteSlice(absl::Span<const uint8_t> slice)162 absl::Span<const T> CastByteSlice(absl::Span<const uint8_t> slice) {
163   CHECK_EQ(0, slice.size() % sizeof(T));
164   return absl::Span<const T>(reinterpret_cast<const T*>(slice.data()),
165                              slice.size() / sizeof(T));
166 }
167 
168 // Compares two containers for equality. Returns true iff the two containers
169 // have the same size and all their elements compare equal using their
170 // operator==. Like std::equal, but forces size equality.
171 template <typename Container1T,
172           typename ElementType = typename Container1T::value_type>
ContainersEqual(const Container1T & c1,std::initializer_list<ElementType> il)173 bool ContainersEqual(const Container1T& c1,
174                      std::initializer_list<ElementType> il) {
175   absl::Span<const ElementType> c2{il};
176   return absl::c_equal(c1, c2);
177 }
178 
179 #if defined(__cpp_lib_to_underlying) && __cpp_lib_to_underlying >= 202102L
180 using to_underlying = std::to_underlying;
181 #else
182 // Helper function which implements C++23's std::to_underlying.
183 template <typename T>
to_underlying(T value)184 constexpr std::underlying_type_t<T> to_underlying(T value) noexcept {
185   return static_cast<std::underlying_type_t<T>>(value);
186 }
187 #endif
188 
189 // Performs a copy of count values from src to dest, using different strides for
190 // source and destination. The source starting index is src_base, while the
191 // destination one is dest_base.
192 template <typename D, typename S>
StridedCopy(absl::Span<D> dest,int64_t dest_base,int64_t dest_stride,absl::Span<const S> src,int64_t src_base,int64_t src_stride,int64_t count)193 void StridedCopy(absl::Span<D> dest, int64_t dest_base, int64_t dest_stride,
194                  absl::Span<const S> src, int64_t src_base, int64_t src_stride,
195                  int64_t count) {
196   for (; count > 0; --count, dest_base += dest_stride, src_base += src_stride) {
197     dest[dest_base] = static_cast<D>(src[src_base]);
198   }
199 }
200 
201 // Adds some context information to the error message in a
202 // Status.  This is useful as Statuses are
203 // propagated upwards.
204 Status AddStatus(Status prior, absl::string_view context);
205 Status AppendStatus(Status prior, absl::string_view context);
206 
207 // Status error shorthands -- StrFormat's the arguments to be used as an error
208 // message and returns a status in the canonical error space.
209 template <typename... Args>
InvalidArgument(const absl::FormatSpec<Args...> & format,const Args &...args)210 Status InvalidArgument(const absl::FormatSpec<Args...>& format,
211                        const Args&... args) {
212   return WithLogBacktrace(
213       tensorflow::errors::InvalidArgument(absl::StrFormat(format, args...)));
214 }
215 template <typename... Args>
Unimplemented(const absl::FormatSpec<Args...> & format,const Args &...args)216 Status Unimplemented(const absl::FormatSpec<Args...>& format,
217                      const Args&... args) {
218   return WithLogBacktrace(
219       tensorflow::errors::Unimplemented(absl::StrFormat(format, args...)));
220 }
221 template <typename... Args>
InternalError(const absl::FormatSpec<Args...> & format,const Args &...args)222 Status InternalError(const absl::FormatSpec<Args...>& format,
223                      const Args&... args) {
224   return WithLogBacktrace(
225       tensorflow::errors::Internal(absl::StrFormat(format, args...)));
226 }
227 template <typename... Args>
FailedPrecondition(const absl::FormatSpec<Args...> & format,const Args &...args)228 Status FailedPrecondition(const absl::FormatSpec<Args...>& format,
229                           const Args&... args) {
230   return WithLogBacktrace(
231       tensorflow::errors::FailedPrecondition(absl::StrFormat(format, args...)));
232 }
233 template <typename... Args>
Cancelled(const absl::FormatSpec<Args...> & format,const Args &...args)234 Status Cancelled(const absl::FormatSpec<Args...>& format, const Args&... args) {
235   return WithLogBacktrace(
236       tensorflow::errors::Cancelled(absl::StrFormat(format, args...)));
237 }
238 template <typename... Args>
ResourceExhausted(const absl::FormatSpec<Args...> & format,const Args &...args)239 Status ResourceExhausted(const absl::FormatSpec<Args...>& format,
240                          const Args&... args) {
241   return WithLogBacktrace(
242       tensorflow::errors::ResourceExhausted(absl::StrFormat(format, args...)));
243 }
244 template <typename... Args>
NotFound(const absl::FormatSpec<Args...> & format,const Args &...args)245 Status NotFound(const absl::FormatSpec<Args...>& format, const Args&... args) {
246   return WithLogBacktrace(
247       tensorflow::errors::NotFound(absl::StrFormat(format, args...)));
248 }
249 template <typename... Args>
Unavailable(const absl::FormatSpec<Args...> & format,const Args &...args)250 Status Unavailable(const absl::FormatSpec<Args...>& format,
251                    const Args&... args) {
252   return WithLogBacktrace(
253       tensorflow::errors::Unavailable(absl::StrFormat(format, args...)));
254 }
255 template <typename... Args>
Unknown(const absl::FormatSpec<Args...> & format,const Args &...args)256 Status Unknown(const absl::FormatSpec<Args...>& format, const Args&... args) {
257   return WithLogBacktrace(
258       tensorflow::errors::Unknown(absl::StrFormat(format, args...)));
259 }
260 template <typename... Args>
Internal(const absl::FormatSpec<Args...> & format,const Args &...args)261 Status Internal(const absl::FormatSpec<Args...>& format, const Args&... args) {
262   return WithLogBacktrace(
263       tensorflow::errors::Internal(absl::StrFormat(format, args...)));
264 }
265 
266 template <typename... Args>
InvalidArgumentStrCat(Args &&...concat)267 Status InvalidArgumentStrCat(Args&&... concat) {
268   return InvalidArgument("%s", absl::StrCat(std::forward<Args>(concat)...));
269 }
270 
271 template <typename... Args>
UnimplementedStrCat(Args &&...concat)272 Status UnimplementedStrCat(Args&&... concat) {
273   return Unimplemented("%s", absl::StrCat(std::forward<Args>(concat)...));
274 }
275 
276 template <typename... Args>
InternalErrorStrCat(Args &&...concat)277 Status InternalErrorStrCat(Args&&... concat) {
278   return InternalError("%s", absl::StrCat(std::forward<Args>(concat)...));
279 }
280 
281 template <typename... Args>
ResourceExhaustedStrCat(Args &&...concat)282 Status ResourceExhaustedStrCat(Args&&... concat) {
283   return ResourceExhausted("%s", absl::StrCat(std::forward<Args>(concat)...));
284 }
285 
286 // Splits the lines of the original, replaces leading whitespace with the prefix
287 // given by "indentation", and returns the string joined by newlines again. As a
288 // side effect, any additional trailing whitespace is removed.
289 //
290 // Note: even different amounts of leading whitespace on different lines will be
291 // uniformly replaced with "indentation".
292 std::string Reindent(absl::string_view original, absl::string_view indentation);
293 
294 template <typename Container>
PositionInContainer(const Container & container,int64_t value)295 int64_t PositionInContainer(const Container& container, int64_t value) {
296   return std::distance(container.begin(), absl::c_find(container, value));
297 }
298 
299 // Formats the container as a comma-separated string. StrAppend must support
300 // appending the elements of the container. Prefix is prepended and suffix is
301 // appended to the returned string.
302 template <typename Container>
303 std::string CommaSeparatedString(const Container& c, const char* prefix = "",
304                                  const char* suffix = "") {
305   // Not using Join() since the implementation here is simple anyway and this
306   // avoids copying the string to append prefix.
307   std::string comma_separated = prefix;
308   const char* separator = "";
309   for (const auto& entry : c) {
310     absl::StrAppend(&comma_separated, separator, entry);
311     separator = ", ";
312   }
313   comma_separated += suffix;
314   return comma_separated;
315 }
316 
317 // Overload needed to allow the container to be an initializer list. The default
318 // type for T makes an empty initializer list work as well.
319 template <typename T = int>
320 std::string CommaSeparatedString(const std::initializer_list<T>& c,
321                                  const char* prefix = "",
322                                  const char* suffix = "") {
323   return CommaSeparatedString<std::initializer_list<T>>(c, prefix, suffix);
324 }
325 
326 // Formats the container in the mathematical notation for a vector, e.g. (1, 3,
327 // 7). StrAppend must support appending the elements of c.
328 template <typename Container>
VectorString(const Container & c)329 std::string VectorString(const Container& c) {
330   return CommaSeparatedString(c, "(", ")");
331 }
332 
333 // Overload needed to allow the container to be an initializer list. The default
334 // type for T makes an empty initializer list work as well.
335 template <typename T = int>
VectorString(const std::initializer_list<T> & c)336 std::string VectorString(const std::initializer_list<T>& c) {
337   return VectorString<std::initializer_list<T>>(c);
338 }
339 
340 // Returns a string which can losslessly round trip to a bfloat.
341 std::string RoundTripFpToString(tensorflow::bfloat16 value);
342 
343 // Returns a string which can losslessly round trip to a fp16.
344 std::string RoundTripFpToString(Eigen::half value);
345 
346 // Returns a string which can losslessly round trip to a float.
347 std::string RoundTripFpToString(float value);
348 
349 // Returns a string which can losslessly round trip to a double.
350 std::string RoundTripFpToString(double value);
351 
352 // Returns a PaddingConfig object that represents no padding for the given rank.
353 PaddingConfig MakeNoPaddingConfig(int64_t rank);
354 
355 // Returns a PaddingConfig object where 'padding' contains
356 // (low edge padding, high edge padding) pairs for each dimension.
357 PaddingConfig MakeEdgePaddingConfig(
358     absl::Span<const std::pair<int64_t, int64_t>> padding);
359 
360 // Returns true if the padding configuration has at least one dimension with
361 // non-zero interior padding.
362 bool HasInteriorPadding(const PaddingConfig& config);
363 
364 // Imports the templated FloorOfRatio math function from the TensorFlow
365 // namespace, as it is very commonly used.
366 template <typename T>
FloorOfRatio(T dividend,T divisor)367 T FloorOfRatio(T dividend, T divisor) {
368   return tensorflow::MathUtil::FloorOfRatio<T>(dividend, divisor);
369 }
370 
371 // Imports the templated CeilOfRatio math function from the TensorFlow
372 // namespace, as it is very commonly used.
373 template <typename T>
CeilOfRatio(T dividend,T divisor)374 T CeilOfRatio(T dividend, T divisor) {
375   return tensorflow::MathUtil::CeilOfRatio<T>(dividend, divisor);
376 }
377 
378 // Rounds the value up to a multiple of the divisor by first calling CeilOfRatio
379 // then multiplying by the divisor. For example: RoundUpTo(13, 8) => 16
380 template <typename T>
RoundUpTo(T value,T divisor)381 T RoundUpTo(T value, T divisor) {
382   return CeilOfRatio(value, divisor) * divisor;
383 }
384 
385 // Rounds the value down to a multiple of the divisor by first calling
386 // FloorOfRatio then multiplying by the divisor. For example:
387 // RoundDownTo(13, 8) => 8
388 template <typename T>
RoundDownTo(T value,T divisor)389 T RoundDownTo(T value, T divisor) {
390   return FloorOfRatio(value, divisor) * divisor;
391 }
392 
393 template <typename T>
394 struct DivMod {
395   T quotient;
396   T modulo;
397 };
398 
399 // Divide `dividend` by `divisor` such that the quotient is rounded towards
400 // negative infinity. The remainder will have the same sign as `divisor`.
401 template <typename T>
FloorDivMod(T dividend,T divisor)402 DivMod<T> FloorDivMod(T dividend, T divisor) {
403   DivMod<T> div_mod;
404   div_mod.quotient = FloorOfRatio(dividend, divisor);
405   div_mod.modulo = dividend - div_mod.quotient * divisor;
406   return div_mod;
407 }
408 
409 // Given a number of flops executed in an amount of time, produces a string that
410 // represents the throughput;
411 // e.g. HumanReadableNumFlops(1e9, 1e9) => 1.00GFLOP/s.
412 std::string HumanReadableNumFlops(double flops, double nanoseconds);
413 
414 // Given a number of transcendental ops executed in an amount of time, produces
415 // a string that represents the throughput;
416 // e.g. HumanReadableNumTranscendentalOps(1e9, 1e9) => 1.00GTROP/s.
417 std::string HumanReadableNumTranscendentalOps(double trops, double nanoseconds);
418 
419 // Split the text into multiple lines and log each line with the given
420 // severity, filename, and line number.
421 void LogLines(int sev, absl::string_view text, const char* fname, int lineno);
422 
423 // Returns a mask with "width" number of least significant bits set.
424 template <typename T>
LsbMask(int width)425 constexpr inline T LsbMask(int width) {
426   static_assert(std::is_unsigned<T>::value,
427                 "T should be an unsigned integer type");
428   ABSL_ASSERT(width >= 0);
429   ABSL_ASSERT(width <= std::numeric_limits<T>::digits);
430   return width == 0
431              ? 0
432              : static_cast<T>(-1) >> (std::numeric_limits<T>::digits - width);
433 }
434 
435 // Return floor(log2(n)) for positive integer n.  Returns -1 iff n == 0.
436 template <typename T>
Log2Floor(T x)437 constexpr inline int Log2Floor(T x) {
438   static_assert(std::is_unsigned<T>::value,
439                 "T should be an unsigned integer type");
440   return absl::bit_width(x) - 1;
441 }
442 
443 // Return ceiling(log2(n)) for positive integer n.  Returns -1 iff n == 0.
444 template <typename T>
Log2Ceiling(T x)445 constexpr inline int Log2Ceiling(T x) {
446   static_assert(std::is_unsigned<T>::value,
447                 "T should be an unsigned integer type");
448   return x == 0 ? -1 : absl::bit_width(x - 1);
449 }
450 
451 // Return the number of sign bits (i.e. the number of leading ones for negative
452 // numbers and the number of leading zeros for non-negative numbers).
453 template <typename T>
CountLeadingSignBits(T x)454 constexpr inline int CountLeadingSignBits(T x) {
455   static_assert(std::is_signed<T>::value, "T should be a signed integer type");
456   using UnsignedType = std::make_unsigned_t<T>;
457   return x < T{0} ? absl::countl_one<UnsignedType>(x)
458                   : absl::countl_zero<UnsignedType>(x);
459 }
460 
461 // Returns `value` with the low `width` bits set and the remaining bits set to
462 // zero.
463 template <typename T>
KeepLowerBits(T value,int width)464 constexpr inline T KeepLowerBits(T value, int width) {
465   return value & LsbMask<T>(width);
466 }
467 
468 // Returns `base` multiplied by itself `exponent` number of times.
469 //
470 // Note: returns 1 when `exponent` is zero.
471 // Precondition: `exponent` is non-negative.
472 template <typename T>
IPow(T base,int exponent)473 constexpr T IPow(T base, int exponent) {
474   // A negative `exponent` is indicative of a logic bug for integral `base`.
475   // We disallow it for floating-point types for symmetry.
476   ABSL_ASSERT(exponent >= 0);
477   // We use the right-to-left binary exponentiation algorithm.
478   T result{1};
479   while (exponent > 0) {
480     if ((exponent & 1) != 0) {
481       result *= base;
482     }
483     base *= base;
484     exponent >>= 1;
485   }
486   return result;
487 }
488 
489 template <size_t>
490 struct UnsignedIntegerTypeForSize;
491 
492 template <>
493 struct UnsignedIntegerTypeForSize<1> {
494   using type = uint8_t;
495 };
496 
497 template <>
498 struct UnsignedIntegerTypeForSize<2> {
499   using type = uint16_t;
500 };
501 
502 template <>
503 struct UnsignedIntegerTypeForSize<4> {
504   using type = uint32_t;
505 };
506 
507 template <>
508 struct UnsignedIntegerTypeForSize<8> {
509   using type = uint64_t;
510 };
511 
512 template <size_t N>
513 struct SignedIntegerTypeForSize {
514   using type = std::make_signed_t<typename UnsignedIntegerTypeForSize<N>::type>;
515 };
516 
517 // Returns the signed magnitude of T.
518 template <typename T>
519 typename SignedIntegerTypeForSize<sizeof(T)>::type ToSignMagnitude(T input) {
520   auto as_bits =
521       absl::bit_cast<typename SignedIntegerTypeForSize<sizeof(T)>::type>(input);
522   auto sign_mask =
523       absl::bit_cast<typename UnsignedIntegerTypeForSize<sizeof(T)>::type>(
524           tensorflow::MathUtil::Sign(as_bits));
525   return as_bits ^ (sign_mask >> 1);
526 }
527 
528 template <typename T>
529 constexpr int NanPayloadBits() {
530   // Floating point types with NaNs have payloads.
531   if (!std::numeric_limits<T>::has_quiet_NaN) {
532     return 0;
533   }
534   return std::numeric_limits<T>::digits - 1;
535 }
536 
537 template <typename T>
538 constexpr uint64_t QuietNanWithoutPayload() {
539   if (const int bits = NanPayloadBits<T>()) {
540     return uint64_t{1} << (bits - 1);
541   }
542   return 0;
543 }
544 
545 template <typename T>
546 constexpr uint64_t NanPayloadBitMask() {
547   if (const int bits = NanPayloadBits<T>()) {
548     return LsbMask<uint64_t>(bits);
549   }
550   return 0;
551 }
552 
553 template <typename T>
554 T NanWithSignAndPayload(bool sign, uint64_t nan_payload) {
555   using RepT = typename UnsignedIntegerTypeForSize<sizeof(T)>::type;
556   const T val = std::numeric_limits<T>::quiet_NaN();
557   auto rep = absl::bit_cast<RepT>(val);
558   rep &= LsbMask<RepT>(std::numeric_limits<RepT>::digits - 1);
559   rep |= uint64_t{sign} << (std::numeric_limits<RepT>::digits - 1);
560   constexpr int kPayloadBits = NanPayloadBits<T>();
561   if (kPayloadBits > 0) {
562     // Clear rep's NaN payload.
563     rep &= ~NanPayloadBitMask<T>();
564     CHECK_NE(nan_payload, 0);
565     rep |= nan_payload;
566   }
567   return absl::bit_cast<T>(rep);
568 }
569 
570 // Utility for performing a static_cast<> on a std::unique_ptr<>.
571 template <typename Derived, typename Base>
572 std::unique_ptr<Derived> unique_ptr_static_cast(std::unique_ptr<Base> ptr) {
573   return std::unique_ptr<Derived>(static_cast<Derived*>(ptr.release()));
574 }
575 
576 int64_t Product(absl::Span<const int64_t> xs);
577 
578 // Returns the start indices of consecutive non-overlapping subsequences of `a`
579 // and `b` with the same product, i.e. `(i, j)` so
580 // • a = {a[0 = i_0], ..., a[i_1 - 1], a[i_1], ... , a[i_2 - 1], ...}
581 // • b = {b[0 = j_0], ..., b[j_1 - 1], b[j_1], ... , b[j_2 - 1], ...}
582 // • ∀ k . 0 <= k < CommonFactors(a, b).size - 1 =>
583 //         a[i_k] × a[i_k + 1] × ... × a[i_(k+1) - 1] =
584 //         b[j_k] × b[j_k + 1] × ... × b[j_(k+1) - 1]
585 // where `CommonFactors(a, b)[CommonFactors(a, b).size - 1] = (a.size, b.size)`
586 //
587 // If input and output are the same, return {(0, 0), {1, 1}, ... {a.size,
588 // b.size}}, otherwise if the given shapes have non-zero size, returns the
589 // bounds of the shortest possible such subsequences; else, returns `{(0, 0),
590 // (a.size, b.size)}`.
591 absl::InlinedVector<std::pair<int64_t, int64_t>, 8> CommonFactors(
592     absl::Span<const int64_t> a, absl::Span<const int64_t> b);
593 
594 struct ConvertedDimensionNumbers {
595   DimensionVector transformed_from_dimensions;
596   DimensionVector untransformed_from_dimensions;
597   DimensionVector to_dimensions;
598   DimensionVector split_from_dimensions;
599   DimensionVector split_from_sizes;
600   DimensionVector split_to_dimensions;
601 };
602 
603 // Convert and unsorted list of dimensions from one shapes dimension sizes to
604 // another shapes dimensions sizes.
605 ConvertedDimensionNumbers ConvertDimensionNumbers(
606     absl::Span<const int64_t> from_dimensions,
607     absl::Span<const int64_t> from_sizes, absl::Span<const int64_t> to_sizes);
608 
609 // Removes illegal characters from filenames.
610 std::string SanitizeFileName(std::string file_name);
611 
612 template <typename C, typename Value>
613 int64_t FindIndex(const C& c, Value&& value) {
614   auto it = absl::c_find(c, std::forward<Value>(value));
615   return std::distance(c.begin(), it);
616 }
617 
618 template <typename C, typename Value>
619 void InsertAt(C* c, int64_t index, Value&& value) {
620   c->insert(c->begin() + index, std::forward<Value>(value));
621 }
622 
623 template <typename C>
624 void EraseAt(C* c, int64_t index) {
625   c->erase(c->begin() + index);
626 }
627 
628 template <typename T>
629 std::vector<T> SpanToVector(absl::Span<const T> slice) {
630   return std::vector<T>(slice.begin(), slice.end());
631 }
632 
633 template <typename T, size_t N>
634 std::vector<T> InlinedVectorToVector(
635     const absl::InlinedVector<T, N>& inlined_vector) {
636   return std::vector<T>(inlined_vector.begin(), inlined_vector.end());
637 }
638 
639 // Returns true if `x` fits in 32-bits.
640 template <typename T>
641 bool IsInt32(T x) {
642   // Following conversion rules: "the value is unchanged if it can be
643   // represented in the destination type (and bit-field width); otherwise, the
644   // value is implementation-defined."
645   return static_cast<int32_t>(x) == x;
646 }
647 
648 template <typename T>
649 Status EraseElementFromVector(std::vector<T>* container, const T& value) {
650   // absl::c_find returns a const_iterator which does not seem to work on
651   // gcc 4.8.4, and this breaks the ubuntu/xla_gpu build bot.
652   auto it = std::find(container->begin(), container->end(), value);
653   TF_RET_CHECK(it != container->end());
654   container->erase(it);
655   return OkStatus();
656 }
657 
658 // Utility function which splits a double-precision float (F64) into a pair of
659 // single-precision floating point numbers. The most significant 49 bits (out of
660 // the total 53 available) in the mantissa of the F64 is represented as the
661 // unevaluated sum of two non-overlapping single-precision F32s; the 'high' part
662 // contains 24 bits in its mantissa, and the 'low' part contains 25 bits in its
663 // sign bit and its mantissa.
664 // Note: The resulting representation can still only represent 8-bit exponent
665 // range that is available in F32s (out of a total of 11 exponent bits in F64s).
666 std::pair<float, float> SplitF64ToF32(double x);
667 
668 class HloInstruction;
669 
670 // A predicate over HLO instruction.
671 using HloPredicate = std::function<bool(const HloInstruction*)>;
672 
673 using Vector3 = std::array<int64_t, 3>;
674 
675 }  // namespace xla
676 
677 #define XLA_LOG_LINES(SEV, STRING) \
678   ::xla::LogLines(SEV, STRING, __FILE__, __LINE__)
679 
680 #define XLA_VLOG_LINES(LEVEL, STRING)                                 \
681   do {                                                                \
682     if (VLOG_IS_ON(LEVEL)) XLA_LOG_LINES(::tensorflow::INFO, STRING); \
683   } while (false);
684 
685 // Utility macro that performs the equivalent of what one would expect
686 // LOG_LINES(FATAL, X) to do but can be used at the end of a function that
687 // returns a value without getting a compiler warning that no value is returned.
688 #define XLA_FATAL_LOG(X)                 \
689   XLA_LOG_LINES(::tensorflow::ERROR, X); \
690   LOG(FATAL) << "Aborting in " << __FUNCTION__ << " due to previous errors.";
691 
692 #endif  // TENSORFLOW_COMPILER_XLA_UTIL_H_
693