• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/literal_comparison.h"
17 
18 #include <unistd.h>
19 
20 #include <cmath>
21 #include <vector>
22 
23 #include "absl/base/casts.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/platform/env.h"
29 
30 using absl::StrAppend;
31 using absl::StrAppendFormat;
32 using absl::StrCat;
33 
34 namespace xla {
35 namespace literal_comparison {
36 namespace {
37 
38 // Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
39 // able to transparently access the raw 16-bit value contained within.
40 template <typename T>
GetRawValue(T val)41 T GetRawValue(T val) {
42   return val;
43 }
GetRawValue(Eigen::half val)44 uint16 GetRawValue(Eigen::half val) {
45   return Eigen::numext::bit_cast<uint16>(val);
46 }
47 
48 // Helper function for comparing a floating point type, FloatT, bitwise equal
49 // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
50 // -- on miscompare, a nice error message is given in the AssertionFailure.
51 template <typename FloatT, typename UnsignedT>
CompareFloatsBitwiseEqual(FloatT lhs,FloatT rhs,absl::Span<const int64> multi_index)52 bool CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs,
53                                absl::Span<const int64> multi_index) {
54   auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
55   auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
56   return ulhs == urhs;
57 }
58 
59 // Templated comparator that specializes for float equality comparison with the
60 // bitwise helper above (this is the un-specialized fallback, to just use the
61 // default gunit implementation).
62 template <typename NativeT>
CompareEqual(NativeT lhs,NativeT rhs,absl::Span<const int64> multi_index)63 bool CompareEqual(NativeT lhs, NativeT rhs,
64                   absl::Span<const int64> multi_index) {
65   return lhs == rhs;
66 }
67 
68 // Specializations for floating types that do bitwise comparisons when equality
69 // comparison is requested.
70 template <>
CompareEqual(bfloat16 lhs,bfloat16 rhs,absl::Span<const int64> multi_index)71 bool CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs,
72                             absl::Span<const int64> multi_index) {
73   return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs, multi_index);
74 }
75 template <>
CompareEqual(Eigen::half lhs,Eigen::half rhs,absl::Span<const int64> multi_index)76 bool CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs,
77                                absl::Span<const int64> multi_index) {
78   return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs, multi_index);
79 }
80 template <>
CompareEqual(float lhs,float rhs,absl::Span<const int64> multi_index)81 bool CompareEqual<float>(float lhs, float rhs,
82                          absl::Span<const int64> multi_index) {
83   return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs, multi_index);
84 }
85 template <>
CompareEqual(double lhs,double rhs,absl::Span<const int64> multi_index)86 bool CompareEqual<double>(double lhs, double rhs,
87                           absl::Span<const int64> multi_index) {
88   return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs, multi_index);
89 }
90 template <>
CompareEqual(complex64 lhs,complex64 rhs,absl::Span<const int64> multi_index)91 bool CompareEqual<complex64>(complex64 lhs, complex64 rhs,
92                              absl::Span<const int64> multi_index) {
93   return CompareEqual<float>(lhs.real(), rhs.real(), multi_index) &&
94          CompareEqual<float>(lhs.imag(), rhs.imag(), multi_index);
95 }
96 template <>
CompareEqual(complex128 lhs,complex128 rhs,absl::Span<const int64> multi_index)97 bool CompareEqual<complex128>(complex128 lhs, complex128 rhs,
98                               absl::Span<const int64> multi_index) {
99   return CompareEqual<double>(lhs.real(), rhs.real(), multi_index) &&
100          CompareEqual<double>(lhs.imag(), rhs.imag(), multi_index);
101 }
102 
103 template <typename NativeT, typename UnsignedT>
MakeBitwiseErrorStatus(NativeT lhs,NativeT rhs,absl::Span<const int64> multi_index)104 Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs,
105                               absl::Span<const int64> multi_index) {
106   auto ulhs = absl::bit_cast<UnsignedT>(GetRawValue(lhs));
107   auto urhs = absl::bit_cast<UnsignedT>(GetRawValue(rhs));
108   auto lhs_double = static_cast<double>(lhs);
109   auto rhs_double = static_cast<double>(rhs);
110   return InvalidArgument(
111       "floating values are not bitwise-equal; and equality testing "
112       "was requested: %s=%s=%a vs %s=%s=%a at array index %s",
113       StrCat(absl::Hex(ulhs)), RoundTripFpToString(lhs), lhs_double,
114       StrCat(absl::Hex(urhs)), RoundTripFpToString(rhs), rhs_double,
115       LiteralUtil::MultiIndexAsString(multi_index));
116 }
117 
118 template <typename NativeT>
MakeErrorStatus(NativeT lhs,NativeT rhs,absl::Span<const int64> multi_index)119 Status MakeErrorStatus(NativeT lhs, NativeT rhs,
120                        absl::Span<const int64> multi_index) {
121   return InvalidArgument(
122       "first mismatch at array index %s:\n  expected value: %s\n  actual "
123       "value:   %s",
124       LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs));
125 }
126 
127 template <>
MakeErrorStatus(bfloat16 lhs,bfloat16 rhs,absl::Span<const int64> multi_index)128 Status MakeErrorStatus(bfloat16 lhs, bfloat16 rhs,
129                        absl::Span<const int64> multi_index) {
130   return MakeBitwiseErrorStatus<bfloat16, uint16>(lhs, rhs, multi_index);
131 }
132 template <>
MakeErrorStatus(Eigen::half lhs,Eigen::half rhs,absl::Span<const int64> multi_index)133 Status MakeErrorStatus(Eigen::half lhs, Eigen::half rhs,
134                        absl::Span<const int64> multi_index) {
135   return MakeBitwiseErrorStatus<Eigen::half, uint16>(lhs, rhs, multi_index);
136 }
137 template <>
MakeErrorStatus(float lhs,float rhs,absl::Span<const int64> multi_index)138 Status MakeErrorStatus(float lhs, float rhs,
139                        absl::Span<const int64> multi_index) {
140   return MakeBitwiseErrorStatus<float, uint32>(lhs, rhs, multi_index);
141 }
142 template <>
MakeErrorStatus(double lhs,double rhs,absl::Span<const int64> multi_index)143 Status MakeErrorStatus(double lhs, double rhs,
144                        absl::Span<const int64> multi_index) {
145   return MakeBitwiseErrorStatus<double, uint64>(lhs, rhs, multi_index);
146 }
147 template <>
MakeErrorStatus(complex64 lhs,complex64 rhs,absl::Span<const int64> multi_index)148 Status MakeErrorStatus(complex64 lhs, complex64 rhs,
149                        absl::Span<const int64> multi_index) {
150   if (!CompareEqual<float>(lhs.real(), rhs.real(), multi_index)) {
151     return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
152   }
153   return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
154 }
155 template <>
MakeErrorStatus(complex128 lhs,complex128 rhs,absl::Span<const int64> multi_index)156 Status MakeErrorStatus(complex128 lhs, complex128 rhs,
157                        absl::Span<const int64> multi_index) {
158   if (!CompareEqual<double>(lhs.real(), rhs.real(), multi_index)) {
159     return MakeErrorStatus(lhs.real(), rhs.real(), multi_index);
160   }
161   return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index);
162 }
163 
164 // A recursive function which iterates through every index of expected and
165 // actual literal and compares their values elementwise. Returns true if all
166 // elements are equal. Mismatched must either be:
167 //    - a literal of booleans that has the same shape as expected and actual. In
168 //      this case, each index in mismatched will be set to true if expected does
169 //      not equal actual at that index and false if there are equal.
170 //    - nullptr. In this case, the function will return once any mismatch is
171 //      found between expected and actual.
172 template <typename NativeT>
Equal(LiteralSlice expected,LiteralSlice actual,absl::Span<int64> multi_index,int64_t dimension,Literal * mismatched=nullptr)173 Status Equal(LiteralSlice expected, LiteralSlice actual,
174              absl::Span<int64> multi_index, int64_t dimension,
175              Literal* mismatched = nullptr) {
176   if (dimension == expected.shape().dimensions_size()) {
177     NativeT expected_value = expected.Get<NativeT>(multi_index);
178     NativeT actual_value = actual.Get<NativeT>(multi_index);
179     bool result =
180         CompareEqual<NativeT>(expected_value, actual_value, multi_index);
181     if (mismatched) {
182       mismatched->Set<bool>(multi_index, !result);
183     }
184     return result ? Status::OK()
185                   : MakeErrorStatus<NativeT>(expected_value, actual_value,
186                                              multi_index);
187   }
188 
189   Status result;
190   for (int64_t i = 0; i < expected.shape().dimensions(dimension); ++i) {
191     multi_index[dimension] = i;
192     if (mismatched != nullptr) {
193       result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1,
194                                    mismatched));
195     } else {
196       TF_RETURN_IF_ERROR(Equal<NativeT>(expected, actual, multi_index,
197                                         dimension + 1, mismatched));
198     }
199   }
200   return result;
201 }
202 
203 // Gets the total element count.  For tuples, this is not the count of tuple
204 // elements, but the sum of elements of each tuple element.
RecursiveElementCount(const Shape & shape)205 int64 RecursiveElementCount(const Shape& shape) {
206   if (shape.IsTuple()) {
207     const int64_t tuple_elements = ShapeUtil::TupleElementCount(shape);
208     int64_t total = 0;
209     for (int64_t i = 0; i < tuple_elements; ++i) {
210       total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i));
211     }
212     return total;
213   } else if (shape.IsArray()) {
214     return ShapeUtil::ElementsIn(shape);
215   } else {
216     return 0;
217   }
218 }
219 
220 // Returns whether the given value is infinity.
221 template <typename NativeT>
IsInf(NativeT val)222 bool IsInf(NativeT val) {
223   return Eigen::numext::isinf(val);
224 }
225 // Returns whether the given value is nan.
226 template <typename NativeT>
IsNan(NativeT value)227 bool IsNan(NativeT value) {
228   return Eigen::numext::isnan(value);
229 }
230 
231 // Converts the given floating-point value to a string.
FpValueToString(bfloat16 value)232 string FpValueToString(bfloat16 value) {
233   return absl::StrFormat("%10.4g", static_cast<double>(value));
234 }
235 
FpValueToString(half value)236 string FpValueToString(half value) {
237   return absl::StrFormat("%11.5g", static_cast<double>(value));
238 }
239 
FpValueToString(float value)240 string FpValueToString(float value) {
241   return absl::StrFormat("%15.9g", static_cast<double>(value));
242 }
243 
FpValueToString(double value)244 string FpValueToString(double value) {
245   return absl::StrFormat("%24.17g", value);
246 }
247 
FpValueToString(complex64 value)248 string FpValueToString(complex64 value) {
249   return absl::StrCat(FpValueToString(value.real()), " + ",
250                       FpValueToString(value.imag()));
251 }
252 
FpValueToString(complex128 value)253 string FpValueToString(complex128 value) {
254   return absl::StrCat(FpValueToString(value.real()), " + ",
255                       FpValueToString(value.imag()));
256 }
257 
258 // A wrapper of std::abs to include data types that are not supported by
259 // std::abs, in particular, bfloat16 and half.
260 template <typename NativeT>
FpAbsoluteValue(NativeT value)261 double FpAbsoluteValue(NativeT value) {
262   return std::abs(value);
263 }
264 
265 template <>
FpAbsoluteValue(bfloat16 value)266 double FpAbsoluteValue(bfloat16 value) {
267   return FpAbsoluteValue<float>(static_cast<float>(value));
268 }
269 
270 template <>
FpAbsoluteValue(half value)271 double FpAbsoluteValue(half value) {
272   return FpAbsoluteValue<float>(static_cast<float>(value));
273 }
274 
275 // Helper class for comparing floating-point literals within an error bound.
276 template <typename NativeT>
277 class NearComparator {
278  public:
279   // Compares the two array literals elementwise and returns a comparison
280   // result. The comparison is ok() if all actual and expected elements are
281   // within the given error bound. In case of error, the status contains a
282   // detailed message about the discrepancy.
Compare(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,ErrorSpec error,bool detailed_message,const MiscompareCallback & miscompare_callback)283   static Status Compare(const LiteralSlice& expected,
284                         const LiteralSlice& actual,
285                         const ShapeIndex& shape_index, ErrorSpec error,
286                         bool detailed_message,
287                         const MiscompareCallback& miscompare_callback) {
288     NearComparator<NativeT> comparator(expected, actual, shape_index, error,
289                                        detailed_message, miscompare_callback);
290     return comparator.Run();
291   }
292 
293  private:
294   // Data structure encapsulating metadata about a single element mismatch.
295   struct Mismatch {
296     NativeT actual;
297     NativeT expected;
298     double rel_error;
299     double abs_error;
300 
301     // The linear index of the failure within the shape. This linear index is
302     // from the 'actual' literal.
303     int64 linear_index;
304 
operator <xla::literal_comparison::__anon7c9aba5a0111::NearComparator::Mismatch305     bool operator<(const Mismatch& other) const {
306       return rel_error < other.rel_error;
307     }
308 
ToStringxla::literal_comparison::__anon7c9aba5a0111::NearComparator::Mismatch309     string ToString(const Shape& shape) const {
310       return absl::StrFormat(
311           "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
312           FpValueToString(actual), FpValueToString(expected),
313           LiteralUtil::MultiIndexAsString(
314               IndexUtil::LinearIndexToMultidimensionalIndex(shape,
315                                                             linear_index)),
316           rel_error, abs_error);
317     }
318   };
319 
NearComparator(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,ErrorSpec error,bool detailed_message,const MiscompareCallback & miscompare_callback)320   NearComparator(const LiteralSlice& expected, const LiteralSlice& actual,
321                  const ShapeIndex& shape_index, ErrorSpec error,
322                  bool detailed_message,
323                  const MiscompareCallback& miscompare_callback)
324       : expected_(expected),
325         actual_(actual),
326         shape_index_(shape_index),
327         error_(error),
328         detailed_message_(detailed_message),
329         miscompare_callback_(miscompare_callback),
330         abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}),
331         abs_error_buckets_(kErrorBucketBounds.size(), 0),
332         rel_error_buckets_(kErrorBucketBounds.size(), 0) {}
333 
334   // Runs the comparison between expected and actual literals.
Run()335   Status Run() {
336     // If the shapes mismatch, we simply fail the expectation instead of
337     // printing out data, as it's a type error rather than a value error.
338     TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape()));
339     if (!expected_.shape().IsArray()) {
340       return InvalidArgument("Expected array shape; got %s.",
341                              ShapeUtil::HumanString(expected_.shape()));
342     }
343 
344     mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED));
345     mismatches_.PopulateWithValue(false);
346 
347     CompareLiterals();
348 
349     if (num_mismatches_ == 0) {
350       return Status::OK();
351     } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) {
352       miscompare_callback_(expected_, actual_, mismatches_, shape_index_);
353     }
354     return InvalidArgument("%s", ErrorMessage());
355   }
356 
357   // Insert the given absolute value into the absolute value bucket vector. The
358   // bounds of the buckets are given by kAbsValueBucketBounds.
UpdateAbsValueBucket(NativeT value,bool is_mismatch)359   void UpdateAbsValueBucket(NativeT value, bool is_mismatch) {
360     // Adjust the bucket containing the absolute values of the 'actual'
361     // elements.
362     const double abs_value = FpAbsoluteValue(value);
363     for (int i = 0; i < abs_value_buckets_.size(); ++i) {
364       if (i == abs_value_buckets_.size() - 1 ||
365           (abs_value >= kAbsValueBucketBounds[i] &&
366            abs_value < kAbsValueBucketBounds[i + 1])) {
367         // The first value of the pair is the count of elements in the bucket,
368         // the second is the count of mismatches in the bucket.
369         abs_value_buckets_[i].first++;
370         if (is_mismatch) {
371           abs_value_buckets_[i].second++;
372         }
373         return;
374       }
375     }
376   }
377 
378   // Insert the given error into the given error bucket vector.
UpdateErrorBucket(double error,absl::Span<int64> error_buckets)379   void UpdateErrorBucket(double error, absl::Span<int64> error_buckets) {
380     CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size());
381     for (int i = 0; i < error_buckets.size(); ++i) {
382       if (error >= kErrorBucketBounds[i]) {
383         error_buckets[i]++;
384       }
385     }
386   }
387 
388   // Compares the two given elements from the expected and actual literals at
389   // the given literal_index and keeps track of various mismatch statistics.
390   template <typename T>
CompareValues(T expected,T actual,int64_t linear_index)391   void CompareValues(T expected, T actual, int64_t linear_index) {
392     double abs_error;
393     double rel_error;
394     if (CompareEqual<T>(expected, actual, {linear_index})) {
395       abs_error = 0;
396       rel_error = 0;
397     } else if (IsNan(expected) || IsNan(actual)) {
398       if ((!error_.relaxed_nans && IsNan(expected) != IsNan(actual)) ||
399           (error_.relaxed_nans && !IsNan(expected) && IsNan(actual))) {
400         num_nan_mismatches_++;
401         // A nan mismatch is considered to have infinite error. rel_error is
402         // used for sorting a std::set of the top mismatches, and a nan value
403         // here will result in undefined behavior because nan's do not satisfy
404         // the strict weak ordering requirement of std containers.
405         abs_error = std::numeric_limits<float>::infinity();
406         rel_error = std::numeric_limits<float>::infinity();
407       } else {
408         abs_error = 0;
409         rel_error = 0;
410       }
411     } else if (IsInf(actual) && !IsInf(expected) && error_.fewer_infs_ok) {
412       // `fewer_infs_ok` gives us the option of comparing as though `actual`
413       // were float_max/min rather than inf.
414       T actual_finite = actual > T{0} ? std::numeric_limits<T>::max()
415                                       : std::numeric_limits<T>::lowest();
416       abs_error = FpAbsoluteValue(actual_finite - expected);
417 
418       // Avoid division by 0 even though it's well-defined because ubsan can be
419       // configured to treat this as a fatal error.
420       if (expected != T{0}) {
421         rel_error = abs_error / FpAbsoluteValue(expected);
422       } else {
423         rel_error = std::numeric_limits<float>::infinity();
424       }
425     } else if (IsInf(expected) || IsInf(actual)) {
426       // If either the expected or actual value is infinity but not both,
427       // then both absolute and relative error are regarded as infinity.
428       CHECK(!CompareEqual(expected, actual, {linear_index}));
429       abs_error = std::numeric_limits<float>::infinity();
430       rel_error = std::numeric_limits<float>::infinity();
431     } else {
432       abs_error = FpAbsoluteValue(actual - expected);
433 
434       // Avoid division by 0 even though it's well-defined because ubsan can be
435       // configured to treat this as a fatal error.
436       if (expected != T{0}) {
437         rel_error = abs_error / FpAbsoluteValue(expected);
438       } else {
439         rel_error = std::numeric_limits<float>::infinity();
440       }
441     }
442     const bool is_abs_mismatch = abs_error > error_.abs;
443     const bool is_rel_mismatch = rel_error > error_.rel;
444     const bool is_mismatch = is_abs_mismatch && is_rel_mismatch;
445 
446     // Update the error of the relative bucket only if the *absolute* error
447     // bound is exceeded and vice versa.
448     if (is_abs_mismatch) {
449       num_abs_mismatches_++;
450       UpdateErrorBucket(rel_error, absl::MakeSpan(rel_error_buckets_));
451     }
452     if (is_rel_mismatch) {
453       num_rel_mismatches_++;
454       UpdateErrorBucket(abs_error, absl::MakeSpan(abs_error_buckets_));
455     }
456 
457     UpdateAbsValueBucket(actual, is_mismatch);
458 
459     if (!is_mismatch) {
460       return;
461     }
462 
463     num_mismatches_++;
464 
465     // Keep track of the kTopRelativeErrorCount relative error mismatches.
466     if (top_rel_mismatches_.size() < kTopRelativeErrorCount ||
467         rel_error > top_rel_mismatches_.begin()->rel_error) {
468       Mismatch mismatch = {actual, expected, rel_error, abs_error,
469                            linear_index};
470       top_rel_mismatches_.insert(mismatch);
471       if (top_rel_mismatches_.size() > kTopRelativeErrorCount) {
472         top_rel_mismatches_.erase(top_rel_mismatches_.begin());
473       }
474     }
475 
476     mismatches_.data<bool>()[linear_index] = true;
477   }
478 
479   // For complex types, we compare real and imaginary parts individually.
CompareValues(complex64 expected,complex64 actual,int64_t linear_index)480   void CompareValues(complex64 expected, complex64 actual,
481                      int64_t linear_index) {
482     const auto both_parts_mismatch = num_mismatches_ + 2;
483     CompareValues<float>(expected.real(), actual.real(), linear_index);
484     CompareValues<float>(expected.imag(), actual.imag(), linear_index);
485     if (num_mismatches_ == both_parts_mismatch) {
486       // The mismatch counter had been incremented by each CompareValues() call,
487       // which means that both real and imaginary parts of the passed-in complex
488       // values are different. However, the counter should reflect a single
489       // mismatch between these complex values.
490       num_mismatches_--;
491     }
492   }
493 
CompareValues(complex128 expected,complex128 actual,int64_t linear_index)494   void CompareValues(complex128 expected, complex128 actual,
495                      int64_t linear_index) {
496     const auto both_parts_mismatch = num_mismatches_ + 2;
497     CompareValues<double>(expected.real(), actual.real(), linear_index);
498     CompareValues<double>(expected.imag(), actual.imag(), linear_index);
499     if (num_mismatches_ == both_parts_mismatch) {
500       // The mismatch counter had been incremented by each CompareValues() call,
501       // which means that both real and imaginary parts of the passed-in complex
502       // values are different. However, the counter should reflect a single
503       // mismatch between these complex values.
504       num_mismatches_--;
505     }
506   }
507 
508   // Compares the two literals elementwise.
CompareLiterals()509   void CompareLiterals() {
510     // Fast path optimization for the case were layouts match.
511     if (LayoutUtil::Equal(actual_.shape().layout(),
512                           expected_.shape().layout())) {
513       absl::Span<const NativeT> expected_data = expected_.data<NativeT>();
514       absl::Span<const NativeT> actual_data = actual_.data<NativeT>();
515       const int64_t len = expected_data.size();
516       for (int64_t i = 0; i < len; ++i) {
517         CompareValues(expected_data[i], actual_data[i], i);
518       }
519       return;
520     }
521     std::vector<int64> multi_index(actual_.shape().rank(), 0);
522     CompareLiteralsSlow(0, &multi_index);
523   }
524 
525   // Slow path for CompareLiterals when 'actual' and 'expected' literals have
526   // different layouts. In this case, multidimensional indices are constructed
527   // and indexed for each element.
CompareLiteralsSlow(int64_t dimension,std::vector<int64> * multi_index)528   void CompareLiteralsSlow(int64_t dimension, std::vector<int64>* multi_index) {
529     if (dimension == multi_index->size()) {
530       CompareValues(expected_.Get<NativeT>(*multi_index),
531                     actual_.Get<NativeT>(*multi_index),
532                     IndexUtil::MultidimensionalIndexToLinearIndex(
533                         actual_.shape(), *multi_index));
534     } else {
535       for (int64_t i = 0; i < expected_.shape().dimensions(dimension); ++i) {
536         (*multi_index)[dimension] = i;
537         CompareLiteralsSlow(dimension + 1, multi_index);
538       }
539     }
540   }
541 
542   // Returns an error message string with a detailed breakdown of the
543   // mismatches. Called after calling Run().
ErrorMessage()544   string ErrorMessage() {
545     string out;
546     int64_t element_count = ShapeUtil::ElementsIn(actual_.shape());
547 
548     auto percent_string = [](float a, float b) {
549       float pct = b == 0.0 ? 0.0 : 100.0 * a / b;
550       return absl::StrFormat("%0.4f%%", pct);
551     };
552 
553     StrAppendFormat(
554         &out,
555         "\nMismatch count %d (%s) in shape %s (%d elements), abs bound "
556         "%g, rel bound %g\n",
557         num_mismatches_, percent_string(num_mismatches_, element_count),
558         ShapeUtil::HumanString(actual_.shape()),
559         ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel);
560     if (num_nan_mismatches_ > 0) {
561       StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n");
562     }
563     StrAppendFormat(&out, "Top relative error mismatches:\n");
564     for (auto it = top_rel_mismatches_.rbegin();
565          it != top_rel_mismatches_.rend(); ++it) {
566       StrAppend(&out, "  ", it->ToString(actual_.shape()), "\n");
567     }
568 
569     if (!detailed_message_) {
570       return out;
571     }
572 
573     StrAppend(&out, "Absolute magnitude breakdown of actual values:\n");
574     CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size());
575     for (int i = 0; i < abs_value_buckets_.size(); ++i) {
576       const int64_t bucket_size = abs_value_buckets_[i].first;
577       const int64_t bucket_mismatches = abs_value_buckets_[i].second;
578       string mismatch_str =
579           bucket_mismatches > 0
580               ? absl::StrFormat(", mismatches %d", bucket_mismatches)
581               : "";
582       StrAppendFormat(&out, "  %-6g <= x < %-6g : %7d (%9s)%s\n",
583                       kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1],
584                       bucket_size, percent_string(bucket_size, element_count),
585                       mismatch_str);
586     }
587 
588     auto print_accum_buckets = [&](const string& header, int64_t total,
589                                    absl::Span<const int64> buckets) {
590       StrAppend(&out, header, ":\n");
591       StrAppendFormat(&out, "  <  %-6g : %7d (%s)\n", kErrorBucketBounds[0],
592                       total - buckets[0],
593                       percent_string(total - buckets[0], total));
594       CHECK_EQ(buckets.size(), kErrorBucketBounds.size());
595       for (int i = 0; i < kErrorBucketBounds.size(); ++i) {
596         StrAppendFormat(&out, "  >= %-6g : %7d (%s)\n", kErrorBucketBounds[i],
597                         buckets[i], percent_string(buckets[i], total));
598       }
599     };
600     StrAppendFormat(&out, "Elements exceeding abs error bound %g: %d (%s)\n",
601                     error_.abs, num_abs_mismatches_,
602                     percent_string(num_abs_mismatches_, element_count));
603     print_accum_buckets(
604         "Relative error breakdown of elements exceeding abs error bound",
605         num_abs_mismatches_, rel_error_buckets_);
606     StrAppendFormat(&out, "Elements exceeding rel error bound %g: %d (%s)\n",
607                     error_.rel, num_rel_mismatches_,
608                     percent_string(num_rel_mismatches_, element_count));
609     print_accum_buckets(
610         "Absolute error breakdown of elements exceeding rel error bound",
611         num_rel_mismatches_, abs_error_buckets_);
612     return out;
613   }
614 
615   // 'actual' and 'expected' literals being compared.
616   LiteralSlice expected_;
617   LiteralSlice actual_;
618 
619   // The shape index of the LiteralSlice that is being compared.
620   ShapeIndex shape_index_;
621 
622   // The error bounds of the comparison.
623   ErrorSpec error_;
624 
625   // Whether to include detailed breakdown of mismatches in the error message.
626   bool detailed_message_;
627 
628   // Callback to invoke on miscompare.
629   MiscompareCallback miscompare_callback_;
630 
631   // Number of element mismatches encountered so far.
632   int64 num_mismatches_ = 0;
633 
634   // Number of elements with a nan mismatch.
635   int64 num_nan_mismatches_ = 0;
636 
637   // Number of elements which exceed the absolute/relative error bound.
638   int64 num_abs_mismatches_ = 0;
639   int64 num_rel_mismatches_ = 0;
640 
641   // A Literal containing which elements did not match in the expected and
642   // actual literals. mismatches_ contains PREDs and is of the same sizes as
643   // the comparison literals.
644   Literal mismatches_;
645 
646   // The number of mismatches to report in the output, sorted by relative error
647   // magnitude.
648   static constexpr int64_t kTopRelativeErrorCount = 5;
649 
650   // The set of mismatches with the largest relative error. The size of this set
651   // is bounded by kTopRelativeErrorCount.
652   std::multiset<Mismatch> top_rel_mismatches_;
653 
654   // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the
655   // bounds of these buckets. abs_value_buckets_ contains a pair for each
656   // bucket: the element count and failure count.
657   static constexpr std::array<float, 7> kAbsValueBucketBounds = {
658       0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits<float>::infinity()};
659   std::vector<std::pair<int64, int64>> abs_value_buckets_;
660 
661   // Buckets for relative and absolute errors. The relative error buckets only
662   // contains those elements which exceed the *absolute* error bound, and vice
663   // versa. This makes it easy to see the effect of adjusting the relative (or
664   // absolute) error bound on the success of the comparison. kErrorBucketBounds
665   // are the lower bounds of the buckets in both vectors. The error buckets are
666   // a cumulative distribution so an error value may appear in more than one
667   // bucket. For example an error value of 0.003 may appear in the buckets
668   // bounded by 0.01, 0.1, and 1.0.
669   static constexpr std::array<float, 5> kErrorBucketBounds = {0.0001, 0.001,
670                                                               0.01, 0.1, 1};
671   std::vector<int64> abs_error_buckets_;
672   std::vector<int64> rel_error_buckets_;
673 };
674 
675 template <typename NativeT>
676 constexpr std::array<float, 7> NearComparator<NativeT>::kAbsValueBucketBounds;
677 template <typename NativeT>
678 constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
679 
EqualHelper(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,const MiscompareCallback & miscompare_callback)680 Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual,
681                    const ShapeIndex& shape_index,
682                    const MiscompareCallback& miscompare_callback) {
683   TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
684 
685   Status result;
686   if (expected.shape().IsTuple()) {
687     ShapeIndex next_index = shape_index;
688     for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
689       next_index.push_back(i);
690       Status tuple_result =
691           EqualHelper(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}),
692                       next_index, miscompare_callback);
693       if (miscompare_callback) {
694         result.Update(tuple_result);
695       } else {
696         TF_RETURN_IF_ERROR(tuple_result);
697       }
698       next_index.pop_back();
699     }
700   } else {
701     std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
702     auto index = absl::MakeSpan(multi_index);
703 
704     Shape unequal_shape = ShapeUtil::MakeShape(PrimitiveType::PRED,
705                                                expected.shape().dimensions());
706     Literal miscompared(unequal_shape);
707     Literal* miscompared_ptr =
708         (miscompare_callback == nullptr ? nullptr : &miscompared);
709 
710     switch (expected.shape().element_type()) {
711       case PRED:
712         result = Equal<bool>(expected, actual, index, 0, miscompared_ptr);
713         break;
714       case S8:
715         result = Equal<int8>(expected, actual, index, 0, miscompared_ptr);
716         break;
717       case S16:
718         result = Equal<int16>(expected, actual, index, 0, miscompared_ptr);
719         break;
720       case S32:
721         result = Equal<int32>(expected, actual, index, 0, miscompared_ptr);
722         break;
723       case S64:
724         result = Equal<int64>(expected, actual, index, 0, miscompared_ptr);
725         break;
726       case U8:
727         result = Equal<uint8>(expected, actual, index, 0, miscompared_ptr);
728         break;
729       case U16:
730         result = Equal<uint16>(expected, actual, index, 0, miscompared_ptr);
731         break;
732       case U32:
733         result = Equal<uint32>(expected, actual, index, 0, miscompared_ptr);
734         break;
735       case U64:
736         result = Equal<uint64>(expected, actual, index, 0, miscompared_ptr);
737         break;
738       case BF16:
739         result = Equal<bfloat16>(expected, actual, index, 0, miscompared_ptr);
740         break;
741       case F16:
742         result = Equal<half>(expected, actual, index, 0, miscompared_ptr);
743         break;
744       case F32:
745         result = Equal<float>(expected, actual, index, 0, miscompared_ptr);
746         break;
747       case F64:
748         result = Equal<double>(expected, actual, index, 0, miscompared_ptr);
749         break;
750       case C64:
751         result = Equal<complex64>(expected, actual, index, 0, miscompared_ptr);
752         break;
753       case C128:
754         result = Equal<complex128>(expected, actual, index, 0, miscompared_ptr);
755         break;
756       case TOKEN:
757         // Tokens have no on-device representation and are trivially equal.
758         return Status::OK();
759       default:
760         LOG(FATAL) << "Unsupported primitive type: "
761                    << PrimitiveType_Name(expected.shape().element_type());
762     }
763 
764     if (!result.ok() && miscompare_callback) {
765       miscompare_callback(expected, actual, LiteralSlice(miscompared),
766                           shape_index);
767     }
768   }
769 
770   return result;
771 }
772 
773 // Helper function for comparing two literals for nearness. Handles tuple-shapes
774 // via recursion. shape_index is the ShapeIndex of expected (or actual)
775 // currently being compared.
NearHelper(const LiteralSlice & expected,const LiteralSlice & actual,const ShapeIndex & shape_index,const ErrorSpec & error,absl::optional<bool> detailed_message,const MiscompareCallback & miscompare_callback)776 Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
777                   const ShapeIndex& shape_index, const ErrorSpec& error,
778                   absl::optional<bool> detailed_message,
779                   const MiscompareCallback& miscompare_callback) {
780   TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
781 
782   if (expected.shape().IsTuple()) {
783     Status return_status;
784     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(expected.shape());
785          ++i) {
786       const auto expected_element = LiteralSlice(expected, {i});
787       const auto actual_element = LiteralSlice(actual, {i});
788       ShapeIndex element_index = shape_index;
789       element_index.push_back(i);
790       Status element_result =
791           NearHelper(expected_element, actual_element, element_index, error,
792                      detailed_message, miscompare_callback);
793       if (!element_result.ok()) {
794         element_result = InvalidArgument("Array at shape index %s, %s",
795                                          element_index.ToString(),
796                                          element_result.error_message());
797         if (return_status.ok()) {
798           return_status = element_result;
799         } else {
800           return_status =
801               AppendStatus(return_status, element_result.error_message());
802         }
803       }
804     }
805     if (!return_status.ok() && shape_index.empty()) {
806       // Emit a top-level error message containing the top-level shape in case
807       // of mismatch.
808       int64_t total_elements = RecursiveElementCount(actual.shape());
809       return_status =
810           InvalidArgument("\nMismatches in shape %s (%d elements):\n%s",
811                           ShapeUtil::HumanString(actual.shape()),
812                           total_elements, return_status.error_message());
813     }
814     return return_status;
815   }
816 
817   if (ShapeUtil::ElementIsFloating(expected.shape()) ||
818       ShapeUtil::ElementIsComplex(expected.shape())) {
819     bool use_detailed_message = detailed_message.value_or(
820         ShapeUtil::ElementsIn(expected.shape()) >= 64);
821     switch (expected.shape().element_type()) {
822       case BF16:
823         return NearComparator<bfloat16>::Compare(expected, actual, shape_index,
824                                                  error, use_detailed_message,
825                                                  miscompare_callback);
826         break;
827       case F16:
828         return NearComparator<half>::Compare(expected, actual, shape_index,
829                                              error, use_detailed_message,
830                                              miscompare_callback);
831         break;
832       case F32:
833         return NearComparator<float>::Compare(expected, actual, shape_index,
834                                               error, use_detailed_message,
835                                               miscompare_callback);
836         break;
837       case F64:
838         return NearComparator<double>::Compare(expected, actual, shape_index,
839                                                error, use_detailed_message,
840                                                miscompare_callback);
841         break;
842       case C64:
843         return NearComparator<complex64>::Compare(expected, actual, shape_index,
844                                                   error, use_detailed_message,
845                                                   miscompare_callback);
846         break;
847       case C128:
848         return NearComparator<complex128>::Compare(
849             expected, actual, shape_index, error, use_detailed_message,
850             miscompare_callback);
851         break;
852       default:
853         LOG(FATAL) << "Unsupported primitive type in near comparator: "
854                    << PrimitiveType_Name(expected.shape().element_type())
855                    << ". Must be floating-point type.";
856     }
857   }
858 
859   // Non-floating point, non-tuple literal.
860   return EqualHelper(expected, actual, shape_index, miscompare_callback);
861 }
862 
863 }  // namespace
864 
EqualShapes(const Shape & expected,const Shape & actual)865 Status EqualShapes(const Shape& expected, const Shape& actual) {
866   if (expected.element_type() != actual.element_type()) {
867     return InvalidArgument("element type mismatch, want: %s got %s",
868                            ShapeUtil::HumanString(expected),
869                            ShapeUtil::HumanString(actual));
870   }
871   if (expected.IsTuple()) {
872     if (ShapeUtil::TupleElementCount(expected) !=
873         ShapeUtil::TupleElementCount(actual)) {
874       return InvalidArgument(
875           "want tuple element count: %d got tuple element count: %d",
876           ShapeUtil::TupleElementCount(expected),
877           ShapeUtil::TupleElementCount(actual));
878     }
879     for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
880       Status result =
881           EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i));
882       if (!result.ok()) {
883         return AppendStatus(result, StrCat("mismatch in tuple index", i));
884       }
885     }
886   } else if (expected.IsArray()) {
887     if (expected.rank() != actual.rank()) {
888       return InvalidArgument("want rank of %s got rank of %s",
889                              ShapeUtil::HumanString(expected),
890                              ShapeUtil::HumanString(actual));
891     }
892     if (expected.element_type() != actual.element_type()) {
893       return InvalidArgument("mismatch in primitive type %s vs %s",
894                              PrimitiveType_Name(expected.element_type()),
895                              PrimitiveType_Name(actual.element_type()));
896     }
897     if (expected.dimensions_size() != actual.dimensions_size()) {
898       return InvalidArgument("want dimensions_size %d got dimensions_size %d",
899                              expected.dimensions_size(),
900                              actual.dimensions_size());
901     }
902     for (int i = 0; i < expected.dimensions_size(); ++i) {
903       if (expected.dimensions(i) != actual.dimensions(i)) {
904         return InvalidArgument(
905             "mismatch in dimension #%d expected: %s actual: %s", i,
906             ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual));
907       }
908     }
909   }
910   // Non-array, non-tuple shapes are trivially equivalent.
911   return Status::OK();
912 }
913 
914 namespace {
915 
916 // If result is an error, extend the error message with the expected and actual
917 // literals.
EmitLiteralsInErrorMessage(const Status & result,const LiteralSlice & expected,const LiteralSlice & actual)918 Status EmitLiteralsInErrorMessage(const Status& result,
919                                   const LiteralSlice& expected,
920                                   const LiteralSlice& actual) {
921   if (result.ok()) {
922     return result;
923   }
924   return InvalidArgument("%s\n\nExpected literal:\n%s\n\nActual literal:\n%s",
925                          result.error_message(), ToStringTruncated(expected),
926                          ToStringTruncated(actual));
927 }
928 
929 }  // namespace
930 
Equal(const LiteralSlice & expected,const LiteralSlice & actual)931 Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
932   VLOG(1) << "expected:";
933   XLA_VLOG_LINES(1, expected.ToString());
934   VLOG(1) << "actual:";
935   XLA_VLOG_LINES(1, actual.ToString());
936   Status result = EqualHelper(expected, actual, {}, nullptr);
937   return EmitLiteralsInErrorMessage(result, expected, actual);
938 }
939 
Near(const LiteralSlice & expected,const LiteralSlice & actual,const ErrorSpec & error,absl::optional<bool> detailed_message,const MiscompareCallback & miscompare_callback)940 Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
941             const ErrorSpec& error, absl::optional<bool> detailed_message,
942             const MiscompareCallback& miscompare_callback) {
943   VLOG(1) << "Expected literal:";
944   XLA_VLOG_LINES(1, expected.ToString());
945   VLOG(1) << "Actual literal:";
946   XLA_VLOG_LINES(1, actual.ToString());
947   Status result = NearHelper(expected, actual, /*shape_index=*/{}, error,
948                              detailed_message, miscompare_callback);
949   return EmitLiteralsInErrorMessage(result, expected, actual);
950 }
951 
ToStringTruncated(const LiteralSlice & literal)952 string ToStringTruncated(const LiteralSlice& literal) {
953   return RecursiveElementCount(literal.shape()) < 1000
954              ? literal.ToString()
955              : "[TRUNCATED, Literal with more than 1000 values]";
956 }
957 
958 }  // namespace literal_comparison
959 }  // namespace xla
960