1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Generally useful utility functions that are common to (not specific to any
17 // given part of) the XLA code base.
18
19 #ifndef TENSORFLOW_COMPILER_XLA_UTIL_H_
20 #define TENSORFLOW_COMPILER_XLA_UTIL_H_
21
22 #include <algorithm>
23 #include <array>
24 #include <functional>
25 #include <limits>
26 #include <string>
27 #include <type_traits>
28 #include <utility>
29 #include <vector>
30
31 #include "absl/algorithm/container.h"
32 #include "absl/base/thread_annotations.h"
33 #include "absl/container/inlined_vector.h"
34 #include "absl/numeric/bits.h"
35 #include "absl/strings/str_cat.h"
36 #include "absl/strings/str_format.h"
37 #include "absl/strings/string_view.h"
38 #include "absl/types/span.h"
39 #include "tensorflow/compiler/xla/status.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/errors.h" // IWYU pragma: keep
43 #include "tensorflow/core/lib/math/math_util.h"
44
45 namespace xla {
46
47 // Converts the unsigned integer n into a mixed-radix representation with the
48 // given bounds (radices). More precisely, if there are K radices, then the
49 // returned vector digits has K entries and satisfies
50 //
51 // 0 <= digits[i] < bounds[i], for i = 0, ..., K - 1
52 //
53 // and FromMixedRadix(digits) == n. The mixed radix representation is unique
54 // modulo the product of the entries of bounds.
55 std::vector<int64_t> ToMixedRadix(int64_t n, absl::Span<const int64_t> bounds);
56
57 // Logs the provided status message with a backtrace.
58 //
59 // For use by Status-factories, logs a backtrace at the point where the status
60 // is created, such that we can use --vmodule=util=1 to see all status
61 // creation backtraces.
62 Status WithLogBacktrace(const Status& status);
63
64 // Ranks greater than 6 are very rare, so use InlinedVector<int64_t, 6> to store
65 // the bounds and indices. And for the rare cases of ranks greater than 6,
66 // the InlinedVector will just behave like an std::vector<> and allocate the
67 // memory to store its values.
InlineRank()68 inline constexpr int InlineRank() { return 6; }
69 using DimensionVector = absl::InlinedVector<int64_t, InlineRank()>;
70 using DimLevelTypeVector = absl::InlinedVector<DimLevelType, InlineRank()>;
71
72 // RAII timer that logs with a given label the wall clock time duration in human
73 // readable form. This differs from base's ElapsedTimer primarily in that it
74 // spits out the human-readable duration form.
75 //
76 // Keeps track of global maximum and cumulative times across all invocations.
77 //
78 // By default, the timing traces are only printed at VLOG(1) and above:
79 //
80 // XLA_SCOPED_LOGGING_TIMER("fooing bar"); // nop if !VLOG_IS_ON(1).
81 //
82 // but you can control this via:
83 //
84 // XLA_SCOPED_LOGGING_TIMER_LEVEL("fooing bar", 2); // nop if !VLOG_IS_ON(2)
85 //
86 #define XLA_SCOPED_LOGGING_TIMER(label) \
87 XLA_SCOPED_LOGGING_TIMER_HELPER(label, 1, __COUNTER__)
88 #define XLA_SCOPED_LOGGING_TIMER_LEVEL(label, level) \
89 XLA_SCOPED_LOGGING_TIMER_HELPER(label, level, __COUNTER__)
90
91 // Helper for implementing macros above. Do not use directly.
92 //
93 // Forces the evaluation of "counter", which we expect is equal to __COUNTER__.
94 #define XLA_SCOPED_LOGGING_TIMER_HELPER(label, level, counter) \
95 XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter)
96
97 // Helper for macros above. Don't use directly.
98 #define XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter) \
99 static ::xla::TimerStats XLA_TimerStats##counter; \
100 ::xla::ScopedLoggingTimer XLA_ScopedLoggingTimerInstance##counter( \
101 label, /*enabled=*/VLOG_IS_ON(level), __FILE__, __LINE__, \
102 &XLA_TimerStats##counter);
103
104 struct TimerStats {
105 absl::Mutex stats_mutex;
106 double cumulative_secs ABSL_GUARDED_BY(stats_mutex) = 0;
107 double max_secs ABSL_GUARDED_BY(stats_mutex) = 0;
108 uint64_t times_called ABSL_GUARDED_BY(stats_mutex) = 0;
109 };
110
111 // RAII timer for XLA_SCOPED_LOGGING_TIMER and XLA_SCOPED_LOGGING_TIMER_LEVEL
112 // macros above. Recommended usage is via the macros so you don't have to give
113 // the timer a name or worry about calling VLOG_IS_ON yourself.
114 class ScopedLoggingTimer {
115 public:
116 // label: Label to display for logging.
117 // enabled: Whether this timer should do anything at all.
118 // file: Filename to display in logging.
119 // line: Line number to display in logging.
120 // `timer_stats`: unowned non-null pointer which is used to populate the
121 // global timer statistics.
122 ScopedLoggingTimer(absl::string_view label, bool enabled, const char* file,
123 int line, TimerStats* timer_stats);
124
125 // Stop the timer and log the tracked time. Timer is disabled after this
126 // function is called.
127 void StopAndLog();
128
129 ~ScopedLoggingTimer();
130
131 private:
132 const std::string label_;
133 const char* const file_;
134 const int line_;
135 TimerStats* const timer_stats_;
136 uint64_t start_micros_;
137 bool enabled_;
138 };
139
140 // Given a vector<T>, returns a Span<char> that points at its
141 // internals.
142 //
143 // Warning: if the vector is updated its storage pointer may change, so use this
144 // with caution (ideally in limited scopes with temporary lifetimes).
145 template <typename T>
MutableByteSlice(std::vector<T> * v)146 absl::Span<uint8_t> MutableByteSlice(std::vector<T>* v) {
147 return absl::Span<uint8_t>(reinterpret_cast<uint8_t*>(v->data()),
148 v->size() * sizeof(T));
149 }
150
151 // Turns an immutable slice of type T into an immutable slice of bytes with the
152 // same byte size.
153 template <typename T>
CastToByteSlice(absl::Span<const T> slice)154 absl::Span<const uint8_t> CastToByteSlice(absl::Span<const T> slice) {
155 return absl::Span<const uint8_t>(
156 reinterpret_cast<const uint8_t*>(slice.data()), slice.size() * sizeof(T));
157 }
158
159 // Casts a byte slice to a non-byte type T, checking that the original slice
160 // length is a multiple of sizeof(T).
161 template <typename T>
CastByteSlice(absl::Span<const uint8_t> slice)162 absl::Span<const T> CastByteSlice(absl::Span<const uint8_t> slice) {
163 CHECK_EQ(0, slice.size() % sizeof(T));
164 return absl::Span<const T>(reinterpret_cast<const T*>(slice.data()),
165 slice.size() / sizeof(T));
166 }
167
168 // Compares two containers for equality. Returns true iff the two containers
169 // have the same size and all their elements compare equal using their
170 // operator==. Like std::equal, but forces size equality.
171 template <typename Container1T,
172 typename ElementType = typename Container1T::value_type>
ContainersEqual(const Container1T & c1,std::initializer_list<ElementType> il)173 bool ContainersEqual(const Container1T& c1,
174 std::initializer_list<ElementType> il) {
175 absl::Span<const ElementType> c2{il};
176 return absl::c_equal(c1, c2);
177 }
178
179 #if defined(__cpp_lib_to_underlying) && __cpp_lib_to_underlying >= 202102L
180 using to_underlying = std::to_underlying;
181 #else
182 // Helper function which implements C++23's std::to_underlying.
183 template <typename T>
to_underlying(T value)184 constexpr std::underlying_type_t<T> to_underlying(T value) noexcept {
185 return static_cast<std::underlying_type_t<T>>(value);
186 }
187 #endif
188
189 // Performs a copy of count values from src to dest, using different strides for
190 // source and destination. The source starting index is src_base, while the
191 // destination one is dest_base.
192 template <typename D, typename S>
StridedCopy(absl::Span<D> dest,int64_t dest_base,int64_t dest_stride,absl::Span<const S> src,int64_t src_base,int64_t src_stride,int64_t count)193 void StridedCopy(absl::Span<D> dest, int64_t dest_base, int64_t dest_stride,
194 absl::Span<const S> src, int64_t src_base, int64_t src_stride,
195 int64_t count) {
196 for (; count > 0; --count, dest_base += dest_stride, src_base += src_stride) {
197 dest[dest_base] = static_cast<D>(src[src_base]);
198 }
199 }
200
201 // Adds some context information to the error message in a
202 // Status. This is useful as Statuses are
203 // propagated upwards.
204 Status AddStatus(Status prior, absl::string_view context);
205 Status AppendStatus(Status prior, absl::string_view context);
206
207 // Status error shorthands -- StrFormat's the arguments to be used as an error
208 // message and returns a status in the canonical error space.
209 template <typename... Args>
InvalidArgument(const absl::FormatSpec<Args...> & format,const Args &...args)210 Status InvalidArgument(const absl::FormatSpec<Args...>& format,
211 const Args&... args) {
212 return WithLogBacktrace(
213 tensorflow::errors::InvalidArgument(absl::StrFormat(format, args...)));
214 }
215 template <typename... Args>
Unimplemented(const absl::FormatSpec<Args...> & format,const Args &...args)216 Status Unimplemented(const absl::FormatSpec<Args...>& format,
217 const Args&... args) {
218 return WithLogBacktrace(
219 tensorflow::errors::Unimplemented(absl::StrFormat(format, args...)));
220 }
221 template <typename... Args>
InternalError(const absl::FormatSpec<Args...> & format,const Args &...args)222 Status InternalError(const absl::FormatSpec<Args...>& format,
223 const Args&... args) {
224 return WithLogBacktrace(
225 tensorflow::errors::Internal(absl::StrFormat(format, args...)));
226 }
227 template <typename... Args>
FailedPrecondition(const absl::FormatSpec<Args...> & format,const Args &...args)228 Status FailedPrecondition(const absl::FormatSpec<Args...>& format,
229 const Args&... args) {
230 return WithLogBacktrace(
231 tensorflow::errors::FailedPrecondition(absl::StrFormat(format, args...)));
232 }
233 template <typename... Args>
Cancelled(const absl::FormatSpec<Args...> & format,const Args &...args)234 Status Cancelled(const absl::FormatSpec<Args...>& format, const Args&... args) {
235 return WithLogBacktrace(
236 tensorflow::errors::Cancelled(absl::StrFormat(format, args...)));
237 }
238 template <typename... Args>
ResourceExhausted(const absl::FormatSpec<Args...> & format,const Args &...args)239 Status ResourceExhausted(const absl::FormatSpec<Args...>& format,
240 const Args&... args) {
241 return WithLogBacktrace(
242 tensorflow::errors::ResourceExhausted(absl::StrFormat(format, args...)));
243 }
244 template <typename... Args>
NotFound(const absl::FormatSpec<Args...> & format,const Args &...args)245 Status NotFound(const absl::FormatSpec<Args...>& format, const Args&... args) {
246 return WithLogBacktrace(
247 tensorflow::errors::NotFound(absl::StrFormat(format, args...)));
248 }
249 template <typename... Args>
Unavailable(const absl::FormatSpec<Args...> & format,const Args &...args)250 Status Unavailable(const absl::FormatSpec<Args...>& format,
251 const Args&... args) {
252 return WithLogBacktrace(
253 tensorflow::errors::Unavailable(absl::StrFormat(format, args...)));
254 }
255 template <typename... Args>
Unknown(const absl::FormatSpec<Args...> & format,const Args &...args)256 Status Unknown(const absl::FormatSpec<Args...>& format, const Args&... args) {
257 return WithLogBacktrace(
258 tensorflow::errors::Unknown(absl::StrFormat(format, args...)));
259 }
260 template <typename... Args>
Internal(const absl::FormatSpec<Args...> & format,const Args &...args)261 Status Internal(const absl::FormatSpec<Args...>& format, const Args&... args) {
262 return WithLogBacktrace(
263 tensorflow::errors::Internal(absl::StrFormat(format, args...)));
264 }
265
266 template <typename... Args>
InvalidArgumentStrCat(Args &&...concat)267 Status InvalidArgumentStrCat(Args&&... concat) {
268 return InvalidArgument("%s", absl::StrCat(std::forward<Args>(concat)...));
269 }
270
271 template <typename... Args>
UnimplementedStrCat(Args &&...concat)272 Status UnimplementedStrCat(Args&&... concat) {
273 return Unimplemented("%s", absl::StrCat(std::forward<Args>(concat)...));
274 }
275
276 template <typename... Args>
InternalErrorStrCat(Args &&...concat)277 Status InternalErrorStrCat(Args&&... concat) {
278 return InternalError("%s", absl::StrCat(std::forward<Args>(concat)...));
279 }
280
281 template <typename... Args>
ResourceExhaustedStrCat(Args &&...concat)282 Status ResourceExhaustedStrCat(Args&&... concat) {
283 return ResourceExhausted("%s", absl::StrCat(std::forward<Args>(concat)...));
284 }
285
286 // Splits the lines of the original, replaces leading whitespace with the prefix
287 // given by "indentation", and returns the string joined by newlines again. As a
288 // side effect, any additional trailing whitespace is removed.
289 //
290 // Note: even different amounts of leading whitespace on different lines will be
291 // uniformly replaced with "indentation".
292 std::string Reindent(absl::string_view original, absl::string_view indentation);
293
294 template <typename Container>
PositionInContainer(const Container & container,int64_t value)295 int64_t PositionInContainer(const Container& container, int64_t value) {
296 return std::distance(container.begin(), absl::c_find(container, value));
297 }
298
299 // Formats the container as a comma-separated string. StrAppend must support
300 // appending the elements of the container. Prefix is prepended and suffix is
301 // appended to the returned string.
302 template <typename Container>
303 std::string CommaSeparatedString(const Container& c, const char* prefix = "",
304 const char* suffix = "") {
305 // Not using Join() since the implementation here is simple anyway and this
306 // avoids copying the string to append prefix.
307 std::string comma_separated = prefix;
308 const char* separator = "";
309 for (const auto& entry : c) {
310 absl::StrAppend(&comma_separated, separator, entry);
311 separator = ", ";
312 }
313 comma_separated += suffix;
314 return comma_separated;
315 }
316
317 // Overload needed to allow the container to be an initializer list. The default
318 // type for T makes an empty initializer list work as well.
319 template <typename T = int>
320 std::string CommaSeparatedString(const std::initializer_list<T>& c,
321 const char* prefix = "",
322 const char* suffix = "") {
323 return CommaSeparatedString<std::initializer_list<T>>(c, prefix, suffix);
324 }
325
326 // Formats the container in the mathematical notation for a vector, e.g. (1, 3,
327 // 7). StrAppend must support appending the elements of c.
328 template <typename Container>
VectorString(const Container & c)329 std::string VectorString(const Container& c) {
330 return CommaSeparatedString(c, "(", ")");
331 }
332
333 // Overload needed to allow the container to be an initializer list. The default
334 // type for T makes an empty initializer list work as well.
335 template <typename T = int>
VectorString(const std::initializer_list<T> & c)336 std::string VectorString(const std::initializer_list<T>& c) {
337 return VectorString<std::initializer_list<T>>(c);
338 }
339
340 // Returns a string which can losslessly round trip to a bfloat.
341 std::string RoundTripFpToString(tensorflow::bfloat16 value);
342
343 // Returns a string which can losslessly round trip to a fp16.
344 std::string RoundTripFpToString(Eigen::half value);
345
346 // Returns a string which can losslessly round trip to a float.
347 std::string RoundTripFpToString(float value);
348
349 // Returns a string which can losslessly round trip to a double.
350 std::string RoundTripFpToString(double value);
351
352 // Returns a PaddingConfig object that represents no padding for the given rank.
353 PaddingConfig MakeNoPaddingConfig(int64_t rank);
354
355 // Returns a PaddingConfig object where 'padding' contains
356 // (low edge padding, high edge padding) pairs for each dimension.
357 PaddingConfig MakeEdgePaddingConfig(
358 absl::Span<const std::pair<int64_t, int64_t>> padding);
359
360 // Returns true if the padding configuration has at least one dimension with
361 // non-zero interior padding.
362 bool HasInteriorPadding(const PaddingConfig& config);
363
364 // Imports the templated FloorOfRatio math function from the TensorFlow
365 // namespace, as it is very commonly used.
366 template <typename T>
FloorOfRatio(T dividend,T divisor)367 T FloorOfRatio(T dividend, T divisor) {
368 return tensorflow::MathUtil::FloorOfRatio<T>(dividend, divisor);
369 }
370
371 // Imports the templated CeilOfRatio math function from the TensorFlow
372 // namespace, as it is very commonly used.
373 template <typename T>
CeilOfRatio(T dividend,T divisor)374 T CeilOfRatio(T dividend, T divisor) {
375 return tensorflow::MathUtil::CeilOfRatio<T>(dividend, divisor);
376 }
377
378 // Rounds the value up to a multiple of the divisor by first calling CeilOfRatio
379 // then multiplying by the divisor. For example: RoundUpTo(13, 8) => 16
380 template <typename T>
RoundUpTo(T value,T divisor)381 T RoundUpTo(T value, T divisor) {
382 return CeilOfRatio(value, divisor) * divisor;
383 }
384
385 // Rounds the value down to a multiple of the divisor by first calling
386 // FloorOfRatio then multiplying by the divisor. For example:
387 // RoundDownTo(13, 8) => 8
388 template <typename T>
RoundDownTo(T value,T divisor)389 T RoundDownTo(T value, T divisor) {
390 return FloorOfRatio(value, divisor) * divisor;
391 }
392
393 template <typename T>
394 struct DivMod {
395 T quotient;
396 T modulo;
397 };
398
399 // Divide `dividend` by `divisor` such that the quotient is rounded towards
400 // negative infinity. The remainder will have the same sign as `divisor`.
401 template <typename T>
FloorDivMod(T dividend,T divisor)402 DivMod<T> FloorDivMod(T dividend, T divisor) {
403 DivMod<T> div_mod;
404 div_mod.quotient = FloorOfRatio(dividend, divisor);
405 div_mod.modulo = dividend - div_mod.quotient * divisor;
406 return div_mod;
407 }
408
409 // Given a number of flops executed in an amount of time, produces a string that
410 // represents the throughput;
411 // e.g. HumanReadableNumFlops(1e9, 1e9) => 1.00GFLOP/s.
412 std::string HumanReadableNumFlops(double flops, double nanoseconds);
413
414 // Given a number of transcendental ops executed in an amount of time, produces
415 // a string that represents the throughput;
416 // e.g. HumanReadableNumTranscendentalOps(1e9, 1e9) => 1.00GTROP/s.
417 std::string HumanReadableNumTranscendentalOps(double trops, double nanoseconds);
418
419 // Split the text into multiple lines and log each line with the given
420 // severity, filename, and line number.
421 void LogLines(int sev, absl::string_view text, const char* fname, int lineno);
422
423 // Returns a mask with "width" number of least significant bits set.
424 template <typename T>
LsbMask(int width)425 constexpr inline T LsbMask(int width) {
426 static_assert(std::is_unsigned<T>::value,
427 "T should be an unsigned integer type");
428 ABSL_ASSERT(width >= 0);
429 ABSL_ASSERT(width <= std::numeric_limits<T>::digits);
430 return width == 0
431 ? 0
432 : static_cast<T>(-1) >> (std::numeric_limits<T>::digits - width);
433 }
434
435 // Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0.
436 template <typename T>
Log2Floor(T x)437 constexpr inline int Log2Floor(T x) {
438 static_assert(std::is_unsigned<T>::value,
439 "T should be an unsigned integer type");
440 return absl::bit_width(x) - 1;
441 }
442
443 // Return ceiling(log2(n)) for positive integer n. Returns -1 iff n == 0.
444 template <typename T>
Log2Ceiling(T x)445 constexpr inline int Log2Ceiling(T x) {
446 static_assert(std::is_unsigned<T>::value,
447 "T should be an unsigned integer type");
448 return x == 0 ? -1 : absl::bit_width(x - 1);
449 }
450
451 // Return the number of sign bits (i.e. the number of leading ones for negative
452 // numbers and the number of leading zeros for non-negative numbers).
453 template <typename T>
CountLeadingSignBits(T x)454 constexpr inline int CountLeadingSignBits(T x) {
455 static_assert(std::is_signed<T>::value, "T should be a signed integer type");
456 using UnsignedType = std::make_unsigned_t<T>;
457 return x < T{0} ? absl::countl_one<UnsignedType>(x)
458 : absl::countl_zero<UnsignedType>(x);
459 }
460
461 // Returns `value` with the low `width` bits set and the remaining bits set to
462 // zero.
463 template <typename T>
KeepLowerBits(T value,int width)464 constexpr inline T KeepLowerBits(T value, int width) {
465 return value & LsbMask<T>(width);
466 }
467
468 // Returns `base` multiplied by itself `exponent` number of times.
469 //
470 // Note: returns 1 when `exponent` is zero.
471 // Precondition: `exponent` is non-negative.
472 template <typename T>
IPow(T base,int exponent)473 constexpr T IPow(T base, int exponent) {
474 // A negative `exponent` is indicative of a logic bug for integral `base`.
475 // We disallow it for floating-point types for symmetry.
476 ABSL_ASSERT(exponent >= 0);
477 // We use the right-to-left binary exponentiation algorithm.
478 T result{1};
479 while (exponent > 0) {
480 if ((exponent & 1) != 0) {
481 result *= base;
482 }
483 base *= base;
484 exponent >>= 1;
485 }
486 return result;
487 }
488
489 template <size_t>
490 struct UnsignedIntegerTypeForSize;
491
492 template <>
493 struct UnsignedIntegerTypeForSize<1> {
494 using type = uint8_t;
495 };
496
497 template <>
498 struct UnsignedIntegerTypeForSize<2> {
499 using type = uint16_t;
500 };
501
502 template <>
503 struct UnsignedIntegerTypeForSize<4> {
504 using type = uint32_t;
505 };
506
507 template <>
508 struct UnsignedIntegerTypeForSize<8> {
509 using type = uint64_t;
510 };
511
512 template <size_t N>
513 struct SignedIntegerTypeForSize {
514 using type = std::make_signed_t<typename UnsignedIntegerTypeForSize<N>::type>;
515 };
516
517 // Returns the signed magnitude of T.
518 template <typename T>
519 typename SignedIntegerTypeForSize<sizeof(T)>::type ToSignMagnitude(T input) {
520 auto as_bits =
521 absl::bit_cast<typename SignedIntegerTypeForSize<sizeof(T)>::type>(input);
522 auto sign_mask =
523 absl::bit_cast<typename UnsignedIntegerTypeForSize<sizeof(T)>::type>(
524 tensorflow::MathUtil::Sign(as_bits));
525 return as_bits ^ (sign_mask >> 1);
526 }
527
528 template <typename T>
529 constexpr int NanPayloadBits() {
530 // Floating point types with NaNs have payloads.
531 if (!std::numeric_limits<T>::has_quiet_NaN) {
532 return 0;
533 }
534 return std::numeric_limits<T>::digits - 1;
535 }
536
537 template <typename T>
538 constexpr uint64_t QuietNanWithoutPayload() {
539 if (const int bits = NanPayloadBits<T>()) {
540 return uint64_t{1} << (bits - 1);
541 }
542 return 0;
543 }
544
545 template <typename T>
546 constexpr uint64_t NanPayloadBitMask() {
547 if (const int bits = NanPayloadBits<T>()) {
548 return LsbMask<uint64_t>(bits);
549 }
550 return 0;
551 }
552
553 template <typename T>
554 T NanWithSignAndPayload(bool sign, uint64_t nan_payload) {
555 using RepT = typename UnsignedIntegerTypeForSize<sizeof(T)>::type;
556 const T val = std::numeric_limits<T>::quiet_NaN();
557 auto rep = absl::bit_cast<RepT>(val);
558 rep &= LsbMask<RepT>(std::numeric_limits<RepT>::digits - 1);
559 rep |= uint64_t{sign} << (std::numeric_limits<RepT>::digits - 1);
560 constexpr int kPayloadBits = NanPayloadBits<T>();
561 if (kPayloadBits > 0) {
562 // Clear rep's NaN payload.
563 rep &= ~NanPayloadBitMask<T>();
564 CHECK_NE(nan_payload, 0);
565 rep |= nan_payload;
566 }
567 return absl::bit_cast<T>(rep);
568 }
569
570 // Utility for performing a static_cast<> on a std::unique_ptr<>.
571 template <typename Derived, typename Base>
572 std::unique_ptr<Derived> unique_ptr_static_cast(std::unique_ptr<Base> ptr) {
573 return std::unique_ptr<Derived>(static_cast<Derived*>(ptr.release()));
574 }
575
576 int64_t Product(absl::Span<const int64_t> xs);
577
578 // Returns the start indices of consecutive non-overlapping subsequences of `a`
579 // and `b` with the same product, i.e. `(i, j)` so
580 // • a = {a[0 = i_0], ..., a[i_1 - 1], a[i_1], ... , a[i_2 - 1], ...}
581 // • b = {b[0 = j_0], ..., b[j_1 - 1], b[j_1], ... , b[j_2 - 1], ...}
582 // • ∀ k . 0 <= k < CommonFactors(a, b).size - 1 =>
583 // a[i_k] × a[i_k + 1] × ... × a[i_(k+1) - 1] =
584 // b[j_k] × b[j_k + 1] × ... × b[j_(k+1) - 1]
585 // where `CommonFactors(a, b)[CommonFactors(a, b).size - 1] = (a.size, b.size)`
586 //
587 // If input and output are the same, return {(0, 0), {1, 1}, ... {a.size,
588 // b.size}}, otherwise if the given shapes have non-zero size, returns the
589 // bounds of the shortest possible such subsequences; else, returns `{(0, 0),
590 // (a.size, b.size)}`.
591 absl::InlinedVector<std::pair<int64_t, int64_t>, 8> CommonFactors(
592 absl::Span<const int64_t> a, absl::Span<const int64_t> b);
593
594 struct ConvertedDimensionNumbers {
595 DimensionVector transformed_from_dimensions;
596 DimensionVector untransformed_from_dimensions;
597 DimensionVector to_dimensions;
598 DimensionVector split_from_dimensions;
599 DimensionVector split_from_sizes;
600 DimensionVector split_to_dimensions;
601 };
602
603 // Convert and unsorted list of dimensions from one shapes dimension sizes to
604 // another shapes dimensions sizes.
605 ConvertedDimensionNumbers ConvertDimensionNumbers(
606 absl::Span<const int64_t> from_dimensions,
607 absl::Span<const int64_t> from_sizes, absl::Span<const int64_t> to_sizes);
608
609 // Removes illegal characters from filenames.
610 std::string SanitizeFileName(std::string file_name);
611
612 template <typename C, typename Value>
613 int64_t FindIndex(const C& c, Value&& value) {
614 auto it = absl::c_find(c, std::forward<Value>(value));
615 return std::distance(c.begin(), it);
616 }
617
618 template <typename C, typename Value>
619 void InsertAt(C* c, int64_t index, Value&& value) {
620 c->insert(c->begin() + index, std::forward<Value>(value));
621 }
622
623 template <typename C>
624 void EraseAt(C* c, int64_t index) {
625 c->erase(c->begin() + index);
626 }
627
628 template <typename T>
629 std::vector<T> SpanToVector(absl::Span<const T> slice) {
630 return std::vector<T>(slice.begin(), slice.end());
631 }
632
633 template <typename T, size_t N>
634 std::vector<T> InlinedVectorToVector(
635 const absl::InlinedVector<T, N>& inlined_vector) {
636 return std::vector<T>(inlined_vector.begin(), inlined_vector.end());
637 }
638
639 // Returns true if `x` fits in 32-bits.
640 template <typename T>
641 bool IsInt32(T x) {
642 // Following conversion rules: "the value is unchanged if it can be
643 // represented in the destination type (and bit-field width); otherwise, the
644 // value is implementation-defined."
645 return static_cast<int32_t>(x) == x;
646 }
647
648 template <typename T>
649 Status EraseElementFromVector(std::vector<T>* container, const T& value) {
650 // absl::c_find returns a const_iterator which does not seem to work on
651 // gcc 4.8.4, and this breaks the ubuntu/xla_gpu build bot.
652 auto it = std::find(container->begin(), container->end(), value);
653 TF_RET_CHECK(it != container->end());
654 container->erase(it);
655 return OkStatus();
656 }
657
658 // Utility function which splits a double-precision float (F64) into a pair of
659 // single-precision floating point numbers. The most significant 49 bits (out of
660 // the total 53 available) in the mantissa of the F64 is represented as the
661 // unevaluated sum of two non-overlapping single-precision F32s; the 'high' part
662 // contains 24 bits in its mantissa, and the 'low' part contains 25 bits in its
663 // sign bit and its mantissa.
664 // Note: The resulting representation can still only represent 8-bit exponent
665 // range that is available in F32s (out of a total of 11 exponent bits in F64s).
666 std::pair<float, float> SplitF64ToF32(double x);
667
668 class HloInstruction;
669
670 // A predicate over HLO instruction.
671 using HloPredicate = std::function<bool(const HloInstruction*)>;
672
673 using Vector3 = std::array<int64_t, 3>;
674
675 } // namespace xla
676
677 #define XLA_LOG_LINES(SEV, STRING) \
678 ::xla::LogLines(SEV, STRING, __FILE__, __LINE__)
679
680 #define XLA_VLOG_LINES(LEVEL, STRING) \
681 do { \
682 if (VLOG_IS_ON(LEVEL)) XLA_LOG_LINES(::tensorflow::INFO, STRING); \
683 } while (false);
684
685 // Utility macro that performs the equivalent of what one would expect
686 // LOG_LINES(FATAL, X) to do but can be used at the end of a function that
687 // returns a value without getting a compiler warning that no value is returned.
688 #define XLA_FATAL_LOG(X) \
689 XLA_LOG_LINES(::tensorflow::ERROR, X); \
690 LOG(FATAL) << "Aborting in " << __FUNCTION__ << " due to previous errors.";
691
692 #endif // TENSORFLOW_COMPILER_XLA_UTIL_H_
693