• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_
18 
19 #include <array>
20 #include <cmath>
21 #include <iterator>
22 
23 #include "tensorflow/compiler/xla/bit_cast.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/math.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
30 #include "tensorflow/compiler/xla/tests/test_macros.h"
31 
32 namespace xla {
33 namespace exhaustive_op_test {
34 
35 struct ErrorSpec {
36   float abs_err;
37   float rel_err;
38 
39   // If true, will consider -0 not near to +0 and vice versa.  Note that
40   // +epsilon may still be considered close to -0, depending on the error
41   // spec; this only covers the case when both `expected` and `actual` are
42   // equal to 0.
43   bool strict_signed_zeros = false;
44 
ErrorSpecErrorSpec45   ErrorSpec(float a, float r) : abs_err(a), rel_err(r) {}
46 };
47 
48 // Representations of the reference function passed in by the user.
49 template <typename NativeRefT, size_t K>
50 struct EvaluateOpWrapper {};
51 template <typename NativeRefT>
52 struct EvaluateOpWrapper<NativeRefT, 1> {
53   using type = NativeRefT (*)(NativeRefT);
54 };
55 template <typename NativeRefT>
56 struct EvaluateOpWrapper<NativeRefT, 2> {
57   using type = NativeRefT (*)(NativeRefT, NativeRefT);
58 };
59 
60 // Representations of the reference function passed in by the user.
61 template <typename XlaInputs, size_t K>
62 struct EnqueueOpWrapper {};
63 template <typename XlaInputs>
64 struct EnqueueOpWrapper<XlaInputs, 1> {
65   using type = std::function<XlaOp(XlaOp)>;
66   static XlaOp BuildFromInputs(XlaInputs inputs, type ty) {
67     return ty(inputs[0]);
68   }
69 };
70 template <typename XlaInputs>
71 struct EnqueueOpWrapper<XlaInputs, 2> {
72   using type = std::function<XlaOp(XlaOp, XlaOp)>;
73   static XlaOp BuildFromInputs(XlaInputs inputs, type ty) {
74     return ty(inputs[0], inputs[1]);
75   }
76 };
77 
78 // Representations of the ErrorSpecGen function passed in by the user.
79 template <PrimitiveType T, size_t K>
80 struct ErrorSpecGenWrapper {};
81 template <PrimitiveType T>
82 struct ErrorSpecGenWrapper<T, 1> {
83   using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type;
84   using type = ErrorSpec (*)(NativeT);
85 };
86 template <PrimitiveType T>
87 struct ErrorSpecGenWrapper<T, 2> {
88   using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type;
89   using type = ErrorSpec (*)(NativeT, NativeT);
90 };
91 
92 template <PrimitiveType T, size_t N>
93 typename ErrorSpecGenWrapper<T, N>::type GetDefaultSpecGenerator();
94 
95 // T: The primitive type being tested.
96 // N: The number of operands that the function being tested takes.
97 template <PrimitiveType T, size_t N>
98 class ExhaustiveOpTestBase : public ClientLibraryTestBase {
99  public:
100   // Definitions depending on the primitive type T.
101 
102   static constexpr bool kIsComplex = (T == C128 || T == C64);
103 
104   // The primitive type used to compute the reference output.
105   struct RefT {
106     static constexpr PrimitiveType value = (T == F16 || T == BF16) ? F32 : T;
107   };
108 
109   // The primitive type of the component of T. If T is not complex, then
110   // ComponentT = T.
111   struct ComponentT {
112     static constexpr PrimitiveType value =
113         !kIsComplex ? T
114                     : T == C128 ? F64 : T == C64 ? F32 : PRIMITIVE_TYPE_INVALID;
115   };
116 
117   // Same as ComponentT, but for the RefT primitive type.
118   struct ComponentRefT {
119     static constexpr PrimitiveType value =
120         !kIsComplex ? RefT::value
121                     : RefT::value == C128
122                           ? F64
123                           : RefT::value == C64 ? F32 : PRIMITIVE_TYPE_INVALID;
124   };
125 
126   // The primitive type of an unsigned integer that can be bitcasted to and from
127   // ComponentT.
128   struct ComponentIntegralT {
129     static constexpr PrimitiveType value =
130         (T == C128 || T == F64)
131             ? U64
132             : (T == C64 || T == F32)
133                   ? U32
134                   : (T == F16 || T == BF16) ? U16 : PRIMITIVE_TYPE_INVALID;
135   };
136 
137   // Native types that correspond to the primitive types above.
138   using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type;
139   using NativeRefT =
140       typename primitive_util::PrimitiveTypeToNative<RefT::value>::type;
141   using ComponentNativeT =
142       typename primitive_util::PrimitiveTypeToNative<ComponentT::value>::type;
143   using ComponentNativeRefT = typename primitive_util::PrimitiveTypeToNative<
144       ComponentRefT::value>::type;
145   using ComponentIntegralNativeT =
146       typename primitive_util::PrimitiveTypeToNative<
147           ComponentIntegralT::value>::type;
148 
149   using InputLiterals = std::array<Literal, N>;
150 
151  private:
152   // N spans corresponding to the list of literal data values.
153   using NativeInputsList = std::array<absl::Span<const NativeT>, N>;
154 
155   // N data items representing a single input to an XLA function.
156   using NativeInputs = std::array<NativeT, N>;
157 
158   // N data items representing a single input to an interpreter backend
159   // function.
160   using NativeRefInputs = std::array<NativeRefT, N>;
161 
162   // N data items representing a single input to an XLA function.
163   using XlaInputs = std::array<XlaOp, N>;
164 
165  public:
166   using ErrorSpecGen = typename ErrorSpecGenWrapper<T, N>::type;
167   using EvaluateOp = typename EvaluateOpWrapper<NativeRefT, N>::type;
168   using EnqueueOp = typename EnqueueOpWrapper<XlaInputs, N>::type;
169 
170   explicit ExhaustiveOpTestBase()
171       : ty_(T), platform_(client_->platform()->Name()) {
172     SetFastMathDisabled(true);
173 
174     // Run all HLO passes.  In particular, constant folding is disabled by
175     // default for tests, but we need to run it in order to tickle some bugs.
176     mutable_debug_options()->clear_xla_disable_hlo_passes();
177   }
178 
179   void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op) {
180     Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator<T, N>());
181   }
182 
183   // A helper for implementing the Run method for exhaustive op tests. It
184   // constructs the HLO module, compiles and runs the module and checks the
185   // result.
186   //
187   // We use a function pointer for evaluate_op for performance because it is
188   // called each time an output element is compared inside a loop in routine
189   // ExpectNear.
190   void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op,
191            ErrorSpecGen error_spec_gen) {
192     InputLiterals input_literals = CreateInputLiterals();
193     FillInput(&input_literals);
194 
195     XlaBuilder builder(TestName());
196     XlaInputs xla_inputs;
197     for (int i = 0; i < N; ++i) {
198       xla_inputs[i] =
199           Parameter(&builder, i, input_literals[i].shape(), "input");
200     }
201     EnqueueOpWrapper<XlaInputs, N>::BuildFromInputs(xla_inputs, enqueue_op);
202 
203     TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build());
204     TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
205                             RunComputationHelper(comp, input_literals));
206     ExpectNear(input_literals, result_literal, evaluate_op, error_spec_gen);
207   }
208 
209   StatusOr<Literal> RunComputationHelper(const XlaComputation& comp,
210                                          const Literal& literal) {
211     return RunComputation(comp, {&literal});
212   }
213 
214   StatusOr<Literal> RunComputationHelper(
215       const XlaComputation& comp, const std::array<Literal, N>& literals) {
216     std::array<const Literal*, N> lit_ptrs;
217     for (int i = 0; i < N; ++i) {
218       lit_ptrs[i] = &literals[i];
219     }
220     return RunComputation(comp, lit_ptrs);
221   }
222 
223   // We essentially reimplement LiteralTestUtil::Near here because
224   //  a) this streamlined implementation is much faster, and
225   //  b) we can print out better error messages (namely, we can print out
226   //     which floating-point value input failed, while LiteralTestUtil::Near
227   //     can only print out the input index that failed).
228   //  c) we need special handling of certain inputs.  For example, we say that
229   //     a denormal input has multiple correct outputs (namely, f(x) and f(0))
230   //     and just needs to be close to one of them.
231   void ExpectNear(const InputLiterals& input_literals,
232                   const Literal& result_literal, EvaluateOp evaluate_op,
233                   ErrorSpecGen error_spec_gen);
234 
235   // Builds and runs the computation using the LocalClient API, rather than the
236   // plain Client API, which is used by ClientLibraryTestBase.  This is because
237   // the plain Client API results does more memcpys to/from Literals, and that's
238   // slow given that we're touching a lot of data here.
239   StatusOr<Literal> RunComputation(
240       const XlaComputation& computation,
241       absl::Span<const Literal* const> input_literals) {
242     // Copy debug options from ClientLibraryTestBase.  In particular, we're
243     // interested in disabling constant folding.
244     ExecutableBuildOptions build_opts;
245     *build_opts.mutable_debug_options() = *mutable_debug_options();
246 
247     std::vector<ScopedShapedBuffer> input_buffers;
248     absl::c_transform(input_literals, std::back_inserter(input_buffers),
249                       [&](const Literal* input_literal) {
250                         return client_
251                             ->LiteralToShapedBuffer(*input_literal,
252                                                     /*device_ordinal=*/0)
253                             .ConsumeValueOrDie();
254                       });
255     std::vector<const Shape*> input_shapes;
256     absl::c_transform(input_buffers, std::back_inserter(input_shapes),
257                       [&](const ScopedShapedBuffer& buffer) {
258                         return &buffer.on_device_shape();
259                       });
260 
261     TF_ASSIGN_OR_RETURN(
262         auto executables,
263         client_->Compile(computation, input_shapes, build_opts));
264 
265     std::vector<const ShapedBuffer*> input_buffer_pointers;
266     absl::c_transform(
267         input_buffers, std::back_inserter(input_buffer_pointers),
268         [&](const ScopedShapedBuffer& buffer) { return &buffer; });
269 
270     ExecutableRunOptions run_opts;
271     run_opts.set_allocator(client_->backend().memory_allocator());
272     run_opts.set_intra_op_thread_pool(
273         client_->backend().eigen_intra_op_thread_pool_device());
274     TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
275                         executables[0]->Run(input_buffer_pointers, run_opts));
276 
277     TF_ASSIGN_OR_RETURN(Literal result_literal,
278                         client_->ShapedBufferToLiteral(result));
279     return std::move(result_literal);
280   }
281 
282   const string& Platform() { return platform_; }
283 
284   // Returns the number of elements in each input literal.
285   virtual int64 GetInputSize() = 0;
286 
287   // Fills the literals with values to test for.
288   virtual void FillInput(InputLiterals* literals) = 0;
289 
290   // Replace infinites with max value to help compute errors.
291   static ComponentNativeRefT ReplaceInfWithMax(ComponentNativeRefT value) {
292     if (std::isinf(value)) {
293       return std::copysign(std::numeric_limits<ComponentNativeRefT>::max(),
294                            value);
295     }
296     return value;
297   }
298 
299   // Returns true if both components are 0, but their sign bits differ.
300   static bool CheckSignedZeroError(ComponentNativeRefT expected,
301                                    ComponentNativeRefT actual) {
302     return expected == 0 && actual == 0 &&
303            std::signbit(expected) != std::signbit(actual);
304   }
305 
306   // Sets the components to 0 if both are NaNs.
307   static void RemoveCorrespondingNaNs(ComponentNativeRefT* expected,
308                                       ComponentNativeRefT* actual) {
309     if (std::isnan(*expected) && std::isnan(*actual)) {
310       *expected = 0;
311       *actual = 0;
312     }
313   }
314 
315   // The Implementation of the functions above, except for complex inputs.
316 
317   static std::complex<ComponentNativeRefT> ReplaceInfWithMax(
318       std::complex<ComponentNativeRefT> value) {
319     value.real(ReplaceInfWithMax(value.real()));
320     value.imag(ReplaceInfWithMax(value.imag()));
321     return value;
322   }
323 
324   static bool CheckSignedZeroError(std::complex<ComponentNativeRefT> expected,
325                                    std::complex<ComponentNativeRefT> actual) {
326     return CheckSignedZeroError(expected.real(), actual.real()) ||
327            CheckSignedZeroError(expected.imag(), actual.imag());
328   }
329 
330   static void RemoveCorrespondingNaNs(
331       std::complex<ComponentNativeRefT>* expected,
332       std::complex<ComponentNativeRefT>* actual) {
333     ComponentNativeRefT expected_real = expected->real();
334     ComponentNativeRefT expected_imag = expected->imag();
335     ComponentNativeRefT actual_real = actual->real();
336     ComponentNativeRefT actual_imag = actual->imag();
337     RemoveCorrespondingNaNs(&expected_real, &actual_real);
338     RemoveCorrespondingNaNs(&expected_imag, &actual_imag);
339     expected->real(expected_real);
340     expected->imag(expected_imag);
341     actual->real(actual_real);
342     actual->imag(actual_imag);
343   }
344 
345   // Returns a list of inputs that should be tested for closeness given some
346   // original input values.
347   //
348   // For denormal component inputs, we accept answers that are close to any of:
349   //
350   //   - evaluate_op(input)
351   //   - evaluate_op(+/-0), where the sign of 0 equal to the sign of
352   //     `input`,
353   //   - evaluate_op(+/-min_normal_float), where the sign of
354   //     min_normal_float matches `input`.
355   //   - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of
356   //     0 is the opposite of `input`.
357   //
358   // (In particular, the XLA:CPU implementation of log flushes positive
359   // denormals to min-normal-float.  This seems kind of reasonable if our
360   // goal is to avoid infinities because they cause nans?)
361   std::vector<ComponentNativeRefT> GetTestValuesWithSubnormalSubstitutions(
362       ComponentNativeRefT value) {
363     std::vector<ComponentNativeRefT> test_values;
364     if (std::fpclassify(value) == FP_SUBNORMAL) {
365       test_values.reserve(relaxed_denormal_signs_ ? 3 : 2);
366       test_values.push_back(std::copysign(0, value));
367       test_values.push_back(std::copysign(
368           std::numeric_limits<ComponentNativeRefT>::min(), value));
369       if (relaxed_denormal_signs_) {
370         test_values.push_back(std::copysign(0, -value));
371       }
372     } else {
373       test_values.push_back(value);
374     }
375     return test_values;
376   }
377 
378   // Similar to complex numbers, we only need to test the components that are
379   // subnormal. We can find the subnormal testing values for each component,
380   // then take the Cartesian product of each set of component values.
381   std::vector<std::complex<ComponentNativeRefT>>
382   GetTestValuesWithSubnormalSubstitutions(
383       std::complex<ComponentNativeRefT> value) {
384     using complex = std::complex<ComponentNativeRefT>;
385 
386     auto real_values = GetTestValuesWithSubnormalSubstitutions(value.real());
387     auto imag_values = GetTestValuesWithSubnormalSubstitutions(value.imag());
388 
389     std::vector<complex> test_values;
390     test_values.reserve(real_values.size() * imag_values.size());
391     for (auto real : real_values) {
392       for (auto imag : imag_values) {
393         test_values.push_back(complex(real, imag));
394       }
395     }
396 
397     return test_values;
398   }
399 
400   // The test values for an XLA function with N operands are the Cartesian
401   // product of the test values for each of the N operands.
402   std::vector<std::array<NativeRefT, N>>
403   GetTestValuesWithSubnormalSubstitutions(
404       const std::array<NativeRefT, N>& value) {
405     std::vector<std::array<NativeRefT, N>> test_values;
406 
407     std::array<std::vector<NativeRefT>, N> component_test_values;
408     int total = 1;
409     for (int i = 0; i < N; ++i) {
410       component_test_values[i] =
411           GetTestValuesWithSubnormalSubstitutions(value[i]);
412       if (!component_test_values.empty()) {
413         total *= component_test_values[i].size();
414       }
415     }
416 
417     // If total == 1, then value has no subnormal components, so we can just
418     // return a vector with value in it.
419     if (total == 1) {
420       test_values.push_back(value);
421       return test_values;
422     }
423 
424     test_values.reserve(total);
425 
426     // Perform a Cartesian product of the vectors in component_test_values.
427     // We can calculate this by uniquely mapping each integer from 0 to
428     // (total - 1) to a list of component indices. The function that maps an
429     // integer z to the index of component j is:
430     //    component_index(j) =  (i / NumValues(0, j-1)) % NumValues(j, j)
431     // and NumIndices(x, y) is the number of values in the Cartesian product of
432     // component_test_values[x], component_test_values[x+1], ...
433     // component_test_values[y].
434     for (int i = 0; i < total; ++i) {
435       int accumulated_num_values = 1;
436       std::array<NativeRefT, N> test_value;
437       for (int j = 0; j < N; ++j) {
438         int num_indices = component_test_values[j].size();
439         int component_index = (i / accumulated_num_values) % num_indices;
440         test_value[j] = component_test_values[j][component_index];
441         accumulated_num_values *= num_indices;
442       }
443       test_values.push_back(std::move(test_value));
444     }
445     return test_values;
446   }
447 
448   InputLiterals CreateInputLiterals() {
449     InputLiterals literals;
450     for (int i = 0; i < N; ++i) {
451       literals[i] = LiteralUtil::CreateFromDimensions(T, {GetInputSize()});
452     }
453     return std::move(literals);
454   }
455 
456   // Determines if two output values are sufficiently close to each other based
457   // on an error spec.
458   bool IsClose(NativeRefT expected, NativeRefT actual, ErrorSpec spec) {
459     // When two corresponding values are a NaN, they can be considered to have
460     // the same value, so the values are just set to 0.
461     RemoveCorrespondingNaNs(&expected, &actual);
462 
463     if (spec.strict_signed_zeros) {
464       if (CheckSignedZeroError(expected, actual)) {
465         return false;
466       }
467     }
468 
469     // Replace Inf with Max when calculating absolute or relative errors. This
470     // allows the test to pass when another value are close to Inf and the
471     // specified absolute or relative errors are not zero.
472     double abs_err =
473         std::abs(ReplaceInfWithMax(expected) - ReplaceInfWithMax(actual));
474     double rel_err = abs_err / std::abs(ReplaceInfWithMax(expected));
475 
476     return abs_err <= spec.abs_err || rel_err <= spec.rel_err;
477   }
478 
479   // Converts part or all bits in an uint64 to the value of the floating point
480   // data type being tested.
481   //
482   // When trying to exhaustive test for an operation of data type T, we always
483   // use an integral I with the same number of bits at T to exhaustive the input
484   // bit patterns for T. This bit pattern is zero extended and stored as uint64.
485   // This function is used to convert such a bit pattern stored as uint64 to
486   // the input value for T.
487   static ComponentNativeT ConvertValue(uint64 bits) {
488     using I = ComponentIntegralNativeT;
489     I used_bits = static_cast<I>(bits);
490     return BitCast<ComponentNativeT>(used_bits);
491   }
492 
493   ComponentNativeT ConvertAndReplaceKnownIncorrectValueWith(
494       uint64 bits, int replacement_value = 0) {
495     if (known_incorrect_fn_ && known_incorrect_fn_(bits)) {
496       return static_cast<ComponentNativeT>(replacement_value);
497     }
498     return ConvertValue(bits);
499   }
500 
501  protected:
502   // The primitive type being tested.
503   const PrimitiveType ty_;
504 
505   // The platform under test.
506   const string platform_;
507 
508   // Testing will ignore inputs for which known_incorrect_fn_ returns true. The
509   // argument to the function is the raw bits for the data being test, zero
510   // extended to 64 bits if the data type is less than 64 bits.
511   std::function<bool(int64_t)> known_incorrect_fn_;
512 
513   // If true, allows denormals to be flushed to non-sign-preserving 0.
514   //
515   // For example, normally we'd expect sqrt(-denormal) to be either nan (sqrt of
516   // a negative number) or -inf (flush the denormal to sign-perserving zero,
517   // then sqrt(-0)).  But with this as true, we'll also accept 0 (sqrt(0)).
518   //
519   // XLA:GPU preserves denormal signs, but other backends don't.
520   bool relaxed_denormal_signs_ = platform_ != "CUDA";
521 };
522 
523 // Represents a set of 64 bit chunks by representing the starting bit chunk,
524 // the last bit chunk, and the spacing between two adjacent bit chunks, without
525 // actually storing all the bit chunks being generated. The bit chunk iterator
526 // is provided to retrieve all the bit chunks.
527 //
528 // This data structure is used to generate the bit representation to test
529 // operations that requires more than 64 bit input data. In this case,
530 // truly exhaustive testing is not possible and we want to test a value every
531 // n values, where n == spacing_.
532 //
533 // Currently, the iterator of BitChunks adds the `spacing_` to a bit chunk to
534 // compute the next bit chunk. We can change this to use values generated
535 // by a random number generator that can achieve the average spacing
536 // statistically, if we will find this is necessary.
537 class BitChunks {
538  public:
539   class iterator
540       : public std::iterator<std::input_iterator_tag,  // iterator_category
541                              uint64,                   // value_type
542                              uint64,                   // difference_type
543                              const uint64*,            // pointer
544                              uint64                    // reference
545                              > {
546    public:
547     iterator() {}
548 
549     explicit iterator(const BitChunks* bit_chunks)
550         : bit_chunks_(bit_chunks), next_bit_chunk_(bit_chunks->start_) {}
551 
552     iterator& operator++() {
553       Next();
554       return *this;
555     }
556 
557     iterator operator++(int) {
558       iterator retval = *this;
559       Next();
560       return retval;
561     }
562 
563     bool operator==(iterator other) const {
564       return bit_chunks_ == other.bit_chunks_ &&
565              next_bit_chunk_ == other.next_bit_chunk_;
566     }
567 
568     bool operator!=(iterator other) const { return !(*this == other); }
569 
570     iterator MoveToEnd() {
571       MoveNextBitChunkToOnePassEnd();
572       return *this;
573     }
574 
575     reference operator*() const {
576       CHECK(*this != this->bit_chunks_->end());
577       return next_bit_chunk_;
578     }
579 
580     const BitChunks* GetBitChunks() const { return bit_chunks_; }
581 
582     void Reset() { next_bit_chunk_ = bit_chunks_->start_; }
583 
584     void Next() {
585       CHECK(*this != this->bit_chunks_->end());
586       if (next_bit_chunk_ == bit_chunks_->end_) {
587         MoveNextBitChunkToOnePassEnd();
588       } else {
589         next_bit_chunk_ += bit_chunks_->spacing_;
590         if (next_bit_chunk_ > bit_chunks_->end_) {
591           next_bit_chunk_ = bit_chunks_->end_;
592         }
593       }
594     }
595 
596     std::string ToString() const {
597       return absl::StrFormat("0x%08x", next_bit_chunk_);
598     }
599 
600    private:
601     // Move next_bit_chunk_ to 1 pass the bit_chunks_->end, to mark that the
602     // iterator has reached the end. When spacing_ is not one, or if we will
603     // change to use a random value instead of spacing_ in function Next(),
604     // normalizing the representation of the iterator ending this way can
605     // can simplify the checking for iterator ending.
606     void MoveNextBitChunkToOnePassEnd() {
607       next_bit_chunk_ = bit_chunks_->end_ + 1;
608     }
609 
610     const BitChunks* bit_chunks_;
611     uint64 next_bit_chunk_;
612   };
613 
614   iterator begin() const { return iterator(this); }
615   iterator end() const {
616     iterator end(this);
617     return end.MoveToEnd();
618   }
619 
620   explicit BitChunks(uint64 start = 0, uint64 end = 0, uint64 spacing = 1)
621       : start_(start), end_(end), spacing_(spacing) {
622     CHECK_GE(end_, start_);
623     CHECK_NE(spacing, 0) << ToString();
624   }
625 
626   int64 GetTotalBitChunks() const {
627     if (start_ == end_) {
628       return 1;
629     }
630 
631     return 1 + (end_ - start_ + spacing_ - 1) / spacing_;
632   }
633 
634   std::string ToString() const {
635     return absl::StrFormat("(0x%08x, 0x%08x, 0x%08x)", start_, end_, spacing_);
636   }
637 
638   uint64 start_;
639   uint64 end_;
640   uint64 spacing_;
641 };
642 
643 inline string StringifyNum(BitChunks c) { return c.ToString(); }
644 
645 inline string StringifyNum(BitChunks::iterator c) { return c.ToString(); }
646 
647 template <typename T>
648 void AppendStringifyNum(std::string* s, T x) {
649   absl::StrAppend(s, StringifyNum(x));
650 }
651 
652 // Represents a set of floating point values through the possible values for
653 // the three components: mantissa, exponent, and sign. Also implements an
654 // iterator for retrieving all the represented floating point values.
655 class FpValues {
656  public:
657   static constexpr uint kTotalBitChunks = 3;
658 
659   class iterator
660       : public std::iterator<std::input_iterator_tag,  // iterator_category
661                              uint64,                   // value_type
662                              uint64,                   // difference_type
663                              const uint64*,            // pointer
664                              uint64                    // reference
665                              > {
666    public:
667     explicit iterator(const FpValues* fp_values) : fp_values_(fp_values) {
668       for (int i = 0; i < FpValues::kTotalBitChunks; ++i) {
669         iters_[i] = BitChunks::iterator(&fp_values->GetBitChunks(i));
670       }
671     }
672 
673     iterator& operator++() {
674       Next();
675       return *this;
676     }
677 
678     iterator operator++(int) {
679       iterator retval = *this;
680       Next();
681       return retval;
682     }
683 
684     bool operator==(iterator other) const {
685       for (int i = 0; i < FpValues::kTotalBitChunks; ++i) {
686         if (iters_[i] != other.GetBitChunksIter(i)) {
687           return false;
688         }
689       }
690       return true;
691     }
692 
693     bool operator!=(iterator other) const { return !(*this == other); }
694 
695     iterator MoveToEnd() {
696       for (int i = 0; i < FpValues::kTotalBitChunks; ++i) {
697         iters_[i].MoveToEnd();
698       }
699       return *this;
700     }
701 
702     uint64 operator*() const {
703       uint64 value = 0;
704       for (int i = 0; i < FpValues::kTotalBitChunks; ++i) {
705         value = value | (*iters_[i]) << fp_values_->offsets_[i];
706       }
707       return value;
708     }
709 
710     const BitChunks::iterator& GetBitChunksIter(int i) { return iters_[i]; }
711 
712     std::string ToString() const {
713       return absl::StrJoin(iters_, ",",
714                            AppendStringifyNum<BitChunks::iterator>);
715     }
716 
717    private:
718     // Moves the iterator for the ith BitChunks to the next value, and
719     // returns true if the new state is not the end of the iterator.
720     bool Next(int i = 0) {
721       iters_[i].Next();
722       if (iters_[i] == iters_[i].GetBitChunks()->end()) {
723         if (i == FpValues::kTotalBitChunks - 1) {
724           return false;
725         }
726         if (Next(i + 1)) {
727           iters_[i].Reset();
728           return true;
729         }
730         return false;
731       }
732       return true;
733     }
734 
735     std::array<BitChunks::iterator, FpValues::kTotalBitChunks> iters_;
736     const FpValues* fp_values_;
737   };
738 
739   FpValues() : bit_chunks_(), offsets_() {}
740   FpValues(absl::Span<const BitChunks> chunks, absl::Span<const int> offsets) {
741     CHECK_EQ(chunks.size(), offsets.size() - 1);
742     CHECK_EQ(chunks.size(), kTotalBitChunks);
743     std::copy_n(chunks.begin(), kTotalBitChunks, bit_chunks_.begin());
744     std::copy_n(offsets.begin(), kTotalBitChunks, offsets_.begin());
745 
746     // The last value in `offsets` is the total number of bits.
747     offsets_[kTotalBitChunks] = offsets[kTotalBitChunks];
748     // Validate the input values.
749     for (int i = 0; i < kTotalBitChunks; ++i) {
750       int total_bits = offsets[i + 1] - offsets[i];
751       if (total_bits < 64) {
752         uint64 bound = 1ull << total_bits;
753         CHECK_LT(chunks[i].start_, bound);
754         CHECK_LT(chunks[i].end_, bound);
755       } else {
756         CHECK_EQ(total_bits, 64);
757       }
758     }
759   }
760 
761   iterator begin() const { return iterator(this); }
762 
763   iterator end() const {
764     iterator end(this);
765     return end.MoveToEnd();
766   }
767 
768   int64 GetTotalNumValues() const {
769     int64_t total = 1;
770     absl::c_for_each(bit_chunks_, [&](const BitChunks& chunks) {
771       total *= chunks.GetTotalBitChunks();
772     });
773     return total;
774   }
775 
776   const BitChunks& GetBitChunks(int i) const { return bit_chunks_[i]; }
777 
778   std::string ToString() const {
779     return absl::StrCat(
780         "[", absl::StrJoin(bit_chunks_, ",", AppendStringifyNum<BitChunks>),
781         "]");
782   }
783 
784   std::array<BitChunks, kTotalBitChunks> bit_chunks_;
785   std::array<int, kTotalBitChunks + 1> offsets_;
786 };
787 
788 template <typename T, typename std::enable_if<
789                           std::is_same<T, float>::value ||
790                           std::is_same<T, double>::value>::type* = nullptr>
791 int GetMantissaTotalBits() {
792   return std::numeric_limits<T>::digits - 1;
793 }
794 
795 template <typename T>
796 int GetFpTotalBits() {
797   return sizeof(T) * 8;
798 }
799 
800 template <typename T>
801 int GetExponentTotalBits() {
802   return GetFpTotalBits<T>() - GetMantissaTotalBits<T>() - 1;
803 }
804 
805 template <typename T>
806 uint64 GetAllOneMantissa() {
807   return (1ull << GetMantissaTotalBits<T>()) - 1ull;
808 }
809 
810 template <typename T>
811 uint64 GetAllOneExponent() {
812   return (1ull << GetExponentTotalBits<T>()) - 1ull;
813 }
814 
815 template <typename T, typename std::enable_if<
816                           std::is_same<T, float>::value ||
817                           std::is_same<T, double>::value>::type* = nullptr>
818 FpValues GetFpValues(BitChunks mantissa, BitChunks exponent, BitChunks sign) {
819   int total_bits = GetFpTotalBits<T>();
820   return FpValues({mantissa, exponent, sign},
821                   {0, GetMantissaTotalBits<T>(), total_bits - 1, total_bits});
822 }
823 
824 template <typename T>
825 FpValues GetZeros() {
826   return GetFpValues<T>(BitChunks(0, 0, 1), BitChunks(0, 0, 1),
827                         BitChunks(0, 1, 1));
828 }
829 
830 template <typename T>
831 FpValues GetSubnormals(int approx_num_values) {
832   int mantissa = GetMantissaTotalBits<T>();
833   uint64 mantissa_spacing = (1ull << mantissa) / (approx_num_values * 2);
834   return GetFpValues<T>(
835       BitChunks(0x1, GetAllOneMantissa<T>(), mantissa_spacing),
836       BitChunks(0, 0, 1), BitChunks(0, 1, 1));
837 }
838 
839 template <typename T>
840 FpValues GetInfinites() {
841   uint64 all_one_exp = GetAllOneExponent<T>();
842   return GetFpValues<T>(BitChunks(0, 0, 1),
843                         BitChunks(all_one_exp, all_one_exp, 1),
844                         BitChunks(0, 1, 1));
845 }
846 
847 template <typename T>
848 FpValues GetNans(int approx_num_values) {
849   int mantissa = GetMantissaTotalBits<T>();
850   uint64 mantissa_spacing = (1ull << mantissa) / (approx_num_values * 2);
851   uint64 all_one_exp = GetAllOneExponent<T>();
852   return GetFpValues<T>(
853       BitChunks(0x1, GetAllOneMantissa<T>(), mantissa_spacing),
854       BitChunks(all_one_exp, all_one_exp, 1), BitChunks(0, 1, 1));
855 }
856 
857 template <typename T>
858 FpValues GetNormals(int approx_num_values) {
859   float component_total = std::sqrt(static_cast<float>(approx_num_values));
860   return GetFpValues<T>(
861       BitChunks(0x1, GetAllOneMantissa<T>(),
862                 (1ull << (GetMantissaTotalBits<T>() + 1)) / component_total),
863       BitChunks(0x1, GetAllOneExponent<T>() - 1,
864                 (1ull << (GetExponentTotalBits<T>() + 1)) / component_total),
865       BitChunks(0, 1, 1));
866 }
867 
868 // Returns a vector of FpValues, which together represent about
869 // `approx_num_values` floating point values of type `T`, with each FpValues
870 // represents about `num_values_per_group` floating point values.
871 template <typename T>
872 std::vector<FpValues> GetFpValuesWithExponents(uint64 first_exponent,
873                                                uint64 exponent_spacing,
874                                                uint64 num_exponents,
875                                                uint64 approx_num_values,
876                                                uint64 num_values_per_group) {
877   const uint64 num_signs = 2;
878   uint64 approx_num_mantissa = approx_num_values / (num_exponents * num_signs);
879   uint64 num_mantissa_per_group =
880       num_values_per_group / (num_exponents * num_signs);
881   CHECK_GT(approx_num_mantissa, 0);
882   CHECK_GT(num_mantissa_per_group, 0);
883 
884   CHECK_LT(first_exponent + num_exponents - 1ull, GetAllOneExponent<T>());
885   int mantissa = GetMantissaTotalBits<T>();
886   uint64 mantissa_spacing = (1ull << mantissa) / approx_num_mantissa;
887 
888   std::vector<FpValues> result;
889   for (uint64 group_start = 0; group_start < GetAllOneMantissa<T>();
890        group_start += mantissa_spacing * num_mantissa_per_group) {
891     uint64 group_end =
892         group_start + (num_mantissa_per_group - 1) * mantissa_spacing;
893     if (group_end > GetAllOneMantissa<T>()) {
894       group_end = GetAllOneMantissa<T>();
895     }
896     result.push_back(GetFpValues<T>(
897         BitChunks(group_start, group_end, mantissa_spacing),
898         BitChunks(first_exponent, first_exponent + num_exponents - 1, 1),
899         BitChunks(0, 1, 1)));
900   }
901   return result;
902 }
903 
904 // Returns a vector of FpValues together represent about `approx_num_values`
905 // "very large" floating point values and `approx_num_values` "very small"
906 // floating point values of type `T`, which each FpValues represent about
907 // `num_values_per_group` floating point values. Because we use FpValues as
908 // a parameter for parameterized testing, the number of floating values
909 // represented by each FpValues affects the input size for each sub-test and
910 // the hence the peak memory usage of the test.
911 template <typename T>
912 std::vector<FpValues> GetFpValuesForMagnitudeExtremeNormals(
913     uint64 approx_num_values = 40000, uint64 num_values_per_group = 4000) {
914   std::vector<FpValues> large =
915       GetFpValuesWithExponents<T>(GetAllOneExponent<T>() - 5, 1, 5,
916                                   approx_num_values / 2, num_values_per_group);
917   std::vector<FpValues> small = GetFpValuesWithExponents<T>(
918       1, 1, 5, approx_num_values / 2, num_values_per_group);
919   large.insert(large.end(), small.begin(), small.end());
920   return large;
921 }
922 
923 template <typename T>
924 std::vector<FpValues> CreateFpValuesForBoundaryTest() {
925   return {GetZeros<T>(), GetSubnormals<T>(1000), GetInfinites<T>(),
926           GetNans<T>(1000)};
927 }
928 
929 inline std::vector<std::pair<int64, int64>> CreateExhaustiveF32Ranges() {
930   // We break up the 2^32-element space into small'ish chunks to keep peak
931   // memory usage low.
932   std::vector<std::pair<int64, int64>> result;
933   const int64_t step = 1 << 25;
934   for (int64_t i = 0; i < (1l << 32); i += step) {
935     result.push_back({i, i + step});
936   }
937   return result;
938 }
939 
940 template <PrimitiveType T, size_t N>
941 inline ErrorSpec DefaultSpecGenerator(
942     typename ExhaustiveOpTestBase<T, N>::NativeT) {
943   LOG(FATAL) << "Unhandled Type";
944 }
945 
946 template <PrimitiveType T, size_t N>
947 inline ErrorSpec DefaultSpecGenerator(
948     typename ExhaustiveOpTestBase<T, N>::NativeT,
949     typename ExhaustiveOpTestBase<T, N>::NativeT) {
950   LOG(FATAL) << "Unhandled Type";
951 }
952 
953 template <>
954 inline ErrorSpec DefaultSpecGenerator<C128, 1>(complex128) {
955   return ErrorSpec{0.0001, 0.0001};
956 }
957 
958 template <>
959 inline ErrorSpec DefaultSpecGenerator<C64, 1>(complex64) {
960   return ErrorSpec{0.0001, 0.0001};
961 }
962 
963 template <>
964 inline ErrorSpec DefaultSpecGenerator<F64, 1>(double) {
965   return ErrorSpec{0.0001, 0.0001};
966 }
967 
968 template <>
969 inline ErrorSpec DefaultSpecGenerator<F32, 1>(float) {
970   return ErrorSpec{0.0001, 0.0001};
971 }
972 
973 template <>
974 inline ErrorSpec DefaultSpecGenerator<F16, 1>(Eigen::half) {
975   return ErrorSpec{0.001, 0.001};
976 }
977 
978 template <>
979 inline ErrorSpec DefaultSpecGenerator<BF16, 1>(bfloat16) {
980   return ErrorSpec{0.002, 0.02};
981 }
982 
983 template <>
984 inline ErrorSpec DefaultSpecGenerator<F64, 2>(double, double) {
985   return ErrorSpec{0.001, 0.001};
986 }
987 
988 template <>
989 inline ErrorSpec DefaultSpecGenerator<F32, 2>(float, float) {
990   return ErrorSpec{0.001, 0.001};
991 }
992 
993 template <>
994 inline ErrorSpec DefaultSpecGenerator<F16, 2>(Eigen::half, Eigen::half) {
995   return ErrorSpec{0.001, 0.001};
996 }
997 
998 template <>
999 inline ErrorSpec DefaultSpecGenerator<BF16, 2>(bfloat16, bfloat16) {
1000   return ErrorSpec{0.002, 0.02};
1001 }
1002 
1003 template <PrimitiveType T, size_t N>
1004 typename ErrorSpecGenWrapper<T, N>::type GetDefaultSpecGenerator() {
1005   return DefaultSpecGenerator<T, N>;
1006 }
1007 
1008 template <typename T, typename std::enable_if<
1009                           std::is_same<T, float>::value ||
1010                           std::is_same<T, double>::value>::type* = nullptr>
1011 T ReferenceMax(T x, T y) {
1012   // We need to propagate NAN here because std::max may not propagate NAN.
1013   if (std::fpclassify(x) == FP_NAN) {
1014     return x;
1015   }
1016   if (std::fpclassify(y) == FP_NAN) {
1017     return y;
1018   }
1019 
1020   return std::max<T>(x, y);
1021 }
1022 
1023 template <typename T, typename std::enable_if<
1024                           std::is_same<T, float>::value ||
1025                           std::is_same<T, double>::value>::type* = nullptr>
1026 T ReferenceMin(T x, T y) {
1027   // We need to propagate NAN here because std::max may not propagate NAN.
1028   if (std::fpclassify(x) == FP_NAN) {
1029     return x;
1030   }
1031   if (std::fpclassify(y) == FP_NAN) {
1032     return y;
1033   }
1034 
1035   return std::min<T>(x, y);
1036 }
1037 
1038 // Returns a wrapper of the given build method, which build an HLO operation
1039 // with an empty broadcast dimension.
1040 inline std::function<XlaOp(XlaOp, XlaOp)> AddEmptyBroadcastDimension(
1041     std::function<XlaOp(XlaOp, XlaOp, absl::Span<const int64>)> build_method) {
1042   return [&](XlaOp src0, XlaOp src1) -> XlaOp {
1043     return build_method(src0, src1, {});
1044   };
1045 }
1046 
1047 template <PrimitiveType T>
1048 class ExhaustiveUnaryTest : public ExhaustiveOpTestBase<T, 1> {
1049  public:
1050   using typename ExhaustiveOpTestBase<T, 1>::ErrorSpecGen;
1051   static ErrorSpecGen GetDefaultSpecGenerator() {
1052     return exhaustive_op_test::GetDefaultSpecGenerator<T, 1>();
1053   }
1054 };
1055 
1056 template <PrimitiveType T>
1057 using ExhaustiveBinaryTest = ExhaustiveOpTestBase<T, 2>;
1058 
1059 }  // namespace exhaustive_op_test
1060 }  // namespace xla
1061 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_
1062