1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Generally useful utility functions that are common to (not specific to any
17 // given part of) the XLA code base.
18
19 #ifndef TENSORFLOW_COMPILER_XLA_UTIL_H_
20 #define TENSORFLOW_COMPILER_XLA_UTIL_H_
21
22 #include <algorithm>
23 #include <limits>
24 #include <string>
25 #include <type_traits>
26 #include <vector>
27
28 #include "absl/algorithm/container.h"
29 #include "absl/base/thread_annotations.h"
30 #include "absl/container/inlined_vector.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/strings/string_view.h"
35 #include "absl/types/span.h"
36 #include "tensorflow/compiler/xla/status.h"
37 #include "tensorflow/compiler/xla/status_macros.h"
38 #include "tensorflow/compiler/xla/types.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/lib/math/math_util.h"
43 #include "tensorflow/core/lib/strings/numbers.h"
44 #include "tensorflow/core/platform/logging.h"
45 #include "tensorflow/core/platform/macros.h"
46 #include "tensorflow/core/platform/mutex.h"
47 #include "tensorflow/core/platform/protobuf.h"
48 #include "tensorflow/core/platform/types.h"
49
50 namespace xla {
51
52 // Logs the provided status message with a backtrace.
53 //
54 // For use by Status-factories, logs a backtrace at the point where the status
55 // is created, such that we can use --vmodule=util=1 to see all status
56 // creation backtraces.
57 Status WithLogBacktrace(const Status& status);
58
59 // Ranks greater than 8 are very rare, so use InlinedVector<int64, 8> to store
60 // the bounds and indices. And for the rare cases of ranks greater than 8,
61 // the InlinedVector will just behave like an std::vector<> and allocate the
62 // memory to store its values.
63 static constexpr int kInlineRank = 8;
64 using DimensionVector = absl::InlinedVector<int64, kInlineRank>;
65
66 // RAII timer that logs with a given label the wall clock time duration in human
67 // readable form. This differs from base's ElapsedTimer primarily in that it
68 // spits out the human-readable duration form.
69 //
70 // Keeps track of global maximum and cumulative times across all invocations.
71 //
72 // By default, the timing traces are only printed at VLOG(1) and above:
73 //
74 // XLA_SCOPED_LOGGING_TIMER("fooing bar"); // nop if !VLOG_IS_ON(1).
75 //
76 // but you can control this via:
77 //
78 // XLA_SCOPED_LOGGING_TIMER_LEVEL("fooing bar", 2); // nop if !VLOG_IS_ON(2)
79 //
80 #define XLA_SCOPED_LOGGING_TIMER(label) \
81 XLA_SCOPED_LOGGING_TIMER_HELPER(label, 1, __COUNTER__)
82 #define XLA_SCOPED_LOGGING_TIMER_LEVEL(label, level) \
83 XLA_SCOPED_LOGGING_TIMER_HELPER(label, level, __COUNTER__)
84
85 // Helper for implementing macros above. Do not use directly.
86 //
87 // Forces the evaluation of "counter", which we expect is equal to __COUNTER__.
88 #define XLA_SCOPED_LOGGING_TIMER_HELPER(label, level, counter) \
89 XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter)
90
91 // Helper for macros above. Don't use directly.
92 #define XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter) \
93 static ::xla::TimerStats XLA_TimerStats##counter; \
94 ::xla::ScopedLoggingTimer XLA_ScopedLoggingTimerInstance##counter( \
95 label, /*enabled=*/VLOG_IS_ON(level), __FILE__, __LINE__, \
96 &XLA_TimerStats##counter);
97
98 struct TimerStats {
99 tensorflow::mutex stats_mutex;
100 double cumulative_secs ABSL_GUARDED_BY(stats_mutex) = 0;
101 double max_secs ABSL_GUARDED_BY(stats_mutex) = 0;
102 uint64 times_called ABSL_GUARDED_BY(stats_mutex) = 0;
103 };
104
105 // RAII timer for XLA_SCOPED_LOGGING_TIMER and XLA_SCOPED_LOGGING_TIMER_LEVEL
106 // macros above. Recommended usage is via the macros so you don't have to give
107 // the timer a name or worry about calling VLOG_IS_ON yourself.
108 class ScopedLoggingTimer {
109 public:
110 // label: Label to display for logging.
111 // enabled: Whether this timer should do anything at all.
112 // file: Filename to display in logging.
113 // line: Line number to display in logging.
114 // `timer_stats`: unowned non-null pointer which is used to populate the
115 // global timer statistics.
116 ScopedLoggingTimer(const std::string& label, bool enabled, const char* file,
117 int line, TimerStats* timer_stats);
118
119 // Stop the timer and log the tracked time. Timer is disabled after this
120 // function is called.
121 void StopAndLog();
122
123 ~ScopedLoggingTimer();
124
125 private:
126 bool enabled_;
127 const char* file_;
128 int line_;
129 string label_;
130 uint64 start_micros_;
131 TimerStats* timer_stats_;
132 };
133
134 // Given a vector<T>, returns a Span<char> that points at its
135 // internals.
136 //
137 // Warning: if the vector is updated its storage pointer may change, so use this
138 // with caution (ideally in limited scopes with temporary lifetimes).
139 template <typename T>
MutableByteSlice(std::vector<T> * v)140 absl::Span<uint8> MutableByteSlice(std::vector<T>* v) {
141 return absl::Span<uint8>(reinterpret_cast<uint8*>(v->data()),
142 v->size() * sizeof(T));
143 }
144
145 // Turns an immutable slice of type T into an immutable slice of bytes with the
146 // same byte size.
147 template <typename T>
CastToByteSlice(absl::Span<const T> slice)148 absl::Span<const uint8> CastToByteSlice(absl::Span<const T> slice) {
149 return absl::Span<const uint8>(reinterpret_cast<const uint8*>(slice.data()),
150 slice.size() * sizeof(T));
151 }
152
153 // Casts a byte slice to a non-byte type T, checking that the original slice
154 // length is a multiple of sizeof(T).
155 template <typename T>
CastByteSlice(absl::Span<const uint8> slice)156 absl::Span<const T> CastByteSlice(absl::Span<const uint8> slice) {
157 CHECK_EQ(0, slice.size() % sizeof(T));
158 return absl::Span<const T>(reinterpret_cast<const T*>(slice.data()),
159 slice.size() / sizeof(T));
160 }
161
162 // Convenience function to force a vector to convert to an immutable slice.
163 template <typename T>
AsSlice(const std::vector<T> & v)164 absl::Span<const T> AsSlice(const std::vector<T>& v) {
165 return absl::Span<const T>(v);
166 }
167
168 // Converts a mutable vector pointer into a Span of the same
169 // type.
170 template <typename T>
AsMutableSlice(std::vector<T> * v)171 absl::Span<T> AsMutableSlice(std::vector<T>* v) {
172 return absl::Span<T>(v->data(), v->size());
173 }
174
175 // xla::int64 is not the same type as tensorflow::protobuf_int64 in open-source.
176 // Wrapper function that gives an int64 array slice view of a repeated int64
177 // protobuf field.
AsInt64Slice(const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64> & v)178 static inline absl::Span<const int64> AsInt64Slice(
179 const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>& v) {
180 absl::Span<const tensorflow::protobuf_int64> slice(v);
181 return absl::Span<const int64>(reinterpret_cast<const int64*>(slice.data()),
182 slice.size());
183 }
184
185 // TODO(b/29771030): This nop overload was added to simplify the migration of
186 // Shape from a proto to a C++ class. Remove after class has been migrated.
AsInt64Slice(absl::Span<const int64> slice)187 static inline absl::Span<const int64> AsInt64Slice(
188 absl::Span<const int64> slice) {
189 return slice;
190 }
191
192 // As above, but for uint64 types.
AsUInt64Slice(const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_uint64> & v)193 static inline absl::Span<const uint64> AsUInt64Slice(
194 const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_uint64>& v) {
195 absl::Span<const tensorflow::protobuf_uint64> slice(v);
196 return absl::Span<const uint64>(reinterpret_cast<const uint64*>(slice.data()),
197 slice.size());
198 }
199
200 // Compares two containers for equality. Returns true iff the two containers
201 // have the same size and all their elements compare equal using their
202 // operator==. Like std::equal, but forces size equality.
203 template <typename Container1T, typename Container2T>
ContainersEqual(const Container1T & c1,const Container2T & c2)204 bool ContainersEqual(const Container1T& c1, const Container2T& c2) {
205 return ((c1.size() == c2.size()) &&
206 std::equal(std::begin(c1), std::end(c1), std::begin(c2)));
207 }
208
209 template <typename Container1T,
210 typename ElementType = typename Container1T::value_type>
ContainersEqual(const Container1T & c1,std::initializer_list<ElementType> il)211 bool ContainersEqual(const Container1T& c1,
212 std::initializer_list<ElementType> il) {
213 absl::Span<const ElementType> c2{il};
214 return ContainersEqual(c1, c2);
215 }
216
217 // Compares two containers for equality. Returns true iff the two containers
218 // have the same size and all their elements compare equal using the predicate
219 // p. Like std::equal, but forces size equality.
220 template <typename Container1T, typename Container2T, class PredicateT>
ContainersEqual(const Container1T & c1,const Container2T & c2,PredicateT p)221 bool ContainersEqual(const Container1T& c1, const Container2T& c2,
222 PredicateT p) {
223 return ((c1.size() == c2.size()) &&
224 std::equal(std::begin(c1), std::end(c1), std::begin(c2), p));
225 }
226
227 // Performs a copy of count values from src to dest, using different strides for
228 // source and destination. The source starting index is src_base, while the
229 // destination one is dest_base.
230 template <typename D, typename S>
StridedCopy(absl::Span<D> dest,int64_t dest_base,int64_t dest_stride,absl::Span<const S> src,int64_t src_base,int64_t src_stride,int64_t count)231 void StridedCopy(absl::Span<D> dest, int64_t dest_base, int64_t dest_stride,
232 absl::Span<const S> src, int64_t src_base, int64_t src_stride,
233 int64_t count) {
234 for (; count > 0; --count, dest_base += dest_stride, src_base += src_stride) {
235 dest[dest_base] = static_cast<D>(src[src_base]);
236 }
237 }
238
239 // Adds some context information to the error message in a
240 // Status. This is useful as Statuses are
241 // propagated upwards.
242 Status AddStatus(Status prior, absl::string_view context);
243 Status AppendStatus(Status prior, absl::string_view context);
244
245 // Status error shorthands -- StrFormat's the arguments to be used as an error
246 // message and returns a status in the canonical error space.
247 template <typename... Args>
InvalidArgument(const absl::FormatSpec<Args...> & format,const Args &...args)248 Status InvalidArgument(const absl::FormatSpec<Args...>& format,
249 const Args&... args) {
250 return WithLogBacktrace(
251 tensorflow::errors::InvalidArgument(absl::StrFormat(format, args...)));
252 }
253 template <typename... Args>
Unimplemented(const absl::FormatSpec<Args...> & format,const Args &...args)254 Status Unimplemented(const absl::FormatSpec<Args...>& format,
255 const Args&... args) {
256 return WithLogBacktrace(
257 tensorflow::errors::Unimplemented(absl::StrFormat(format, args...)));
258 }
259 template <typename... Args>
InternalError(const absl::FormatSpec<Args...> & format,const Args &...args)260 Status InternalError(const absl::FormatSpec<Args...>& format,
261 const Args&... args) {
262 return WithLogBacktrace(
263 tensorflow::errors::Internal(absl::StrFormat(format, args...)));
264 }
265 template <typename... Args>
FailedPrecondition(const absl::FormatSpec<Args...> & format,const Args &...args)266 Status FailedPrecondition(const absl::FormatSpec<Args...>& format,
267 const Args&... args) {
268 return WithLogBacktrace(
269 tensorflow::errors::FailedPrecondition(absl::StrFormat(format, args...)));
270 }
271 template <typename... Args>
Cancelled(const absl::FormatSpec<Args...> & format,const Args &...args)272 Status Cancelled(const absl::FormatSpec<Args...>& format, const Args&... args) {
273 return WithLogBacktrace(
274 tensorflow::errors::Cancelled(absl::StrFormat(format, args...)));
275 }
276 template <typename... Args>
ResourceExhausted(const absl::FormatSpec<Args...> & format,const Args &...args)277 Status ResourceExhausted(const absl::FormatSpec<Args...>& format,
278 const Args&... args) {
279 return WithLogBacktrace(
280 tensorflow::errors::ResourceExhausted(absl::StrFormat(format, args...)));
281 }
282 template <typename... Args>
NotFound(const absl::FormatSpec<Args...> & format,const Args &...args)283 Status NotFound(const absl::FormatSpec<Args...>& format, const Args&... args) {
284 return WithLogBacktrace(
285 tensorflow::errors::NotFound(absl::StrFormat(format, args...)));
286 }
287 template <typename... Args>
Unavailable(const absl::FormatSpec<Args...> & format,const Args &...args)288 Status Unavailable(const absl::FormatSpec<Args...>& format,
289 const Args&... args) {
290 return WithLogBacktrace(
291 tensorflow::errors::Unavailable(absl::StrFormat(format, args...)));
292 }
293 template <typename... Args>
Unknown(const absl::FormatSpec<Args...> & format,const Args &...args)294 Status Unknown(const absl::FormatSpec<Args...>& format, const Args&... args) {
295 return WithLogBacktrace(
296 tensorflow::errors::Unknown(absl::StrFormat(format, args...)));
297 }
298 template <typename... Args>
Internal(const absl::FormatSpec<Args...> & format,const Args &...args)299 Status Internal(const absl::FormatSpec<Args...>& format, const Args&... args) {
300 return WithLogBacktrace(
301 tensorflow::errors::Internal(absl::StrFormat(format, args...)));
302 }
303
304 template <typename... Args>
InvalidArgumentStrCat(Args &&...concat)305 Status InvalidArgumentStrCat(Args&&... concat) {
306 return InvalidArgument("%s", absl::StrCat(std::forward<Args>(concat)...));
307 }
308
309 template <typename... Args>
UnimplementedStrCat(Args &&...concat)310 Status UnimplementedStrCat(Args&&... concat) {
311 return Unimplemented("%s", absl::StrCat(std::forward<Args>(concat)...));
312 }
313
314 template <typename... Args>
InternalErrorStrCat(Args &&...concat)315 Status InternalErrorStrCat(Args&&... concat) {
316 return InternalError("%s", absl::StrCat(std::forward<Args>(concat)...));
317 }
318
319 template <typename... Args>
ResourceExhaustedStrCat(Args &&...concat)320 Status ResourceExhaustedStrCat(Args&&... concat) {
321 return ResourceExhausted("%s", absl::StrCat(std::forward<Args>(concat)...));
322 }
323
324 // Splits the lines of the original, replaces leading whitespace with the prefix
325 // given by "indentation", and returns the string joined by newlines again. As a
326 // side effect, any additional trailing whitespace is removed.
327 //
328 // Note: even different amounts of leading whitespace on different lines will be
329 // uniformly replaced with "indentation".
330 string Reindent(absl::string_view original, absl::string_view indentation);
331
332 template <typename Container>
PositionInContainer(const Container & container,int64_t value)333 int64 PositionInContainer(const Container& container, int64_t value) {
334 return std::distance(container.begin(), absl::c_find(container, value));
335 }
336
337 // Formats the container as a comma-separated string. StrAppend must support
338 // appending the elements of the container. Prefix is prepended and suffix is
339 // appended to the returned string.
340 template <typename Container>
341 string CommaSeparatedString(const Container& c, const char* prefix = "",
342 const char* suffix = "") {
343 // Not using Join() since the implementation here is simple anyway and this
344 // avoids copying the string to append prefix.
345 string comma_separated = prefix;
346 const char* separator = "";
347 for (const auto& entry : c) {
348 absl::StrAppend(&comma_separated, separator, entry);
349 separator = ", ";
350 }
351 comma_separated += suffix;
352 return comma_separated;
353 }
354
355 // Overload needed to allow the container to be an initializer list. The default
356 // type for T makes an empty initializer list work as well.
357 template <typename T = int>
358 string CommaSeparatedString(const std::initializer_list<T>& c,
359 const char* prefix = "", const char* suffix = "") {
360 return CommaSeparatedString<std::initializer_list<T>>(c, prefix, suffix);
361 }
362
363 // Formats the container in the mathematical notation for a vector, e.g. (1, 3,
364 // 7). StrAppend must support appending the elements of c.
365 template <typename Container>
VectorString(const Container & c)366 string VectorString(const Container& c) {
367 return CommaSeparatedString(c, "(", ")");
368 }
369
370 // Overload needed to allow the container to be an initializer list. The default
371 // type for T makes an empty initializer list work as well.
372 template <typename T = int>
VectorString(const std::initializer_list<T> & c)373 string VectorString(const std::initializer_list<T>& c) {
374 return VectorString<std::initializer_list<T>>(c);
375 }
376
377 // Returns a string which can losslessly round trip to a bfloat.
378 string RoundTripFpToString(tensorflow::bfloat16 value);
379
380 // Returns a string which can losslessly round trip to a fp16.
381 string RoundTripFpToString(Eigen::half value);
382
383 // Returns a string which can losslessly round trip to a float.
384 string RoundTripFpToString(float value);
385
386 // Returns a string which can losslessly round trip to a double.
387 string RoundTripFpToString(double value);
388
389 // Returns a PaddingConfig object that represents no padding for the given rank.
390 PaddingConfig MakeNoPaddingConfig(int64_t rank);
391
392 // Returns a PaddingConfig object where 'padding' contains
393 // (low edge padding, high edge padding) pairs for each dimension.
394 PaddingConfig MakeEdgePaddingConfig(
395 absl::Span<const std::pair<int64, int64>> padding);
396
397 // Returns true if the padding configuration has at least one dimension with
398 // non-zero interior padding.
399 bool HasInteriorPadding(const PaddingConfig& config);
400
401 // Imports the templated FloorOfRatio math function from the TensorFlow
402 // namespace, as it is very commonly used.
403 template <typename T>
FloorOfRatio(T dividend,T divisor)404 T FloorOfRatio(T dividend, T divisor) {
405 return tensorflow::MathUtil::FloorOfRatio<T>(dividend, divisor);
406 }
407
408 // Imports the templated CeilOfRatio math function from the TensorFlow
409 // namespace, as it is very commonly used.
410 template <typename T>
CeilOfRatio(T dividend,T divisor)411 T CeilOfRatio(T dividend, T divisor) {
412 return tensorflow::MathUtil::CeilOfRatio<T>(dividend, divisor);
413 }
414
415 // Rounds the value up to a multiple of the divisor by first calling CeilOfRatio
416 // then multiplying by the divisor. For example: RoundUpToNearest(13, 8) => 16
417 template <typename T>
RoundUpToNearest(T value,T divisor)418 T RoundUpToNearest(T value, T divisor) {
419 return CeilOfRatio(value, divisor) * divisor;
420 }
421
422 // Rounds the value down to a multiple of the divisor by first calling
423 // FloorOfRatio then multiplying by the divisor. For example:
424 // RoundDownToNearest(13, 8) => 8
425 template <typename T>
RoundDownToNearest(T value,T divisor)426 T RoundDownToNearest(T value, T divisor) {
427 return FloorOfRatio(value, divisor) * divisor;
428 }
429
430 // Given a number of flops executed in an amount of time, produces a string that
431 // represents the throughput;
432 // e.g. HumanReadableNumFlops(1e9, 1e9) => 1.00GFLOP/s.
433 string HumanReadableNumFlops(double flops, double nanoseconds);
434
435 // Given a number of transcendental ops executed in an amount of time, produces
436 // a string that represents the throughput;
437 // e.g. HumanReadableNumTranscendentalOps(1e9, 1e9) => 1.00GTROP/s.
438 string HumanReadableNumTranscendentalOps(double trops, double nanoseconds);
439
440 // Split the text into multiple lines and log each line with the given
441 // severity, filename, and line number.
442 void LogLines(int sev, absl::string_view text, const char* fname, int lineno);
443
444 template <typename T>
IsPowerOfTwo(T x)445 inline bool IsPowerOfTwo(T x) {
446 static_assert(!std::numeric_limits<T>::is_signed, "unsigned types only");
447 return x != 0 && (x & (x - 1)) == 0;
448 }
449
450 // Returns a mask with "width" number of least significant bits set.
451 template <typename T>
LsbMask(int width)452 inline T LsbMask(int width) {
453 static_assert(std::is_unsigned<T>::value,
454 "T should be an unsigned integer type");
455 CHECK_GE(width, 0) << "Unsupported width " << width;
456 CHECK_LE(width, std::numeric_limits<T>::digits)
457 << "Unsupported width " << width;
458 return width == 0
459 ? 0
460 : static_cast<T>(-1) >> (std::numeric_limits<T>::digits - width);
461 }
462
463 // Returns the value with every bit except the lower 'width' bits set to zero.
464 template <typename T>
ClearUpperBits(T value,int width)465 inline T ClearUpperBits(T value, int width) {
466 return value & LsbMask<T>(width);
467 }
468
469 template <size_t>
470 struct UnsignedIntegerTypeForSize;
471
472 template <>
473 struct UnsignedIntegerTypeForSize<1> {
474 using type = uint8_t;
475 };
476
477 template <>
478 struct UnsignedIntegerTypeForSize<2> {
479 using type = uint16_t;
480 };
481
482 template <>
483 struct UnsignedIntegerTypeForSize<4> {
484 using type = uint32_t;
485 };
486
487 template <>
488 struct UnsignedIntegerTypeForSize<8> {
489 using type = uint64_t;
490 };
491
492 template <typename T>
493 constexpr int NanPayloadBits() {
494 // Floating point types with NaNs have payloads.
495 if (!std::numeric_limits<T>::has_quiet_NaN) {
496 return 0;
497 }
498 return std::numeric_limits<T>::digits - 1;
499 }
500
501 template <typename T>
502 constexpr uint64_t QuietNanWithoutPayload() {
503 if (const int bits = NanPayloadBits<T>()) {
504 return uint64_t{1} << (bits - 1);
505 }
506 return 0;
507 }
508
509 template <typename T>
510 constexpr uint64_t NanPayloadBitMask() {
511 if (const int bits = NanPayloadBits<T>()) {
512 return LsbMask<uint64_t>(bits);
513 }
514 return 0;
515 }
516
517 template <typename T>
518 T NanWithSignAndPayload(bool sign, uint64_t nan_payload) {
519 using RepT = typename UnsignedIntegerTypeForSize<sizeof(T)>::type;
520 const T val = std::numeric_limits<T>::quiet_NaN();
521 auto rep = absl::bit_cast<RepT>(val);
522 rep &= LsbMask<RepT>(std::numeric_limits<RepT>::digits - 1);
523 rep |= uint64_t{sign} << (std::numeric_limits<RepT>::digits - 1);
524 constexpr int kPayloadBits = NanPayloadBits<T>();
525 if (kPayloadBits > 0) {
526 // Clear rep's NaN payload.
527 rep &= ~NanPayloadBitMask<T>();
528 CHECK_NE(nan_payload, 0);
529 rep |= nan_payload;
530 }
531 return absl::bit_cast<T>(rep);
532 }
533
534 // Utility for performing a static_cast<> on a std::unique_ptr<>.
535 template <typename Derived, typename Base>
536 std::unique_ptr<Derived> unique_ptr_static_cast(std::unique_ptr<Base> ptr) {
537 return std::unique_ptr<Derived>(static_cast<Derived*>(ptr.release()));
538 }
539
540 int64 Product(absl::Span<const int64> xs);
541
542 // Returns the start indices of consecutive non-overlapping subsequences of `a`
543 // and `b` with the same product, i.e. `(i, j)` so
544 // • a = {a[0 = i_0], ..., a[i_1 - 1], a[i_1], ... , a[i_2 - 1], ...}
545 // • b = {b[0 = j_0], ..., b[j_1 - 1], b[j_1], ... , b[j_2 - 1], ...}
546 // • ∀ k . 0 <= k < CommonFactors(a, b).size - 1 =>
547 // a[i_k] × a[i_k + 1] × ... × a[i_(k+1) - 1] =
548 // b[j_k] × b[j_k + 1] × ... × b[j_(k+1) - 1]
549 // where `CommonFactors(a, b)[CommonFactors(a, b).size - 1] = (a.size, b.size)`
550 //
551 // If input and output are the same, return {(0, 0), {1, 1}, ... {a.size,
552 // b.size}}, otherwise if the given shapes have non-zero size, returns the
553 // bounds of the shortest possible such subsequences; else, returns `{(0, 0),
554 // (a.size, b.size)}`.
555 absl::InlinedVector<std::pair<int64, int64>, 8> CommonFactors(
556 absl::Span<const int64> a, absl::Span<const int64> b);
557
558 struct ConvertedDimensionNumbers {
559 DimensionVector transformed_from_dimensions;
560 DimensionVector untransformed_from_dimensions;
561 DimensionVector to_dimensions;
562 };
563
564 // Convert and unsorted list of dimensions from one shapes dimension sizes to
565 // another shapes dimensions sizes.
566 ConvertedDimensionNumbers ConvertDimensionNumbers(
567 absl::Span<const int64> from_dimensions, absl::Span<const int64> from_sizes,
568 absl::Span<const int64> to_sizes);
569
570 // Removes illegal characters from filenames.
571 string SanitizeFileName(string file_name);
572
573 template <typename C, typename Value>
574 int64 FindIndex(const C& c, Value&& value) {
575 auto it = absl::c_find(c, std::forward<Value>(value));
576 return std::distance(c.begin(), it);
577 }
578
579 template <typename C, typename Value>
580 void InsertAt(C* c, int64_t index, Value&& value) {
581 c->insert(c->begin() + index, std::forward<Value>(value));
582 }
583
584 template <typename C>
585 void EraseAt(C* c, int64_t index) {
586 c->erase(c->begin() + index);
587 }
588
589 template <typename T>
590 std::vector<T> SpanToVector(absl::Span<const T> slice) {
591 return std::vector<T>(slice.begin(), slice.end());
592 }
593
594 template <typename T, size_t N>
595 std::vector<T> InlinedVectorToVector(
596 const absl::InlinedVector<T, N>& inlined_vector) {
597 return std::vector<T>(inlined_vector.begin(), inlined_vector.end());
598 }
599
600 // Returns true if `x` fits in 32-bits.
601 template <typename T>
602 bool IsInt32(T x) {
603 // Following conversion rules: "the value is unchanged if it can be
604 // represented in the destination type (and bit-field width); otherwise, the
605 // value is implementation-defined."
606 return static_cast<int32>(x) == x;
607 }
608
609 template <typename T>
610 Status EraseElementFromVector(std::vector<T>* container, const T& value) {
611 // absl::c_find returns a const_iterator which does not seem to work on
612 // gcc 4.8.4, and this breaks the ubuntu/xla_gpu build bot.
613 auto it = std::find(container->begin(), container->end(), value);
614 TF_RET_CHECK(it != container->end());
615 container->erase(it);
616 return Status::OK();
617 }
618
619 // Utility function which splits a double-precision float (F64) into a pair of
620 // single-precision floating point numbers. The most significant 49 bits (out of
621 // the total 53 available) in the mantissa of the F64 is represented as the
622 // unevaluated sum of two non-overlapping single-precision F32s; the 'high' part
623 // contains 24 bits in its mantissa, and the 'low' part contains 25 bits in its
624 // sign bit and its mantissa.
625 // Note: The resulting representation can still only represent 8-bit exponent
626 // range that is available in F32s (out of a total of 11 exponent bits in F64s).
627 std::pair<float, float> SplitF64ToF32(double x);
628
629 // MakeCleanup(f) returns an RAII cleanup object that calls 'f' in its
630 // destructor. The easiest way to use MakeCleanup is with a lambda argument,
631 // capturing the return value in an 'auto' local variable. Most users will not
632 // need more sophisticated syntax than that.
633 //
634 // Example:
635 // void func() {
636 // auto resource = acquire_resource();
637 // auto cleanup = MakeCleanup([&] { release_resource(resource); });
638 // TF_RETURN_IF_ERROR(...); // phew, calls release_resource!
639 // }
640 //
641 // You can use Cleanup<F> directly, instead of using MakeCleanup and auto,
642 // but there's rarely a reason to do that.
643 //
644 // You can call 'release()' on a Cleanup object to cancel the cleanup
645 //
646 // You probably do not want to capture by reference in the cleanup lambda a
647 // variable that is returned by the function. This can lead to disabling of RVO
648 // at best, and undefined behavior at worst.
649 template <typename F>
650 class Cleanup {
651 public:
652 Cleanup() : released_(true), f_() {}
653
654 template <typename G>
655 explicit Cleanup(G&& f) : f_(std::forward<G>(f)) {}
656
657 Cleanup(Cleanup&& src) : released_(src.is_released()), f_(src.release()) {}
658
659 // Implicitly move-constructible from any compatible Cleanup<G>. The source
660 // will be released as if src.release() were called. A moved-from Cleanup can
661 // be safely destroyed or reassigned.
662 template <typename G>
663 Cleanup(Cleanup<G>&& src) : released_(src.is_released()), f_(src.release()) {}
664
665 // Assignment to a Cleanup object behaves like destroying it and making a new
666 // one in its place, analogous to unique_ptr semantics.
667 Cleanup& operator=(Cleanup&& src) {
668 if (!released_) std::move(f_)();
669 released_ = src.released_;
670 f_ = src.release();
671 return *this;
672 }
673
674 ~Cleanup() {
675 if (!released_) std::move(f_)();
676 }
677
678 // Releases the cleanup function instead of running it. Hint: use
679 // c.release()() to run early.
680 F release() {
681 released_ = true;
682 return std::move(f_);
683 }
684
685 bool is_released() const { return released_; }
686
687 private:
688 static_assert(!std::is_reference<F>::value, "F must not be a reference");
689
690 bool released_ = false;
691 F f_;
692 };
693
694 template <int&... ExplicitParameterBarrier, typename F,
695 typename DecayF = typename std::decay<F>::type>
696 ABSL_MUST_USE_RESULT Cleanup<DecayF> MakeCleanup(F&& f) {
697 return Cleanup<DecayF>(std::forward<F>(f));
698 }
699
700 } // namespace xla
701
702 #define XLA_LOG_LINES(SEV, STRING) \
703 ::xla::LogLines(SEV, STRING, __FILE__, __LINE__)
704
705 #define XLA_VLOG_LINES(LEVEL, STRING) \
706 do { \
707 if (VLOG_IS_ON(LEVEL)) XLA_LOG_LINES(::tensorflow::INFO, STRING); \
708 } while (false);
709
710 // Utility macro that performs the equivalent of what one would expect
711 // LOG_LINES(FATAL, X) to do but can be used at the end of a function that
712 // returns a value without getting a compiler warning that no value is returned.
713 #define XLA_FATAL_LOG(X) \
714 XLA_LOG_LINES(::tensorflow::ERROR, X); \
715 LOG(FATAL) << "Aborting in " << __FUNCTION__ << " due to previous errors.";
716
717 #endif // TENSORFLOW_COMPILER_XLA_UTIL_H_
718