1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Generally useful utility functions that are common to (not specific to any
17 // given part of) the XLA code base.
18
19 #ifndef TENSORFLOW_COMPILER_XLA_UTIL_H_
20 #define TENSORFLOW_COMPILER_XLA_UTIL_H_
21
22 #include <algorithm>
23 #include <string>
24 #include <type_traits>
25 #include <vector>
26
27 #include "absl/algorithm/container.h"
28 #include "absl/base/thread_annotations.h"
29 #include "absl/container/inlined_vector.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_format.h"
32 #include "absl/strings/str_join.h"
33 #include "absl/strings/string_view.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/status.h"
36 #include "tensorflow/compiler/xla/status_macros.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/lib/math/math_util.h"
42 #include "tensorflow/core/lib/strings/numbers.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/platform/macros.h"
45 #include "tensorflow/core/platform/mutex.h"
46 #include "tensorflow/core/platform/protobuf.h"
47 #include "tensorflow/core/platform/types.h"
48
49 namespace xla {
50
51 // Logs the provided status message with a backtrace.
52 //
53 // For use by Status-factories, logs a backtrace at the point where the status
54 // is created, such that we can use --vmodule=util=1 to see all status
55 // creation backtraces.
56 Status WithLogBacktrace(const Status& status);
57
58 // Ranks greater than 8 are very rare, so use InlinedVector<int64, 8> to store
59 // the bounds and indices. And for the rare cases of ranks greater than 8,
60 // the InlinedVector will just behave like an std::vector<> and allocate the
61 // memory to store its values.
62 static constexpr int kInlineRank = 8;
63 using DimensionVector = absl::InlinedVector<int64, kInlineRank>;
64
65 // RAII timer that logs with a given label the wall clock time duration in human
66 // readable form. This differs from base's ElapsedTimer primarily in that it
67 // spits out the human-readable duration form.
68 //
69 // Keeps track of global maximum and cumulative times across all invocations.
70 //
71 // By default, the timing traces are only printed at VLOG(1) and above:
72 //
73 // XLA_SCOPED_LOGGING_TIMER("fooing bar"); // nop if !VLOG_IS_ON(1).
74 //
75 // but you can control this via:
76 //
77 // XLA_SCOPED_LOGGING_TIMER_LEVEL("fooing bar", 2); // nop if !VLOG_IS_ON(2)
78 //
79 #define XLA_SCOPED_LOGGING_TIMER(label) \
80 XLA_SCOPED_LOGGING_TIMER_HELPER(label, 1, __COUNTER__)
81 #define XLA_SCOPED_LOGGING_TIMER_LEVEL(label, level) \
82 XLA_SCOPED_LOGGING_TIMER_HELPER(label, level, __COUNTER__)
83
84 // Helper for implementing macros above. Do not use directly.
85 //
86 // Forces the evaluation of "counter", which we expect is equal to __COUNTER__.
87 #define XLA_SCOPED_LOGGING_TIMER_HELPER(label, level, counter) \
88 XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter)
89
90 // Helper for macros above. Don't use directly.
91 #define XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter) \
92 static ::xla::TimerStats XLA_TimerStats##counter; \
93 ::xla::ScopedLoggingTimer XLA_ScopedLoggingTimerInstance##counter( \
94 label, /*enabled=*/VLOG_IS_ON(level), __FILE__, __LINE__, \
95 &XLA_TimerStats##counter);
96
97 struct TimerStats {
98 tensorflow::mutex stats_mutex;
99 double cumulative_secs ABSL_GUARDED_BY(stats_mutex) = 0;
100 double max_secs ABSL_GUARDED_BY(stats_mutex) = 0;
101 uint64 times_called ABSL_GUARDED_BY(stats_mutex) = 0;
102 };
103
104 // RAII timer for XLA_SCOPED_LOGGING_TIMER and XLA_SCOPED_LOGGING_TIMER_LEVEL
105 // macros above. Recommended usage is via the macros so you don't have to give
106 // the timer a name or worry about calling VLOG_IS_ON yourself.
107 class ScopedLoggingTimer {
108 public:
109 // label: Label to display for logging.
110 // enabled: Whether this timer should do anything at all.
111 // file: Filename to display in logging.
112 // line: Line number to display in logging.
113 // `timer_stats`: unowned non-null pointer which is used to populate the
114 // global timer statistics.
115 ScopedLoggingTimer(const std::string& label, bool enabled, const char* file,
116 int line, TimerStats* timer_stats);
117
118 // Stop the timer and log the tracked time. Timer is disabled after this
119 // function is called.
120 void StopAndLog();
121
122 ~ScopedLoggingTimer();
123
124 private:
125 bool enabled_;
126 const char* file_;
127 int line_;
128 string label_;
129 uint64 start_micros_;
130 TimerStats* timer_stats_;
131 };
132
133 // Given a vector<T>, returns a Span<char> that points at its
134 // internals.
135 //
136 // Warning: if the vector is updated its storage pointer may change, so use this
137 // with caution (ideally in limited scopes with temporary lifetimes).
138 template <typename T>
MutableByteSlice(std::vector<T> * v)139 absl::Span<uint8> MutableByteSlice(std::vector<T>* v) {
140 return absl::Span<uint8>(reinterpret_cast<uint8*>(v->data()),
141 v->size() * sizeof(T));
142 }
143
144 // Turns an immutable slice of type T into an immutable slice of bytes with the
145 // same byte size.
146 template <typename T>
CastToByteSlice(absl::Span<const T> slice)147 absl::Span<const uint8> CastToByteSlice(absl::Span<const T> slice) {
148 return absl::Span<const uint8>(reinterpret_cast<const uint8*>(slice.data()),
149 slice.size() * sizeof(T));
150 }
151
152 // Casts a byte slice to a non-byte type T, checking that the original slice
153 // length is a multiple of sizeof(T).
154 template <typename T>
CastByteSlice(absl::Span<const uint8> slice)155 absl::Span<const T> CastByteSlice(absl::Span<const uint8> slice) {
156 CHECK_EQ(0, slice.size() % sizeof(T));
157 return absl::Span<const T>(reinterpret_cast<const T*>(slice.data()),
158 slice.size() / sizeof(T));
159 }
160
161 // Convenience function to force a vector to convert to an immutable slice.
162 template <typename T>
AsSlice(const std::vector<T> & v)163 absl::Span<const T> AsSlice(const std::vector<T>& v) {
164 return absl::Span<const T>(v);
165 }
166
167 // Converts a mutable vector pointer into a Span of the same
168 // type.
169 template <typename T>
AsMutableSlice(std::vector<T> * v)170 absl::Span<T> AsMutableSlice(std::vector<T>* v) {
171 return absl::Span<T>(v->data(), v->size());
172 }
173
174 // xla::int64 is not the same type as tensorflow::protobuf_int64 in open-source.
175 // Wrapper function that gives an int64 array slice view of a repeated int64
176 // protobuf field.
AsInt64Slice(const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64> & v)177 static inline absl::Span<const int64> AsInt64Slice(
178 const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>& v) {
179 absl::Span<const tensorflow::protobuf_int64> slice(v);
180 return absl::Span<const int64>(reinterpret_cast<const int64*>(slice.data()),
181 slice.size());
182 }
183
184 // TODO(b/29771030): This nop overload was added to simplify the migration of
185 // Shape from a proto to a C++ class. Remove after class has been migrated.
AsInt64Slice(absl::Span<const int64> slice)186 static inline absl::Span<const int64> AsInt64Slice(
187 absl::Span<const int64> slice) {
188 return slice;
189 }
190
191 // As above, but for uint64 types.
AsUInt64Slice(const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_uint64> & v)192 static inline absl::Span<const uint64> AsUInt64Slice(
193 const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_uint64>& v) {
194 absl::Span<const tensorflow::protobuf_uint64> slice(v);
195 return absl::Span<const uint64>(reinterpret_cast<const uint64*>(slice.data()),
196 slice.size());
197 }
198
199 // Compares two containers for equality. Returns true iff the two containers
200 // have the same size and all their elements compare equal using their
201 // operator==. Like std::equal, but forces size equality.
202 template <typename Container1T, typename Container2T>
ContainersEqual(const Container1T & c1,const Container2T & c2)203 bool ContainersEqual(const Container1T& c1, const Container2T& c2) {
204 return ((c1.size() == c2.size()) &&
205 std::equal(std::begin(c1), std::end(c1), std::begin(c2)));
206 }
207
208 template <typename Container1T,
209 typename ElementType = typename Container1T::value_type>
ContainersEqual(const Container1T & c1,std::initializer_list<ElementType> il)210 bool ContainersEqual(const Container1T& c1,
211 std::initializer_list<ElementType> il) {
212 absl::Span<const ElementType> c2{il};
213 return ContainersEqual(c1, c2);
214 }
215
216 // Compares two containers for equality. Returns true iff the two containers
217 // have the same size and all their elements compare equal using the predicate
218 // p. Like std::equal, but forces size equality.
219 template <typename Container1T, typename Container2T, class PredicateT>
ContainersEqual(const Container1T & c1,const Container2T & c2,PredicateT p)220 bool ContainersEqual(const Container1T& c1, const Container2T& c2,
221 PredicateT p) {
222 return ((c1.size() == c2.size()) &&
223 std::equal(std::begin(c1), std::end(c1), std::begin(c2), p));
224 }
225
226 // Performs a copy of count values from src to dest, using different strides for
227 // source and destination. The source starting index is src_base, while the
228 // destination one is dest_base.
229 template <typename D, typename S>
StridedCopy(absl::Span<D> dest,int64 dest_base,int64 dest_stride,absl::Span<const S> src,int64 src_base,int64 src_stride,int64 count)230 void StridedCopy(absl::Span<D> dest, int64 dest_base, int64 dest_stride,
231 absl::Span<const S> src, int64 src_base, int64 src_stride,
232 int64 count) {
233 for (; count > 0; --count, dest_base += dest_stride, src_base += src_stride) {
234 dest[dest_base] = static_cast<D>(src[src_base]);
235 }
236 }
237
238 // Adds some context information to the error message in a
239 // Status. This is useful as Statuses are
240 // propagated upwards.
241 Status AddStatus(Status prior, absl::string_view context);
242 Status AppendStatus(Status prior, absl::string_view context);
243
244 // Status error shorthands -- StrFormat's the arguments to be used as an error
245 // message and returns a status in the canonical error space.
246 template <typename... Args>
InvalidArgument(const absl::FormatSpec<Args...> & format,const Args &...args)247 Status InvalidArgument(const absl::FormatSpec<Args...>& format,
248 const Args&... args) {
249 return WithLogBacktrace(
250 tensorflow::errors::InvalidArgument(absl::StrFormat(format, args...)));
251 }
252 template <typename... Args>
Unimplemented(const absl::FormatSpec<Args...> & format,const Args &...args)253 Status Unimplemented(const absl::FormatSpec<Args...>& format,
254 const Args&... args) {
255 return WithLogBacktrace(
256 tensorflow::errors::Unimplemented(absl::StrFormat(format, args...)));
257 }
258 template <typename... Args>
InternalError(const absl::FormatSpec<Args...> & format,const Args &...args)259 Status InternalError(const absl::FormatSpec<Args...>& format,
260 const Args&... args) {
261 return WithLogBacktrace(
262 tensorflow::errors::Internal(absl::StrFormat(format, args...)));
263 }
264 template <typename... Args>
FailedPrecondition(const absl::FormatSpec<Args...> & format,const Args &...args)265 Status FailedPrecondition(const absl::FormatSpec<Args...>& format,
266 const Args&... args) {
267 return WithLogBacktrace(
268 tensorflow::errors::FailedPrecondition(absl::StrFormat(format, args...)));
269 }
270 template <typename... Args>
Cancelled(const absl::FormatSpec<Args...> & format,const Args &...args)271 Status Cancelled(const absl::FormatSpec<Args...>& format, const Args&... args) {
272 return WithLogBacktrace(
273 tensorflow::errors::Cancelled(absl::StrFormat(format, args...)));
274 }
275 template <typename... Args>
ResourceExhausted(const absl::FormatSpec<Args...> & format,const Args &...args)276 Status ResourceExhausted(const absl::FormatSpec<Args...>& format,
277 const Args&... args) {
278 return WithLogBacktrace(
279 tensorflow::errors::ResourceExhausted(absl::StrFormat(format, args...)));
280 }
281 template <typename... Args>
NotFound(const absl::FormatSpec<Args...> & format,const Args &...args)282 Status NotFound(const absl::FormatSpec<Args...>& format, const Args&... args) {
283 return WithLogBacktrace(
284 tensorflow::errors::NotFound(absl::StrFormat(format, args...)));
285 }
286 template <typename... Args>
Unavailable(const absl::FormatSpec<Args...> & format,const Args &...args)287 Status Unavailable(const absl::FormatSpec<Args...>& format,
288 const Args&... args) {
289 return WithLogBacktrace(
290 tensorflow::errors::Unavailable(absl::StrFormat(format, args...)));
291 }
292 template <typename... Args>
Unknown(const absl::FormatSpec<Args...> & format,const Args &...args)293 Status Unknown(const absl::FormatSpec<Args...>& format, const Args&... args) {
294 return WithLogBacktrace(
295 tensorflow::errors::Unknown(absl::StrFormat(format, args...)));
296 }
297 template <typename... Args>
Internal(const absl::FormatSpec<Args...> & format,const Args &...args)298 Status Internal(const absl::FormatSpec<Args...>& format, const Args&... args) {
299 return WithLogBacktrace(
300 tensorflow::errors::Internal(absl::StrFormat(format, args...)));
301 }
302
303 template <typename... Args>
InvalidArgumentStrCat(Args &&...concat)304 Status InvalidArgumentStrCat(Args&&... concat) {
305 return InvalidArgument("%s", absl::StrCat(std::forward<Args>(concat)...));
306 }
307
308 template <typename... Args>
UnimplementedStrCat(Args &&...concat)309 Status UnimplementedStrCat(Args&&... concat) {
310 return Unimplemented("%s", absl::StrCat(std::forward<Args>(concat)...));
311 }
312
313 template <typename... Args>
InternalErrorStrCat(Args &&...concat)314 Status InternalErrorStrCat(Args&&... concat) {
315 return InternalError("%s", absl::StrCat(std::forward<Args>(concat)...));
316 }
317
318 template <typename... Args>
ResourceExhaustedStrCat(Args &&...concat)319 Status ResourceExhaustedStrCat(Args&&... concat) {
320 return ResourceExhausted("%s", absl::StrCat(std::forward<Args>(concat)...));
321 }
322
323 // Splits the lines of the original, replaces leading whitespace with the prefix
324 // given by "indentation", and returns the string joined by newlines again. As a
325 // side effect, any additional trailing whitespace is removed.
326 //
327 // Note: even different amounts of leading whitespace on different lines will be
328 // uniformly replaced with "indentation".
329 string Reindent(absl::string_view original, absl::string_view indentation);
330
331 template <typename Container>
PositionInContainer(const Container & container,int64 value)332 int64 PositionInContainer(const Container& container, int64 value) {
333 return std::distance(container.begin(), absl::c_find(container, value));
334 }
335
336 // Formats the container as a comma-separated string. StrAppend must support
337 // appending the elements of the container. Prefix is prepended and suffix is
338 // appended to the returned string.
339 template <typename Container>
340 string CommaSeparatedString(const Container& c, const char* prefix = "",
341 const char* suffix = "") {
342 // Not using Join() since the implementation here is simple anyway and this
343 // avoids copying the string to append prefix.
344 string comma_separated = prefix;
345 const char* separator = "";
346 for (const auto& entry : c) {
347 absl::StrAppend(&comma_separated, separator, entry);
348 separator = ", ";
349 }
350 comma_separated += suffix;
351 return comma_separated;
352 }
353
354 // Overload needed to allow the container to be an initializer list. The default
355 // type for T makes an empty initializer list work as well.
356 template <typename T = int>
357 string CommaSeparatedString(const std::initializer_list<T>& c,
358 const char* prefix = "", const char* suffix = "") {
359 return CommaSeparatedString<std::initializer_list<T>>(c, prefix, suffix);
360 }
361
362 // Formats the container in the mathematical notation for a vector, e.g. (1, 3,
363 // 7). StrAppend must support appending the elements of c.
364 template <typename Container>
VectorString(const Container & c)365 string VectorString(const Container& c) {
366 return CommaSeparatedString(c, "(", ")");
367 }
368
369 // Overload needed to allow the container to be an initializer list. The default
370 // type for T makes an empty initializer list work as well.
371 template <typename T = int>
VectorString(const std::initializer_list<T> & c)372 string VectorString(const std::initializer_list<T>& c) {
373 return VectorString<std::initializer_list<T>>(c);
374 }
375
376 // Returns a string which can losslessly round trip to a bfloat.
377 string RoundTripFpToString(tensorflow::bfloat16 value);
378
379 // Returns a string which can losslessly round trip to a fp16.
380 string RoundTripFpToString(Eigen::half value);
381
382 // Returns a string which can losslessly round trip to a float.
383 string RoundTripFpToString(float value);
384
385 // Returns a string which can losslessly round trip to a double.
386 string RoundTripFpToString(double value);
387
388 // Returns a PaddingConfig object that represents no padding for the given rank.
389 PaddingConfig MakeNoPaddingConfig(int64 rank);
390
391 // Returns a PaddingConfig object where 'padding' contains
392 // (low edge padding, high edge padding) pairs for each dimension.
393 PaddingConfig MakeEdgePaddingConfig(
394 absl::Span<const std::pair<int64, int64>> padding);
395
396 // Returns true if the padding configuration has at least one dimension with
397 // non-zero interior padding.
398 bool HasInteriorPadding(const PaddingConfig& config);
399
400 // Imports the templated FloorOfRatio math function from the TensorFlow
401 // namespace, as it is very commonly used.
402 template <typename T>
FloorOfRatio(T dividend,T divisor)403 T FloorOfRatio(T dividend, T divisor) {
404 return tensorflow::MathUtil::FloorOfRatio<T>(dividend, divisor);
405 }
406
407 // Imports the templated CeilOfRatio math function from the TensorFlow
408 // namespace, as it is very commonly used.
409 template <typename T>
CeilOfRatio(T dividend,T divisor)410 T CeilOfRatio(T dividend, T divisor) {
411 return tensorflow::MathUtil::CeilOfRatio<T>(dividend, divisor);
412 }
413
414 // Rounds the value up to a multiple of the divisor by first calling CeilOfRatio
415 // then multiplying by the divisor. For example: RoundUpToNearest(13, 8) => 16
416 template <typename T>
RoundUpToNearest(T value,T divisor)417 T RoundUpToNearest(T value, T divisor) {
418 return CeilOfRatio(value, divisor) * divisor;
419 }
420
421 // Rounds the value down to a multiple of the divisor by first calling
422 // FloorOfRatio then multiplying by the divisor. For example:
423 // RoundDownToNearest(13, 8) => 8
424 template <typename T>
RoundDownToNearest(T value,T divisor)425 T RoundDownToNearest(T value, T divisor) {
426 return FloorOfRatio(value, divisor) * divisor;
427 }
428
429 // Given a number of flops executed in an amount of time, produces a string that
430 // represents the throughput;
431 // e.g. HumanReadableNumFlops(1e9, 1e9) => 1.00GFLOP/s.
432 string HumanReadableNumFlops(double flops, double nanoseconds);
433
434 // Given a number of transcendental ops executed in an amount of time, produces
435 // a string that represents the throughput;
436 // e.g. HumanReadableNumTranscendentalOps(1e9, 1e9) => 1.00GTROP/s.
437 string HumanReadableNumTranscendentalOps(double trops, double nanoseconds);
438
439 // Split the text into multiple lines and log each line with the given
440 // severity, filename, and line number.
441 void LogLines(int sev, absl::string_view text, const char* fname, int lineno);
442
443 template <typename T>
IsPowerOfTwo(T x)444 inline bool IsPowerOfTwo(T x) {
445 static_assert(!std::numeric_limits<T>::is_signed, "unsigned types only");
446 return x != 0 && (x & (x - 1)) == 0;
447 }
448
449 // Returns a mask with "bits" number of least significant bits set.
LsbMaskU32(int bits)450 inline uint32 LsbMaskU32(int bits) {
451 CHECK_GE(bits, 0);
452 return (1U << bits) - 1;
453 }
454
455 // Utility for performing a static_cast<> on a std::unique_ptr<>.
456 template <typename Derived, typename Base>
unique_ptr_static_cast(std::unique_ptr<Base> ptr)457 std::unique_ptr<Derived> unique_ptr_static_cast(std::unique_ptr<Base> ptr) {
458 return std::unique_ptr<Derived>(static_cast<Derived*>(ptr.release()));
459 }
460
461 int64 Product(absl::Span<const int64> xs);
462
463 // Returns the start indices of consecutive non-overlapping subsequences of `a`
464 // and `b` with the same product, i.e. `(i, j)` so
465 // • a = {a[0 = i_0], ..., a[i_1 - 1], a[i_1], ... , a[i_2 - 1], ...}
466 // • b = {b[0 = j_0], ..., b[j_1 - 1], b[j_1], ... , b[j_2 - 1], ...}
467 // • ∀ k . 0 <= k < CommonFactors(a, b).size - 1 =>
468 // a[i_k] × a[i_k + 1] × ... × a[i_(k+1) - 1] =
469 // b[j_k] × b[j_k + 1] × ... × b[j_(k+1) - 1]
470 // where `CommonFactors(a, b)[CommonFactors(a, b).size - 1] = (a.size, b.size)`
471 //
472 // If input and output are the same, return {(0, 0), {1, 1}, ... {a.size,
473 // b.size}}, otherwise if the given shapes have non-zero size, returns the
474 // bounds of the shortest possible such subsequences; else, returns `{(0, 0),
475 // (a.size, b.size)}`.
476 absl::InlinedVector<std::pair<int64, int64>, 8> CommonFactors(
477 absl::Span<const int64> a, absl::Span<const int64> b);
478
479 struct ConvertedDimensionNumbers {
480 DimensionVector transformed_from_dimensions;
481 DimensionVector untransformed_from_dimensions;
482 DimensionVector to_dimensions;
483 };
484
485 // Convert and unsorted list of dimensions from one shapes dimension sizes to
486 // another shapes dimensions sizes.
487 ConvertedDimensionNumbers ConvertDimensionNumbers(
488 absl::Span<const int64> from_dimensions, absl::Span<const int64> from_sizes,
489 absl::Span<const int64> to_sizes);
490
491 // Removes illegal characters from filenames.
492 string SanitizeFileName(string file_name);
493
494 template <typename C, typename Value>
FindIndex(const C & c,Value && value)495 int64 FindIndex(const C& c, Value&& value) {
496 auto it = absl::c_find(c, std::forward<Value>(value));
497 return std::distance(c.begin(), it);
498 }
499
500 template <typename C, typename Value>
InsertAt(C * c,int64 index,Value && value)501 void InsertAt(C* c, int64 index, Value&& value) {
502 c->insert(c->begin() + index, std::forward<Value>(value));
503 }
504
505 template <typename C>
EraseAt(C * c,int64 index)506 void EraseAt(C* c, int64 index) {
507 c->erase(c->begin() + index);
508 }
509
510 template <typename T>
SpanToVector(absl::Span<const T> slice)511 std::vector<T> SpanToVector(absl::Span<const T> slice) {
512 return std::vector<T>(slice.begin(), slice.end());
513 }
514
515 template <typename T, size_t N>
InlinedVectorToVector(const absl::InlinedVector<T,N> & inlined_vector)516 std::vector<T> InlinedVectorToVector(
517 const absl::InlinedVector<T, N>& inlined_vector) {
518 return std::vector<T>(inlined_vector.begin(), inlined_vector.end());
519 }
520
521 // Returns true if `x` fits in 32-bits.
522 template <typename T>
IsInt32(T x)523 bool IsInt32(T x) {
524 // Following conversion rules: "the value is unchanged if it can be
525 // represented in the destination type (and bit-field width); otherwise, the
526 // value is implementation-defined."
527 return static_cast<int32>(x) == x;
528 }
529
530 template <typename T>
EraseElementFromVector(std::vector<T> * container,const T & value)531 Status EraseElementFromVector(std::vector<T>* container, const T& value) {
532 // absl::c_find returns a const_iterator which does not seem to work on
533 // gcc 4.8.4, and this breaks the ubuntu/xla_gpu build bot.
534 auto it = std::find(container->begin(), container->end(), value);
535 TF_RET_CHECK(it != container->end());
536 container->erase(it);
537 return Status::OK();
538 }
539
540 // Utility function which splits a double-precision float (F64) into a pair of
541 // single-precision floating point numbers. The most significant 49 bits (out of
542 // the total 53 available) in the mantissa of the F64 is represented as the
543 // unevaluated sum of two non-overlapping single-precision F32s; the 'high' part
544 // contains 24 bits in its mantissa, and the 'low' part contains 25 bits in its
545 // sign bit and its mantissa.
546 // Note: The resulting representation can still only represent 8-bit exponent
547 // range that is available in F32s (out of a total of 11 exponent bits in F64s).
548 std::pair<float, float> SplitF64ToF32(double x);
549
550 // MakeCleanup(f) returns an RAII cleanup object that calls 'f' in its
551 // destructor. The easiest way to use MakeCleanup is with a lambda argument,
552 // capturing the return value in an 'auto' local variable. Most users will not
553 // need more sophisticated syntax than that.
554 //
555 // Example:
556 // void func() {
557 // auto resource = acquire_resource();
558 // auto cleanup = MakeCleanup([&] { release_resource(resource); });
559 // TF_RETURN_IF_ERROR(...); // phew, calls release_resource!
560 // }
561 //
562 // You can use Cleanup<F> directly, instead of using MakeCleanup and auto,
563 // but there's rarely a reason to do that.
564 //
565 // You can call 'release()' on a Cleanup object to cancel the cleanup
566 //
567 // You probably do not want to capture by reference in the cleanup lambda a
568 // variable that is returned by the function. This can lead to disabling of RVO
569 // at best, and undefined behavior at worst.
570 template <typename F>
571 class Cleanup {
572 public:
Cleanup()573 Cleanup() : released_(true), f_() {}
574
575 template <typename G>
Cleanup(G && f)576 explicit Cleanup(G&& f) : f_(std::forward<G>(f)) {}
577
Cleanup(Cleanup && src)578 Cleanup(Cleanup&& src) : released_(src.is_released()), f_(src.release()) {}
579
580 // Implicitly move-constructible from any compatible Cleanup<G>. The source
581 // will be released as if src.release() were called. A moved-from Cleanup can
582 // be safely destroyed or reassigned.
583 template <typename G>
Cleanup(Cleanup<G> && src)584 Cleanup(Cleanup<G>&& src) : released_(src.is_released()), f_(src.release()) {}
585
586 // Assignment to a Cleanup object behaves like destroying it and making a new
587 // one in its place, analogous to unique_ptr semantics.
588 Cleanup& operator=(Cleanup&& src) {
589 if (!released_) std::move(f_)();
590 released_ = src.released_;
591 f_ = src.release();
592 return *this;
593 }
594
~Cleanup()595 ~Cleanup() {
596 if (!released_) std::move(f_)();
597 }
598
599 // Releases the cleanup function instead of running it. Hint: use
600 // c.release()() to run early.
release()601 F release() {
602 released_ = true;
603 return std::move(f_);
604 }
605
is_released()606 bool is_released() const { return released_; }
607
608 private:
609 static_assert(!std::is_reference<F>::value, "F must not be a reference");
610
611 bool released_ = false;
612 F f_;
613 };
614
615 template <int&... ExplicitParameterBarrier, typename F,
616 typename DecayF = typename std::decay<F>::type>
MakeCleanup(F && f)617 ABSL_MUST_USE_RESULT Cleanup<DecayF> MakeCleanup(F&& f) {
618 return Cleanup<DecayF>(std::forward<F>(f));
619 }
620
621 } // namespace xla
622
623 #define XLA_LOG_LINES(SEV, STRING) \
624 ::xla::LogLines(SEV, STRING, __FILE__, __LINE__)
625
626 #define XLA_VLOG_LINES(LEVEL, STRING) \
627 do { \
628 if (VLOG_IS_ON(LEVEL)) XLA_LOG_LINES(::tensorflow::INFO, STRING); \
629 } while (false);
630
631 // Utility macro that performs the equivalent of what one would expect
632 // LOG_LINES(FATAL, X) to do but can be used at the end of a function that
633 // returns a value without getting a compiler warning that no value is returned.
634 #define XLA_FATAL_LOG(X) \
635 XLA_LOG_LINES(::tensorflow::ERROR, X); \
636 LOG(FATAL) << "Aborting in " << __FUNCTION__ << " due to previous errors.";
637
638 #endif // TENSORFLOW_COMPILER_XLA_UTIL_H_
639