1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/literal_comparison.h"
17
18 #include <unistd.h>
19
20 #include <cmath>
21 #include <vector>
22
23 #include "absl/base/casts.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/platform/env.h"
29
30 using absl::StrAppend;
31 using absl::StrAppendFormat;
32 using absl::StrCat;
33
34 namespace xla {
35 namespace literal_comparison {
36 namespace {
37
38 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
39 // able to transparently access the raw 16-bit value contained within.
40 template <typename T>
GetRawValue(T val)41 T GetRawValue(T val) {
42 return val;
43 }
GetRawValue(Eigen::half val)44 uint16 GetRawValue(Eigen::half val) {
45 return Eigen::numext::bit_cast<uint16>(val);
46 }
47
48 // Helper function for comparing a floating point type, FloatT, bitwise equal
49 // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
50 // -- on miscompare, a nice error message is given in the AssertionFailure.
51 template <typename FloatT, typename UnsignedT>
CompareFloatsBitwiseEqual(FloatT lhs,FloatT rhs,absl::Span<const int64> multi_index)52 bool CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs,
53 absl::Span<const int64> multi_index) {
54 auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
55 auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
56 return ulhs == urhs;
57 }
58
59 // Templated comparator that specializes for float equality comparison with the
60 // bitwise helper above (this is the un-specialized fallback, to just use the
61 // default gunit implementation).
62 template <typename NativeT>
CompareEqual(NativeT lhs,NativeT rhs,absl::Span<const int64> multi_index)63 bool CompareEqual(NativeT lhs, NativeT rhs,
64 absl::Span<const int64> multi_index) {
65 return lhs == rhs;
66 }
67
68 // Specializations for floating types that do bitwise comparisons when equality
69 // comparison is requested.
70 template <>
CompareEqual(bfloat16 lhs,bfloat16 rhs,absl::Span<const int64> multi_index)71 bool CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs,
72 absl::Span<const int64> multi_index) {
73 return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs, multi_index);
74 }
75 template <>
CompareEqual(Eigen::half lhs,Eigen::half rhs,absl::Span<const int64> multi_index)76 bool CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs,
77 absl::Span<const int64> multi_index) {
78 return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs, multi_index);
79 }
80 template <>
CompareEqual(float lhs,float rhs,absl::Span<const int64> multi_index)81 bool CompareEqual<float>(float lhs, float rhs,
82 absl::Span<const int64> multi_index) {
83 return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs, multi_index);
84 }
85 template <>
CompareEqual(double lhs,double rhs,absl::Span<const int64> multi_index)86 bool CompareEqual<double>(double lhs, double rhs,
87 absl::Span<const int64> multi_index) {
88 return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs, multi_index);
89 }
90 template <>
CompareEqual(complex64 lhs,complex64 rhs,absl::Span<const int64> multi_index)91 bool CompareEqual<complex64>(complex64 lhs, complex64 rhs,
92 absl::Span<const int64> multi_index) {
93 return CompareEqual<float>(lhs.real(), rhs.real(), multi_index) &&
94 CompareEqual<float>(lhs.imag(), rhs.imag(), multi_index);
95 }
96 template <>
CompareEqual(complex128 lhs,complex128 rhs,absl::Span<const int64> multi_index)97 bool CompareEqual<complex128>(complex128 lhs, complex128 rhs,
98 absl::Span<const int64> multi_index) {
99 return CompareEqual<double>(lhs.real(), rhs.real(), multi_index) &&
100 CompareEqual<double>(lhs.imag(), rhs.imag(), multi_index);
101 }
102
103 template <typename NativeT, typename UnsignedT>
MakeBitwiseErrorStatus(NativeT lhs,NativeT rhs,absl::Span<const int64> multi_index)104 Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs,
105 absl::Span<const int64> multi_index) {
106 auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
107 auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
108 auto lhs_double = static_cast<double>(lhs);
109 auto rhs_double = static_cast<double>(rhs);
110 return InvalidArgument(
111 "floating values are not bitwise-equal; and equality testing "
112 "was requested: %s=%s=%a vs %s=%s=%a at array index %s",
113 StrCat(absl::Hex(ulhs)), RoundTripFpToString(lhs), lhs_double,
114 StrCat(absl::Hex(urhs)), RoundTripFpToString(rhs), rhs_double,
115 LiteralUtil::MultiIndexAsString(multi_index));
116 }
117
118 template <typename NativeT>
MakeErrorStatus(NativeT lhs,NativeT rhs,absl::Span<const int64> multi_index)119 Status MakeErrorStatus(NativeT lhs, NativeT rhs,
120 absl::Span<const int64> multi_index) {
121 return InvalidArgument(
122 "first mismatch at array index %s:\n expected value: %s\n actual "
123 "value: %s",
124 LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs));
125 }
126
127 template <>
MakeErrorStatus(bfloat16 lhs,bfloat16 rhs,absl::Span<const int64> multi_index)128 Status MakeErrorStatus(bfloat16 lhs, bfloat16 rhs,
129 absl::Span<const int64> multi_index) {
130 return MakeBitwiseErrorStatus<bfloat16, uint16>(lhs, rhs, multi_index);
131 }
132 template <>
MakeErrorStatus(Eigen::half lhs,Eigen::half rhs,absl::Span<const int64> multi_index)133 Status MakeErrorStatus(Eigen::half lhs, Eigen::half rhs,
134 absl::Span<const int64> multi_index) {
135 return MakeBitwiseErrorStatus<Eigen::half, uint16>(lhs, rhs, multi_index);
136 }
137 template <>
MakeErrorStatus(float lhs,float rhs,absl::Span<const int64> multi_index)138 Status MakeErrorStatus(float lhs, float rhs,
139 absl::Span<const int64> multi_index) {
140 return MakeBitwiseErrorStatus<float, uint32>(lhs, rhs, multi_index);
141 }
142 template <>
MakeErrorStatus(double lhs,double rhs,absl::Span<const int64> multi_index)143 Status MakeErrorStatus(double lhs, double rhs,
144 absl::Span<const int64> multi_index) {
145 return MakeBitwiseErrorStatus<double, uint64>(lhs, rhs, multi_index);
146 }
147 template <>
MakeErrorStatus(complex64 lhs,complex64 rhs,absl::Span<const int64> multi_index)148 Status MakeErrorStatus(complex64 lhs, complex64 rhs,
149 absl::Span<const int64> multi_index) {
150 if (!CompareEqual<float>(lhs.real(), rhs.real(), multi_index)) {
151 return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
152 }
153 return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
154 }
155 template <>
MakeErrorStatus(complex128 lhs,complex128 rhs,absl::Span<const int64> multi_index)156 Status MakeErrorStatus(complex128 lhs, complex128 rhs,
157 absl::Span<const int64> multi_index) {
158 if (!CompareEqual<double>(lhs.real(), rhs.real(), multi_index)) {
159 return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
160 }
161 return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
162 }
163
164 // A recursive function which iterates through every index of expected and
165 // actual literal and compares their values elementwise. Returns true if all
166 // elements are equal. Mismatched must either be:
167 // - a literal of booleans that has the same shape as expected and actual. In
168 // this case, each index in mismatched will be set to true if expected does
169 // not equal actual at that index and false if there are equal.
170 // - nullptr. In this case, the function will return once any mismatch is
171 // found between expected and actual.
172 template <typename NativeT>
Equal(LiteralSlice expected,LiteralSlice actual,absl::Span<int64> multi_index,int64_t dimension,Literal * mismatched=nullptr)173 Status Equal(LiteralSlice expected, LiteralSlice actual,
174 absl::Span<int64> multi_index, int64_t dimension,
175 Literal* mismatched = nullptr) {
176 if (dimension == expected.shape().dimensions_size()) {
177 NativeT expected_value = expected.Get<NativeT>(multi_index);
178 NativeT actual_value = actual.Get<NativeT>(multi_index);
179 bool result =
180 CompareEqual<NativeT>(expected_value, actual_value, multi_index);
181 if (mismatched) {
182 mismatched->Set<bool>(multi_index, !result);
183 }
184 return result ? Status::OK()
185 : MakeErrorStatus<NativeT>(expected_value, actual_value,
186 multi_index);
187 }
188
189 Status result;
190 for (int64_t i = 0; i < expected.shape().dimensions(dimension); ++i) {
191 multi_index[dimension] = i;
192 if (mismatched != nullptr) {
193 result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1,
194 mismatched));
195 } else {
196 TF_RETURN_IF_ERROR(Equal<NativeT>(expected, actual, multi_index,
197 dimension + 1, mismatched));
198 }
199 }
200 return result;
201 }
202
203 // Gets the total element count. For tuples, this is not the count of tuple
204 // elements, but the sum of elements of each tuple element.
RecursiveElementCount(const Shape & shape)205 int64 RecursiveElementCount(const Shape& shape) {
206 if (shape.IsTuple()) {
207 const int64_t tuple_elements = ShapeUtil::TupleElementCount(shape);
208 int64_t total = 0;
209 for (int64_t i = 0; i < tuple_elements; ++i) {
210 total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i));
211 }
212 return total;
213 } else if (shape.IsArray()) {
214 return ShapeUtil::ElementsIn(shape);
215 } else {
216 return 0;
217 }
218 }
219
220 // Returns whether the given value is infinity.
221 template <typename NativeT>
IsInf(NativeT val)222 bool IsInf(NativeT val) {
223 return Eigen::numext::isinf(val);
224 }
225 // Returns whether the given value is nan.
226 template <typename NativeT>
IsNan(NativeT value)227 bool IsNan(NativeT value) {
228 return Eigen::numext::isnan(value);
229 }
230
231 // Converts the given floating-point value to a string.
FpValueToString(bfloat16 value)232 string FpValueToString(bfloat16 value) {
233 return absl::StrFormat("%10.4g", static_cast<double>(value));
234 }
235
FpValueToString(half value)236 string FpValueToString(half value) {
237 return absl::StrFormat("%11.5g", static_cast<double>(value));
238 }
239
FpValueToString(float value)240 string FpValueToString(float value) {
241 return absl::StrFormat("%15.9g", static_cast<double>(value));
242 }
243
FpValueToString(double value)244 string FpValueToString(double value) {
245 return absl::StrFormat("%24.17g", value);
246 }
247
FpValueToString(complex64 value)248 string FpValueToString(complex64 value) {
249 return absl::StrCat(FpValueToString(value.real()), " + ",
250 FpValueToString(value.imag()));
251 }
252
FpValueToString(complex128 value)253 string FpValueToString(complex128 value) {
254 return absl::StrCat(FpValueToString(value.real()), " + ",
255 FpValueToString(value.imag()));
256 }
257
258 // A wrapper of std::abs to include data types that are not supported by
259 // std::abs, in particular, bfloat16 and half.
260 template <typename NativeT>
FpAbsoluteValue(NativeT value)261 double FpAbsoluteValue(NativeT value) {
262 return std::abs(value);
263 }
264
265 template <>
FpAbsoluteValue(bfloat16 value)266 double FpAbsoluteValue(bfloat16 value) {
267 return FpAbsoluteValue<float>(static_cast<float>(value));
268 }
269
270 template <>
FpAbsoluteValue(half value)271 double FpAbsoluteValue(half value) {
272 return FpAbsoluteValue<float>(static_cast<float>(value));
273 }
274
275 // Helper class for comparing floating-point literals within an error bound.
276 template <typename NativeT>
277 class NearComparator {
278 public:
279 // Compares the two array literals elementwise and returns a comparison
280 // result. The comparison is ok() if all actual and expected elements are
281 // within the given error bound. In case of error, the status contains a
282 // detailed message about the discrepancy.
Compare(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,ErrorSpec error,bool detailed_message,const MiscompareCallback & miscompare_callback)283 static Status Compare(const LiteralSlice& expected,
284 const LiteralSlice& actual,
285 const ShapeIndex& shape_index, ErrorSpec error,
286 bool detailed_message,
287 const MiscompareCallback& miscompare_callback) {
288 NearComparator<NativeT> comparator(expected, actual, shape_index, error,
289 detailed_message, miscompare_callback);
290 return comparator.Run();
291 }
292
293 private:
294 // Data structure encapsulating metadata about a single element mismatch.
295 struct Mismatch {
296 NativeT actual;
297 NativeT expected;
298 double rel_error;
299 double abs_error;
300
301 // The linear index of the failure within the shape. This linear index is
302 // from the 'actual' literal.
303 int64 linear_index;
304
operator <xla::literal_comparison::__anon7c9aba5a0111::NearComparator::Mismatch305 bool operator<(const Mismatch& other) const {
306 return rel_error < other.rel_error;
307 }
308
ToStringxla::literal_comparison::__anon7c9aba5a0111::NearComparator::Mismatch309 string ToString(const Shape& shape) const {
310 return absl::StrFormat(
311 "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
312 FpValueToString(actual), FpValueToString(expected),
313 LiteralUtil::MultiIndexAsString(
314 IndexUtil::LinearIndexToMultidimensionalIndex(shape,
315 linear_index)),
316 rel_error, abs_error);
317 }
318 };
319
NearComparator(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,ErrorSpec error,bool detailed_message,const MiscompareCallback & miscompare_callback)320 NearComparator(const LiteralSlice& expected, const LiteralSlice& actual,
321 const ShapeIndex& shape_index, ErrorSpec error,
322 bool detailed_message,
323 const MiscompareCallback& miscompare_callback)
324 : expected_(expected),
325 actual_(actual),
326 shape_index_(shape_index),
327 error_(error),
328 detailed_message_(detailed_message),
329 miscompare_callback_(miscompare_callback),
330 abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}),
331 abs_error_buckets_(kErrorBucketBounds.size(), 0),
332 rel_error_buckets_(kErrorBucketBounds.size(), 0) {}
333
334 // Runs the comparison between expected and actual literals.
Run()335 Status Run() {
336 // If the shapes mismatch, we simply fail the expectation instead of
337 // printing out data, as it's a type error rather than a value error.
338 TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape()));
339 if (!expected_.shape().IsArray()) {
340 return InvalidArgument("Expected array shape; got %s.",
341 ShapeUtil::HumanString(expected_.shape()));
342 }
343
344 mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED));
345 mismatches_.PopulateWithValue(false);
346
347 CompareLiterals();
348
349 if (num_mismatches_ == 0) {
350 return Status::OK();
351 } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) {
352 miscompare_callback_(expected_, actual_, mismatches_, shape_index_);
353 }
354 return InvalidArgument("%s", ErrorMessage());
355 }
356
357 // Insert the given absolute value into the absolute value bucket vector. The
358 // bounds of the buckets are given by kAbsValueBucketBounds.
UpdateAbsValueBucket(NativeT value,bool is_mismatch)359 void UpdateAbsValueBucket(NativeT value, bool is_mismatch) {
360 // Adjust the bucket containing the absolute values of the 'actual'
361 // elements.
362 const double abs_value = FpAbsoluteValue(value);
363 for (int i = 0; i < abs_value_buckets_.size(); ++i) {
364 if (i == abs_value_buckets_.size() - 1 ||
365 (abs_value >= kAbsValueBucketBounds[i] &&
366 abs_value < kAbsValueBucketBounds[i + 1])) {
367 // The first value of the pair is the count of elements in the bucket,
368 // the second is the count of mismatches in the bucket.
369 abs_value_buckets_[i].first++;
370 if (is_mismatch) {
371 abs_value_buckets_[i].second++;
372 }
373 return;
374 }
375 }
376 }
377
378 // Insert the given error into the given error bucket vector.
UpdateErrorBucket(double error,absl::Span<int64> error_buckets)379 void UpdateErrorBucket(double error, absl::Span<int64> error_buckets) {
380 CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size());
381 for (int i = 0; i < error_buckets.size(); ++i) {
382 if (error >= kErrorBucketBounds[i]) {
383 error_buckets[i]++;
384 }
385 }
386 }
387
388 // Compares the two given elements from the expected and actual literals at
389 // the given literal_index and keeps track of various mismatch statistics.
390 template <typename T>
CompareValues(T expected,T actual,int64_t linear_index)391 void CompareValues(T expected, T actual, int64_t linear_index) {
392 double abs_error;
393 double rel_error;
394 if (CompareEqual<T>(expected, actual, {linear_index})) {
395 abs_error = 0;
396 rel_error = 0;
397 } else if (IsNan(expected) || IsNan(actual)) {
398 if ((!error_.relaxed_nans && IsNan(expected) != IsNan(actual)) ||
399 (error_.relaxed_nans && !IsNan(expected) && IsNan(actual))) {
400 num_nan_mismatches_++;
401 // A nan mismatch is considered to have infinite error. rel_error is
402 // used for sorting a std::set of the top mismatches, and a nan value
403 // here will result in undefined behavior because nan's do not satisfy
404 // the strict weak ordering requirement of std containers.
405 abs_error = std::numeric_limits<float>::infinity();
406 rel_error = std::numeric_limits<float>::infinity();
407 } else {
408 abs_error = 0;
409 rel_error = 0;
410 }
411 } else if (IsInf(actual) && !IsInf(expected) && error_.fewer_infs_ok) {
412 // `fewer_infs_ok` gives us the option of comparing as though `actual`
413 // were float_max/min rather than inf.
414 T actual_finite = actual > T{0} ? std::numeric_limits<T>::max()
415 : std::numeric_limits<T>::lowest();
416 abs_error = FpAbsoluteValue(actual_finite - expected);
417
418 // Avoid division by 0 even though it's well-defined because ubsan can be
419 // configured to treat this as a fatal error.
420 if (expected != T{0}) {
421 rel_error = abs_error / FpAbsoluteValue(expected);
422 } else {
423 rel_error = std::numeric_limits<float>::infinity();
424 }
425 } else if (IsInf(expected) || IsInf(actual)) {
426 // If either the expected or actual value is infinity but not both,
427 // then both absolute and relative error are regarded as infinity.
428 CHECK(!CompareEqual(expected, actual, {linear_index}));
429 abs_error = std::numeric_limits<float>::infinity();
430 rel_error = std::numeric_limits<float>::infinity();
431 } else {
432 abs_error = FpAbsoluteValue(actual - expected);
433
434 // Avoid division by 0 even though it's well-defined because ubsan can be
435 // configured to treat this as a fatal error.
436 if (expected != T{0}) {
437 rel_error = abs_error / FpAbsoluteValue(expected);
438 } else {
439 rel_error = std::numeric_limits<float>::infinity();
440 }
441 }
442 const bool is_abs_mismatch = abs_error > error_.abs;
443 const bool is_rel_mismatch = rel_error > error_.rel;
444 const bool is_mismatch = is_abs_mismatch && is_rel_mismatch;
445
446 // Update the error of the relative bucket only if the *absolute* error
447 // bound is exceeded and vice versa.
448 if (is_abs_mismatch) {
449 num_abs_mismatches_++;
450 UpdateErrorBucket(rel_error, absl::MakeSpan(rel_error_buckets_));
451 }
452 if (is_rel_mismatch) {
453 num_rel_mismatches_++;
454 UpdateErrorBucket(abs_error, absl::MakeSpan(abs_error_buckets_));
455 }
456
457 UpdateAbsValueBucket(actual, is_mismatch);
458
459 if (!is_mismatch) {
460 return;
461 }
462
463 num_mismatches_++;
464
465 // Keep track of the kTopRelativeErrorCount relative error mismatches.
466 if (top_rel_mismatches_.size() < kTopRelativeErrorCount ||
467 rel_error > top_rel_mismatches_.begin()->rel_error) {
468 Mismatch mismatch = {actual, expected, rel_error, abs_error,
469 linear_index};
470 top_rel_mismatches_.insert(mismatch);
471 if (top_rel_mismatches_.size() > kTopRelativeErrorCount) {
472 top_rel_mismatches_.erase(top_rel_mismatches_.begin());
473 }
474 }
475
476 mismatches_.data<bool>()[linear_index] = true;
477 }
478
479 // For complex types, we compare real and imaginary parts individually.
CompareValues(complex64 expected,complex64 actual,int64_t linear_index)480 void CompareValues(complex64 expected, complex64 actual,
481 int64_t linear_index) {
482 const auto both_parts_mismatch = num_mismatches_ + 2;
483 CompareValues<float>(expected.real(), actual.real(), linear_index);
484 CompareValues<float>(expected.imag(), actual.imag(), linear_index);
485 if (num_mismatches_ == both_parts_mismatch) {
486 // The mismatch counter had been incremented by each CompareValues() call,
487 // which means that both real and imaginary parts of the passed-in complex
488 // values are different. However, the counter should reflect a single
489 // mismatch between these complex values.
490 num_mismatches_--;
491 }
492 }
493
CompareValues(complex128 expected,complex128 actual,int64_t linear_index)494 void CompareValues(complex128 expected, complex128 actual,
495 int64_t linear_index) {
496 const auto both_parts_mismatch = num_mismatches_ + 2;
497 CompareValues<double>(expected.real(), actual.real(), linear_index);
498 CompareValues<double>(expected.imag(), actual.imag(), linear_index);
499 if (num_mismatches_ == both_parts_mismatch) {
500 // The mismatch counter had been incremented by each CompareValues() call,
501 // which means that both real and imaginary parts of the passed-in complex
502 // values are different. However, the counter should reflect a single
503 // mismatch between these complex values.
504 num_mismatches_--;
505 }
506 }
507
508 // Compares the two literals elementwise.
CompareLiterals()509 void CompareLiterals() {
510 // Fast path optimization for the case were layouts match.
511 if (LayoutUtil::Equal(actual_.shape().layout(),
512 expected_.shape().layout())) {
513 absl::Span<const NativeT> expected_data = expected_.data<NativeT>();
514 absl::Span<const NativeT> actual_data = actual_.data<NativeT>();
515 const int64_t len = expected_data.size();
516 for (int64_t i = 0; i < len; ++i) {
517 CompareValues(expected_data[i], actual_data[i], i);
518 }
519 return;
520 }
521 std::vector<int64> multi_index(actual_.shape().rank(), 0);
522 CompareLiteralsSlow(0, &multi_index);
523 }
524
525 // Slow path for CompareLiterals when 'actual' and 'expected' literals have
526 // different layouts. In this case, multidimensional indices are constructed
527 // and indexed for each element.
CompareLiteralsSlow(int64_t dimension,std::vector<int64> * multi_index)528 void CompareLiteralsSlow(int64_t dimension, std::vector<int64>* multi_index) {
529 if (dimension == multi_index->size()) {
530 CompareValues(expected_.Get<NativeT>(*multi_index),
531 actual_.Get<NativeT>(*multi_index),
532 IndexUtil::MultidimensionalIndexToLinearIndex(
533 actual_.shape(), *multi_index));
534 } else {
535 for (int64_t i = 0; i < expected_.shape().dimensions(dimension); ++i) {
536 (*multi_index)[dimension] = i;
537 CompareLiteralsSlow(dimension + 1, multi_index);
538 }
539 }
540 }
541
542 // Returns an error message string with a detailed breakdown of the
543 // mismatches. Called after calling Run().
ErrorMessage()544 string ErrorMessage() {
545 string out;
546 int64_t element_count = ShapeUtil::ElementsIn(actual_.shape());
547
548 auto percent_string = [](float a, float b) {
549 float pct = b == 0.0 ? 0.0 : 100.0 * a / b;
550 return absl::StrFormat("%0.4f%%", pct);
551 };
552
553 StrAppendFormat(
554 &out,
555 "\nMismatch count %d (%s) in shape %s (%d elements), abs bound "
556 "%g, rel bound %g\n",
557 num_mismatches_, percent_string(num_mismatches_, element_count),
558 ShapeUtil::HumanString(actual_.shape()),
559 ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel);
560 if (num_nan_mismatches_ > 0) {
561 StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n");
562 }
563 StrAppendFormat(&out, "Top relative error mismatches:\n");
564 for (auto it = top_rel_mismatches_.rbegin();
565 it != top_rel_mismatches_.rend(); ++it) {
566 StrAppend(&out, " ", it->ToString(actual_.shape()), "\n");
567 }
568
569 if (!detailed_message_) {
570 return out;
571 }
572
573 StrAppend(&out, "Absolute magnitude breakdown of actual values:\n");
574 CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size());
575 for (int i = 0; i < abs_value_buckets_.size(); ++i) {
576 const int64_t bucket_size = abs_value_buckets_[i].first;
577 const int64_t bucket_mismatches = abs_value_buckets_[i].second;
578 string mismatch_str =
579 bucket_mismatches > 0
580 ? absl::StrFormat(", mismatches %d", bucket_mismatches)
581 : "";
582 StrAppendFormat(&out, " %-6g <= x < %-6g : %7d (%9s)%s\n",
583 kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1],
584 bucket_size, percent_string(bucket_size, element_count),
585 mismatch_str);
586 }
587
588 auto print_accum_buckets = [&](const string& header, int64_t total,
589 absl::Span<const int64> buckets) {
590 StrAppend(&out, header, ":\n");
591 StrAppendFormat(&out, " < %-6g : %7d (%s)\n", kErrorBucketBounds[0],
592 total - buckets[0],
593 percent_string(total - buckets[0], total));
594 CHECK_EQ(buckets.size(), kErrorBucketBounds.size());
595 for (int i = 0; i < kErrorBucketBounds.size(); ++i) {
596 StrAppendFormat(&out, " >= %-6g : %7d (%s)\n", kErrorBucketBounds[i],
597 buckets[i], percent_string(buckets[i], total));
598 }
599 };
600 StrAppendFormat(&out, "Elements exceeding abs error bound %g: %d (%s)\n",
601 error_.abs, num_abs_mismatches_,
602 percent_string(num_abs_mismatches_, element_count));
603 print_accum_buckets(
604 "Relative error breakdown of elements exceeding abs error bound",
605 num_abs_mismatches_, rel_error_buckets_);
606 StrAppendFormat(&out, "Elements exceeding rel error bound %g: %d (%s)\n",
607 error_.rel, num_rel_mismatches_,
608 percent_string(num_rel_mismatches_, element_count));
609 print_accum_buckets(
610 "Absolute error breakdown of elements exceeding rel error bound",
611 num_rel_mismatches_, abs_error_buckets_);
612 return out;
613 }
614
615 // 'actual' and 'expected' literals being compared.
616 LiteralSlice expected_;
617 LiteralSlice actual_;
618
619 // The shape index of the LiteralSlice that is being compared.
620 ShapeIndex shape_index_;
621
622 // The error bounds of the comparison.
623 ErrorSpec error_;
624
625 // Whether to include detailed breakdown of mismatches in the error message.
626 bool detailed_message_;
627
628 // Callback to invoke on miscompare.
629 MiscompareCallback miscompare_callback_;
630
631 // Number of element mismatches encountered so far.
632 int64 num_mismatches_ = 0;
633
634 // Number of elements with a nan mismatch.
635 int64 num_nan_mismatches_ = 0;
636
637 // Number of elements which exceed the absolute/relative error bound.
638 int64 num_abs_mismatches_ = 0;
639 int64 num_rel_mismatches_ = 0;
640
641 // A Literal containing which elements did not match in the expected and
642 // actual literals. mismatches_ contains PREDs and is of the same sizes as
643 // the comparison literals.
644 Literal mismatches_;
645
646 // The number of mismatches to report in the output, sorted by relative error
647 // magnitude.
648 static constexpr int64_t kTopRelativeErrorCount = 5;
649
650 // The set of mismatches with the largest relative error. The size of this set
651 // is bounded by kTopRelativeErrorCount.
652 std::multiset<Mismatch> top_rel_mismatches_;
653
654 // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the
655 // bounds of these buckets. abs_value_buckets_ contains a pair for each
656 // bucket: the element count and failure count.
657 static constexpr std::array<float, 7> kAbsValueBucketBounds = {
658 0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits<float>::infinity()};
659 std::vector<std::pair<int64, int64>> abs_value_buckets_;
660
661 // Buckets for relative and absolute errors. The relative error buckets only
662 // contains those elements which exceed the *absolute* error bound, and vice
663 // versa. This makes it easy to see the effect of adjusting the relative (or
664 // absolute) error bound on the success of the comparison. kErrorBucketBounds
665 // are the lower bounds of the buckets in both vectors. The error buckets are
666 // a cumulative distribution so an error value may appear in more than one
667 // bucket. For example an error value of 0.003 may appear in the buckets
668 // bounded by 0.01, 0.1, and 1.0.
669 static constexpr std::array<float, 5> kErrorBucketBounds = {0.0001, 0.001,
670 0.01, 0.1, 1};
671 std::vector<int64> abs_error_buckets_;
672 std::vector<int64> rel_error_buckets_;
673 };
674
675 template <typename NativeT>
676 constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds;
677 template <typename NativeT>
678 constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
679
EqualHelper(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,const MiscompareCallback & miscompare_callback)680 Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual,
681 const ShapeIndex& shape_index,
682 const MiscompareCallback& miscompare_callback) {
683 TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
684
685 Status result;
686 if (expected.shape().IsTuple()) {
687 ShapeIndex next_index = shape_index;
688 for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
689 next_index.push_back(i);
690 Status tuple_result =
691 EqualHelper(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}),
692 next_index, miscompare_callback);
693 if (miscompare_callback) {
694 result.Update(tuple_result);
695 } else {
696 TF_RETURN_IF_ERROR(tuple_result);
697 }
698 next_index.pop_back();
699 }
700 } else {
701 std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
702 auto index = absl::MakeSpan(multi_index);
703
704 Shape unequal_shape = ShapeUtil::MakeShape(PrimitiveType::PRED,
705 expected.shape().dimensions());
706 Literal miscompared(unequal_shape);
707 Literal* miscompared_ptr =
708 (miscompare_callback == nullptr ? nullptr : &miscompared);
709
710 switch (expected.shape().element_type()) {
711 case PRED:
712 result = Equal<bool>(expected, actual, index, 0, miscompared_ptr);
713 break;
714 case S8:
715 result = Equal<int8>(expected, actual, index, 0, miscompared_ptr);
716 break;
717 case S16:
718 result = Equal<int16>(expected, actual, index, 0, miscompared_ptr);
719 break;
720 case S32:
721 result = Equal<int32>(expected, actual, index, 0, miscompared_ptr);
722 break;
723 case S64:
724 result = Equal<int64>(expected, actual, index, 0, miscompared_ptr);
725 break;
726 case U8:
727 result = Equal<uint8>(expected, actual, index, 0, miscompared_ptr);
728 break;
729 case U16:
730 result = Equal<uint16>(expected, actual, index, 0, miscompared_ptr);
731 break;
732 case U32:
733 result = Equal<uint32>(expected, actual, index, 0, miscompared_ptr);
734 break;
735 case U64:
736 result = Equal<uint64>(expected, actual, index, 0, miscompared_ptr);
737 break;
738 case BF16:
739 result = Equal<bfloat16>(expected, actual, index, 0, miscompared_ptr);
740 break;
741 case F16:
742 result = Equal<half>(expected, actual, index, 0, miscompared_ptr);
743 break;
744 case F32:
745 result = Equal<float>(expected, actual, index, 0, miscompared_ptr);
746 break;
747 case F64:
748 result = Equal<double>(expected, actual, index, 0, miscompared_ptr);
749 break;
750 case C64:
751 result = Equal<complex64>(expected, actual, index, 0, miscompared_ptr);
752 break;
753 case C128:
754 result = Equal<complex128>(expected, actual, index, 0, miscompared_ptr);
755 break;
756 case TOKEN:
757 // Tokens have no on-device representation and are trivially equal.
758 return Status::OK();
759 default:
760 LOG(FATAL) << "Unsupported primitive type: "
761 << PrimitiveType_Name(expected.shape().element_type());
762 }
763
764 if (!result.ok() && miscompare_callback) {
765 miscompare_callback(expected, actual, LiteralSlice(miscompared),
766 shape_index);
767 }
768 }
769
770 return result;
771 }
772
773 // Helper function for comparing two literals for nearness. Handles tuple-shapes
774 // via recursion. shape_index is the ShapeIndex of expected (or actual)
775 // currently being compared.
NearHelper(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,const ErrorSpec & error,absl::optional<bool> detailed_message,const MiscompareCallback & miscompare_callback)776 Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
777 const ShapeIndex& shape_index, const ErrorSpec& error,
778 absl::optional<bool> detailed_message,
779 const MiscompareCallback& miscompare_callback) {
780 TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
781
782 if (expected.shape().IsTuple()) {
783 Status return_status;
784 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(expected.shape());
785 ++i) {
786 const auto expected_element = LiteralSlice(expected, {i});
787 const auto actual_element = LiteralSlice(actual, {i});
788 ShapeIndex element_index = shape_index;
789 element_index.push_back(i);
790 Status element_result =
791 NearHelper(expected_element, actual_element, element_index, error,
792 detailed_message, miscompare_callback);
793 if (!element_result.ok()) {
794 element_result = InvalidArgument("Array at shape index %s, %s",
795 element_index.ToString(),
796 element_result.error_message());
797 if (return_status.ok()) {
798 return_status = element_result;
799 } else {
800 return_status =
801 AppendStatus(return_status, element_result.error_message());
802 }
803 }
804 }
805 if (!return_status.ok() && shape_index.empty()) {
806 // Emit a top-level error message containing the top-level shape in case
807 // of mismatch.
808 int64_t total_elements = RecursiveElementCount(actual.shape());
809 return_status =
810 InvalidArgument("\nMismatches in shape %s (%d elements):\n%s",
811 ShapeUtil::HumanString(actual.shape()),
812 total_elements, return_status.error_message());
813 }
814 return return_status;
815 }
816
817 if (ShapeUtil::ElementIsFloating(expected.shape()) ||
818 ShapeUtil::ElementIsComplex(expected.shape())) {
819 bool use_detailed_message = detailed_message.value_or(
820 ShapeUtil::ElementsIn(expected.shape()) >= 64);
821 switch (expected.shape().element_type()) {
822 case BF16:
823 return NearComparator<bfloat16>::Compare(expected, actual, shape_index,
824 error, use_detailed_message,
825 miscompare_callback);
826 break;
827 case F16:
828 return NearComparator<half>::Compare(expected, actual, shape_index,
829 error, use_detailed_message,
830 miscompare_callback);
831 break;
832 case F32:
833 return NearComparator<float>::Compare(expected, actual, shape_index,
834 error, use_detailed_message,
835 miscompare_callback);
836 break;
837 case F64:
838 return NearComparator<double>::Compare(expected, actual, shape_index,
839 error, use_detailed_message,
840 miscompare_callback);
841 break;
842 case C64:
843 return NearComparator<complex64>::Compare(expected, actual, shape_index,
844 error, use_detailed_message,
845 miscompare_callback);
846 break;
847 case C128:
848 return NearComparator<complex128>::Compare(
849 expected, actual, shape_index, error, use_detailed_message,
850 miscompare_callback);
851 break;
852 default:
853 LOG(FATAL) << "Unsupported primitive type in near comparator: "
854 << PrimitiveType_Name(expected.shape().element_type())
855 << ". Must be floating-point type.";
856 }
857 }
858
859 // Non-floating point, non-tuple literal.
860 return EqualHelper(expected, actual, shape_index, miscompare_callback);
861 }
862
863 } // namespace
864
EqualShapes(const Shape & expected,const Shape & actual)865 Status EqualShapes(const Shape& expected, const Shape& actual) {
866 if (expected.element_type() != actual.element_type()) {
867 return InvalidArgument("element type mismatch, want: %s got %s",
868 ShapeUtil::HumanString(expected),
869 ShapeUtil::HumanString(actual));
870 }
871 if (expected.IsTuple()) {
872 if (ShapeUtil::TupleElementCount(expected) !=
873 ShapeUtil::TupleElementCount(actual)) {
874 return InvalidArgument(
875 "want tuple element count: %d got tuple element count: %d",
876 ShapeUtil::TupleElementCount(expected),
877 ShapeUtil::TupleElementCount(actual));
878 }
879 for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
880 Status result =
881 EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i));
882 if (!result.ok()) {
883 return AppendStatus(result, StrCat("mismatch in tuple index", i));
884 }
885 }
886 } else if (expected.IsArray()) {
887 if (expected.rank() != actual.rank()) {
888 return InvalidArgument("want rank of %s got rank of %s",
889 ShapeUtil::HumanString(expected),
890 ShapeUtil::HumanString(actual));
891 }
892 if (expected.element_type() != actual.element_type()) {
893 return InvalidArgument("mismatch in primitive type %s vs %s",
894 PrimitiveType_Name(expected.element_type()),
895 PrimitiveType_Name(actual.element_type()));
896 }
897 if (expected.dimensions_size() != actual.dimensions_size()) {
898 return InvalidArgument("want dimensions_size %d got dimensions_size %d",
899 expected.dimensions_size(),
900 actual.dimensions_size());
901 }
902 for (int i = 0; i < expected.dimensions_size(); ++i) {
903 if (expected.dimensions(i) != actual.dimensions(i)) {
904 return InvalidArgument(
905 "mismatch in dimension #%d expected: %s actual: %s", i,
906 ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual));
907 }
908 }
909 }
910 // Non-array, non-tuple shapes are trivially equivalent.
911 return Status::OK();
912 }
913
914 namespace {
915
916 // If result is an error, extend the error message with the expected and actual
917 // literals.
EmitLiteralsInErrorMessage(const Status & result,const LiteralSlice & expected,const LiteralSlice & actual)918 Status EmitLiteralsInErrorMessage(const Status& result,
919 const LiteralSlice& expected,
920 const LiteralSlice& actual) {
921 if (result.ok()) {
922 return result;
923 }
924 return InvalidArgument("%s\n\nExpected literal:\n%s\n\nActual literal:\n%s",
925 result.error_message(), ToStringTruncated(expected),
926 ToStringTruncated(actual));
927 }
928
929 } // namespace
930
Equal(const LiteralSlice & expected,const LiteralSlice & actual)931 Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
932 VLOG(1) << "expected:";
933 XLA_VLOG_LINES(1, expected.ToString());
934 VLOG(1) << "actual:";
935 XLA_VLOG_LINES(1, actual.ToString());
936 Status result = EqualHelper(expected, actual, {}, nullptr);
937 return EmitLiteralsInErrorMessage(result, expected, actual);
938 }
939
Near(const LiteralSlice & expected,const LiteralSlice & actual,const ErrorSpec & error,absl::optional<bool> detailed_message,const MiscompareCallback & miscompare_callback)940 Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
941 const ErrorSpec& error, absl::optional<bool> detailed_message,
942 const MiscompareCallback& miscompare_callback) {
943 VLOG(1) << "Expected literal:";
944 XLA_VLOG_LINES(1, expected.ToString());
945 VLOG(1) << "Actual literal:";
946 XLA_VLOG_LINES(1, actual.ToString());
947 Status result = NearHelper(expected, actual, /*shape_index=*/{}, error,
948 detailed_message, miscompare_callback);
949 return EmitLiteralsInErrorMessage(result, expected, actual);
950 }
951
ToStringTruncated(const LiteralSlice & literal)952 string ToStringTruncated(const LiteralSlice& literal) {
953 return RecursiveElementCount(literal.shape()) < 1000
954 ? literal.ToString()
955 : "[TRUNCATED, Literal with more than 1000 values]";
956 }
957
958 } // namespace literal_comparison
959 } // namespace xla
960