1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_ 17 #define TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_ 18 19 #include <array> 20 #include <cmath> 21 #include <iterator> 22 23 #include "tensorflow/compiler/xla/bit_cast.h" 24 #include "tensorflow/compiler/xla/client/lib/constants.h" 25 #include "tensorflow/compiler/xla/client/lib/math.h" 26 #include "tensorflow/compiler/xla/client/xla_builder.h" 27 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h" 29 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 30 #include "tensorflow/compiler/xla/tests/test_macros.h" 31 32 namespace xla { 33 namespace exhaustive_op_test { 34 35 struct ErrorSpec { 36 float abs_err; 37 float rel_err; 38 39 // If true, will consider -0 not near to +0 and vice versa. Note that 40 // +epsilon may still be considered close to -0, depending on the error 41 // spec; this only covers the case when both `expected` and `actual` are 42 // equal to 0. 43 bool strict_signed_zeros = false; 44 ErrorSpecErrorSpec45 ErrorSpec(float a, float r) : abs_err(a), rel_err(r) {} 46 }; 47 48 // Representations of the reference function passed in by the user. 49 template <typename NativeRefT, size_t K> 50 struct EvaluateOpWrapper {}; 51 template <typename NativeRefT> 52 struct EvaluateOpWrapper<NativeRefT, 1> { 53 using type = NativeRefT (*)(NativeRefT); 54 }; 55 template <typename NativeRefT> 56 struct EvaluateOpWrapper<NativeRefT, 2> { 57 using type = NativeRefT (*)(NativeRefT, NativeRefT); 58 }; 59 60 // Representations of the reference function passed in by the user. 61 template <typename XlaInputs, size_t K> 62 struct EnqueueOpWrapper {}; 63 template <typename XlaInputs> 64 struct EnqueueOpWrapper<XlaInputs, 1> { 65 using type = std::function<XlaOp(XlaOp)>; 66 static XlaOp BuildFromInputs(XlaInputs inputs, type ty) { 67 return ty(inputs[0]); 68 } 69 }; 70 template <typename XlaInputs> 71 struct EnqueueOpWrapper<XlaInputs, 2> { 72 using type = std::function<XlaOp(XlaOp, XlaOp)>; 73 static XlaOp BuildFromInputs(XlaInputs inputs, type ty) { 74 return ty(inputs[0], inputs[1]); 75 } 76 }; 77 78 // Representations of the ErrorSpecGen function passed in by the user. 79 template <PrimitiveType T, size_t K> 80 struct ErrorSpecGenWrapper {}; 81 template <PrimitiveType T> 82 struct ErrorSpecGenWrapper<T, 1> { 83 using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type; 84 using type = ErrorSpec (*)(NativeT); 85 }; 86 template <PrimitiveType T> 87 struct ErrorSpecGenWrapper<T, 2> { 88 using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type; 89 using type = ErrorSpec (*)(NativeT, NativeT); 90 }; 91 92 template <PrimitiveType T, size_t N> 93 typename ErrorSpecGenWrapper<T, N>::type GetDefaultSpecGenerator(); 94 95 // T: The primitive type being tested. 96 // N: The number of operands that the function being tested takes. 97 template <PrimitiveType T, size_t N> 98 class ExhaustiveOpTestBase : public ClientLibraryTestBase { 99 public: 100 // Definitions depending on the primitive type T. 101 102 static constexpr bool kIsComplex = (T == C128 || T == C64); 103 104 // The primitive type used to compute the reference output. 105 struct RefT { 106 static constexpr PrimitiveType value = (T == F16 || T == BF16) ? F32 : T; 107 }; 108 109 // The primitive type of the component of T. If T is not complex, then 110 // ComponentT = T. 111 struct ComponentT { 112 static constexpr PrimitiveType value = 113 !kIsComplex ? T 114 : T == C128 ? F64 : T == C64 ? F32 : PRIMITIVE_TYPE_INVALID; 115 }; 116 117 // Same as ComponentT, but for the RefT primitive type. 118 struct ComponentRefT { 119 static constexpr PrimitiveType value = 120 !kIsComplex ? RefT::value 121 : RefT::value == C128 122 ? F64 123 : RefT::value == C64 ? F32 : PRIMITIVE_TYPE_INVALID; 124 }; 125 126 // The primitive type of an unsigned integer that can be bitcasted to and from 127 // ComponentT. 128 struct ComponentIntegralT { 129 static constexpr PrimitiveType value = 130 (T == C128 || T == F64) 131 ? U64 132 : (T == C64 || T == F32) 133 ? U32 134 : (T == F16 || T == BF16) ? U16 : PRIMITIVE_TYPE_INVALID; 135 }; 136 137 // Native types that correspond to the primitive types above. 138 using NativeT = typename primitive_util::PrimitiveTypeToNative<T>::type; 139 using NativeRefT = 140 typename primitive_util::PrimitiveTypeToNative<RefT::value>::type; 141 using ComponentNativeT = 142 typename primitive_util::PrimitiveTypeToNative<ComponentT::value>::type; 143 using ComponentNativeRefT = typename primitive_util::PrimitiveTypeToNative< 144 ComponentRefT::value>::type; 145 using ComponentIntegralNativeT = 146 typename primitive_util::PrimitiveTypeToNative< 147 ComponentIntegralT::value>::type; 148 149 using InputLiterals = std::array<Literal, N>; 150 151 private: 152 // N spans corresponding to the list of literal data values. 153 using NativeInputsList = std::array<absl::Span<const NativeT>, N>; 154 155 // N data items representing a single input to an XLA function. 156 using NativeInputs = std::array<NativeT, N>; 157 158 // N data items representing a single input to an interpreter backend 159 // function. 160 using NativeRefInputs = std::array<NativeRefT, N>; 161 162 // N data items representing a single input to an XLA function. 163 using XlaInputs = std::array<XlaOp, N>; 164 165 public: 166 using ErrorSpecGen = typename ErrorSpecGenWrapper<T, N>::type; 167 using EvaluateOp = typename EvaluateOpWrapper<NativeRefT, N>::type; 168 using EnqueueOp = typename EnqueueOpWrapper<XlaInputs, N>::type; 169 170 explicit ExhaustiveOpTestBase() 171 : ty_(T), platform_(client_->platform()->Name()) { 172 SetFastMathDisabled(true); 173 174 // Run all HLO passes. In particular, constant folding is disabled by 175 // default for tests, but we need to run it in order to tickle some bugs. 176 mutable_debug_options()->clear_xla_disable_hlo_passes(); 177 } 178 179 void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op) { 180 Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator<T, N>()); 181 } 182 183 // A helper for implementing the Run method for exhaustive op tests. It 184 // constructs the HLO module, compiles and runs the module and checks the 185 // result. 186 // 187 // We use a function pointer for evaluate_op for performance because it is 188 // called each time an output element is compared inside a loop in routine 189 // ExpectNear. 190 void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op, 191 ErrorSpecGen error_spec_gen) { 192 InputLiterals input_literals = CreateInputLiterals(); 193 FillInput(&input_literals); 194 195 XlaBuilder builder(TestName()); 196 XlaInputs xla_inputs; 197 for (int i = 0; i < N; ++i) { 198 xla_inputs[i] = 199 Parameter(&builder, i, input_literals[i].shape(), "input"); 200 } 201 EnqueueOpWrapper<XlaInputs, N>::BuildFromInputs(xla_inputs, enqueue_op); 202 203 TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build()); 204 TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, 205 RunComputationHelper(comp, input_literals)); 206 ExpectNear(input_literals, result_literal, evaluate_op, error_spec_gen); 207 } 208 209 StatusOr<Literal> RunComputationHelper(const XlaComputation& comp, 210 const Literal& literal) { 211 return RunComputation(comp, {&literal}); 212 } 213 214 StatusOr<Literal> RunComputationHelper( 215 const XlaComputation& comp, const std::array<Literal, N>& literals) { 216 std::array<const Literal*, N> lit_ptrs; 217 for (int i = 0; i < N; ++i) { 218 lit_ptrs[i] = &literals[i]; 219 } 220 return RunComputation(comp, lit_ptrs); 221 } 222 223 // We essentially reimplement LiteralTestUtil::Near here because 224 // a) this streamlined implementation is much faster, and 225 // b) we can print out better error messages (namely, we can print out 226 // which floating-point value input failed, while LiteralTestUtil::Near 227 // can only print out the input index that failed). 228 // c) we need special handling of certain inputs. For example, we say that 229 // a denormal input has multiple correct outputs (namely, f(x) and f(0)) 230 // and just needs to be close to one of them. 231 void ExpectNear(const InputLiterals& input_literals, 232 const Literal& result_literal, EvaluateOp evaluate_op, 233 ErrorSpecGen error_spec_gen); 234 235 // Builds and runs the computation using the LocalClient API, rather than the 236 // plain Client API, which is used by ClientLibraryTestBase. This is because 237 // the plain Client API results does more memcpys to/from Literals, and that's 238 // slow given that we're touching a lot of data here. 239 StatusOr<Literal> RunComputation( 240 const XlaComputation& computation, 241 absl::Span<const Literal* const> input_literals) { 242 // Copy debug options from ClientLibraryTestBase. In particular, we're 243 // interested in disabling constant folding. 244 ExecutableBuildOptions build_opts; 245 *build_opts.mutable_debug_options() = *mutable_debug_options(); 246 247 std::vector<ScopedShapedBuffer> input_buffers; 248 absl::c_transform(input_literals, std::back_inserter(input_buffers), 249 [&](const Literal* input_literal) { 250 return client_ 251 ->LiteralToShapedBuffer(*input_literal, 252 /*device_ordinal=*/0) 253 .ConsumeValueOrDie(); 254 }); 255 std::vector<const Shape*> input_shapes; 256 absl::c_transform(input_buffers, std::back_inserter(input_shapes), 257 [&](const ScopedShapedBuffer& buffer) { 258 return &buffer.on_device_shape(); 259 }); 260 261 TF_ASSIGN_OR_RETURN( 262 auto executables, 263 client_->Compile(computation, input_shapes, build_opts)); 264 265 std::vector<const ShapedBuffer*> input_buffer_pointers; 266 absl::c_transform( 267 input_buffers, std::back_inserter(input_buffer_pointers), 268 [&](const ScopedShapedBuffer& buffer) { return &buffer; }); 269 270 ExecutableRunOptions run_opts; 271 run_opts.set_allocator(client_->backend().memory_allocator()); 272 run_opts.set_intra_op_thread_pool( 273 client_->backend().eigen_intra_op_thread_pool_device()); 274 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, 275 executables[0]->Run(input_buffer_pointers, run_opts)); 276 277 TF_ASSIGN_OR_RETURN(Literal result_literal, 278 client_->ShapedBufferToLiteral(result)); 279 return std::move(result_literal); 280 } 281 282 const string& Platform() { return platform_; } 283 284 // Returns the number of elements in each input literal. 285 virtual int64 GetInputSize() = 0; 286 287 // Fills the literals with values to test for. 288 virtual void FillInput(InputLiterals* literals) = 0; 289 290 // Replace infinites with max value to help compute errors. 291 static ComponentNativeRefT ReplaceInfWithMax(ComponentNativeRefT value) { 292 if (std::isinf(value)) { 293 return std::copysign(std::numeric_limits<ComponentNativeRefT>::max(), 294 value); 295 } 296 return value; 297 } 298 299 // Returns true if both components are 0, but their sign bits differ. 300 static bool CheckSignedZeroError(ComponentNativeRefT expected, 301 ComponentNativeRefT actual) { 302 return expected == 0 && actual == 0 && 303 std::signbit(expected) != std::signbit(actual); 304 } 305 306 // Sets the components to 0 if both are NaNs. 307 static void RemoveCorrespondingNaNs(ComponentNativeRefT* expected, 308 ComponentNativeRefT* actual) { 309 if (std::isnan(*expected) && std::isnan(*actual)) { 310 *expected = 0; 311 *actual = 0; 312 } 313 } 314 315 // The Implementation of the functions above, except for complex inputs. 316 317 static std::complex<ComponentNativeRefT> ReplaceInfWithMax( 318 std::complex<ComponentNativeRefT> value) { 319 value.real(ReplaceInfWithMax(value.real())); 320 value.imag(ReplaceInfWithMax(value.imag())); 321 return value; 322 } 323 324 static bool CheckSignedZeroError(std::complex<ComponentNativeRefT> expected, 325 std::complex<ComponentNativeRefT> actual) { 326 return CheckSignedZeroError(expected.real(), actual.real()) || 327 CheckSignedZeroError(expected.imag(), actual.imag()); 328 } 329 330 static void RemoveCorrespondingNaNs( 331 std::complex<ComponentNativeRefT>* expected, 332 std::complex<ComponentNativeRefT>* actual) { 333 ComponentNativeRefT expected_real = expected->real(); 334 ComponentNativeRefT expected_imag = expected->imag(); 335 ComponentNativeRefT actual_real = actual->real(); 336 ComponentNativeRefT actual_imag = actual->imag(); 337 RemoveCorrespondingNaNs(&expected_real, &actual_real); 338 RemoveCorrespondingNaNs(&expected_imag, &actual_imag); 339 expected->real(expected_real); 340 expected->imag(expected_imag); 341 actual->real(actual_real); 342 actual->imag(actual_imag); 343 } 344 345 // Returns a list of inputs that should be tested for closeness given some 346 // original input values. 347 // 348 // For denormal component inputs, we accept answers that are close to any of: 349 // 350 // - evaluate_op(input) 351 // - evaluate_op(+/-0), where the sign of 0 equal to the sign of 352 // `input`, 353 // - evaluate_op(+/-min_normal_float), where the sign of 354 // min_normal_float matches `input`. 355 // - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of 356 // 0 is the opposite of `input`. 357 // 358 // (In particular, the XLA:CPU implementation of log flushes positive 359 // denormals to min-normal-float. This seems kind of reasonable if our 360 // goal is to avoid infinities because they cause nans?) 361 std::vector<ComponentNativeRefT> GetTestValuesWithSubnormalSubstitutions( 362 ComponentNativeRefT value) { 363 std::vector<ComponentNativeRefT> test_values; 364 if (std::fpclassify(value) == FP_SUBNORMAL) { 365 test_values.reserve(relaxed_denormal_signs_ ? 3 : 2); 366 test_values.push_back(std::copysign(0, value)); 367 test_values.push_back(std::copysign( 368 std::numeric_limits<ComponentNativeRefT>::min(), value)); 369 if (relaxed_denormal_signs_) { 370 test_values.push_back(std::copysign(0, -value)); 371 } 372 } else { 373 test_values.push_back(value); 374 } 375 return test_values; 376 } 377 378 // Similar to complex numbers, we only need to test the components that are 379 // subnormal. We can find the subnormal testing values for each component, 380 // then take the Cartesian product of each set of component values. 381 std::vector<std::complex<ComponentNativeRefT>> 382 GetTestValuesWithSubnormalSubstitutions( 383 std::complex<ComponentNativeRefT> value) { 384 using complex = std::complex<ComponentNativeRefT>; 385 386 auto real_values = GetTestValuesWithSubnormalSubstitutions(value.real()); 387 auto imag_values = GetTestValuesWithSubnormalSubstitutions(value.imag()); 388 389 std::vector<complex> test_values; 390 test_values.reserve(real_values.size() * imag_values.size()); 391 for (auto real : real_values) { 392 for (auto imag : imag_values) { 393 test_values.push_back(complex(real, imag)); 394 } 395 } 396 397 return test_values; 398 } 399 400 // The test values for an XLA function with N operands are the Cartesian 401 // product of the test values for each of the N operands. 402 std::vector<std::array<NativeRefT, N>> 403 GetTestValuesWithSubnormalSubstitutions( 404 const std::array<NativeRefT, N>& value) { 405 std::vector<std::array<NativeRefT, N>> test_values; 406 407 std::array<std::vector<NativeRefT>, N> component_test_values; 408 int total = 1; 409 for (int i = 0; i < N; ++i) { 410 component_test_values[i] = 411 GetTestValuesWithSubnormalSubstitutions(value[i]); 412 if (!component_test_values.empty()) { 413 total *= component_test_values[i].size(); 414 } 415 } 416 417 // If total == 1, then value has no subnormal components, so we can just 418 // return a vector with value in it. 419 if (total == 1) { 420 test_values.push_back(value); 421 return test_values; 422 } 423 424 test_values.reserve(total); 425 426 // Perform a Cartesian product of the vectors in component_test_values. 427 // We can calculate this by uniquely mapping each integer from 0 to 428 // (total - 1) to a list of component indices. The function that maps an 429 // integer z to the index of component j is: 430 // component_index(j) = (i / NumValues(0, j-1)) % NumValues(j, j) 431 // and NumIndices(x, y) is the number of values in the Cartesian product of 432 // component_test_values[x], component_test_values[x+1], ... 433 // component_test_values[y]. 434 for (int i = 0; i < total; ++i) { 435 int accumulated_num_values = 1; 436 std::array<NativeRefT, N> test_value; 437 for (int j = 0; j < N; ++j) { 438 int num_indices = component_test_values[j].size(); 439 int component_index = (i / accumulated_num_values) % num_indices; 440 test_value[j] = component_test_values[j][component_index]; 441 accumulated_num_values *= num_indices; 442 } 443 test_values.push_back(std::move(test_value)); 444 } 445 return test_values; 446 } 447 448 InputLiterals CreateInputLiterals() { 449 InputLiterals literals; 450 for (int i = 0; i < N; ++i) { 451 literals[i] = LiteralUtil::CreateFromDimensions(T, {GetInputSize()}); 452 } 453 return std::move(literals); 454 } 455 456 // Determines if two output values are sufficiently close to each other based 457 // on an error spec. 458 bool IsClose(NativeRefT expected, NativeRefT actual, ErrorSpec spec) { 459 // When two corresponding values are a NaN, they can be considered to have 460 // the same value, so the values are just set to 0. 461 RemoveCorrespondingNaNs(&expected, &actual); 462 463 if (spec.strict_signed_zeros) { 464 if (CheckSignedZeroError(expected, actual)) { 465 return false; 466 } 467 } 468 469 // Replace Inf with Max when calculating absolute or relative errors. This 470 // allows the test to pass when another value are close to Inf and the 471 // specified absolute or relative errors are not zero. 472 double abs_err = 473 std::abs(ReplaceInfWithMax(expected) - ReplaceInfWithMax(actual)); 474 double rel_err = abs_err / std::abs(ReplaceInfWithMax(expected)); 475 476 return abs_err <= spec.abs_err || rel_err <= spec.rel_err; 477 } 478 479 // Converts part or all bits in an uint64 to the value of the floating point 480 // data type being tested. 481 // 482 // When trying to exhaustive test for an operation of data type T, we always 483 // use an integral I with the same number of bits at T to exhaustive the input 484 // bit patterns for T. This bit pattern is zero extended and stored as uint64. 485 // This function is used to convert such a bit pattern stored as uint64 to 486 // the input value for T. 487 static ComponentNativeT ConvertValue(uint64 bits) { 488 using I = ComponentIntegralNativeT; 489 I used_bits = static_cast<I>(bits); 490 return BitCast<ComponentNativeT>(used_bits); 491 } 492 493 ComponentNativeT ConvertAndReplaceKnownIncorrectValueWith( 494 uint64 bits, int replacement_value = 0) { 495 if (known_incorrect_fn_ && known_incorrect_fn_(bits)) { 496 return static_cast<ComponentNativeT>(replacement_value); 497 } 498 return ConvertValue(bits); 499 } 500 501 protected: 502 // The primitive type being tested. 503 const PrimitiveType ty_; 504 505 // The platform under test. 506 const string platform_; 507 508 // Testing will ignore inputs for which known_incorrect_fn_ returns true. The 509 // argument to the function is the raw bits for the data being test, zero 510 // extended to 64 bits if the data type is less than 64 bits. 511 std::function<bool(int64_t)> known_incorrect_fn_; 512 513 // If true, allows denormals to be flushed to non-sign-preserving 0. 514 // 515 // For example, normally we'd expect sqrt(-denormal) to be either nan (sqrt of 516 // a negative number) or -inf (flush the denormal to sign-perserving zero, 517 // then sqrt(-0)). But with this as true, we'll also accept 0 (sqrt(0)). 518 // 519 // XLA:GPU preserves denormal signs, but other backends don't. 520 bool relaxed_denormal_signs_ = platform_ != "CUDA"; 521 }; 522 523 // Represents a set of 64 bit chunks by representing the starting bit chunk, 524 // the last bit chunk, and the spacing between two adjacent bit chunks, without 525 // actually storing all the bit chunks being generated. The bit chunk iterator 526 // is provided to retrieve all the bit chunks. 527 // 528 // This data structure is used to generate the bit representation to test 529 // operations that requires more than 64 bit input data. In this case, 530 // truly exhaustive testing is not possible and we want to test a value every 531 // n values, where n == spacing_. 532 // 533 // Currently, the iterator of BitChunks adds the `spacing_` to a bit chunk to 534 // compute the next bit chunk. We can change this to use values generated 535 // by a random number generator that can achieve the average spacing 536 // statistically, if we will find this is necessary. 537 class BitChunks { 538 public: 539 class iterator 540 : public std::iterator<std::input_iterator_tag, // iterator_category 541 uint64, // value_type 542 uint64, // difference_type 543 const uint64*, // pointer 544 uint64 // reference 545 > { 546 public: 547 iterator() {} 548 549 explicit iterator(const BitChunks* bit_chunks) 550 : bit_chunks_(bit_chunks), next_bit_chunk_(bit_chunks->start_) {} 551 552 iterator& operator++() { 553 Next(); 554 return *this; 555 } 556 557 iterator operator++(int) { 558 iterator retval = *this; 559 Next(); 560 return retval; 561 } 562 563 bool operator==(iterator other) const { 564 return bit_chunks_ == other.bit_chunks_ && 565 next_bit_chunk_ == other.next_bit_chunk_; 566 } 567 568 bool operator!=(iterator other) const { return !(*this == other); } 569 570 iterator MoveToEnd() { 571 MoveNextBitChunkToOnePassEnd(); 572 return *this; 573 } 574 575 reference operator*() const { 576 CHECK(*this != this->bit_chunks_->end()); 577 return next_bit_chunk_; 578 } 579 580 const BitChunks* GetBitChunks() const { return bit_chunks_; } 581 582 void Reset() { next_bit_chunk_ = bit_chunks_->start_; } 583 584 void Next() { 585 CHECK(*this != this->bit_chunks_->end()); 586 if (next_bit_chunk_ == bit_chunks_->end_) { 587 MoveNextBitChunkToOnePassEnd(); 588 } else { 589 next_bit_chunk_ += bit_chunks_->spacing_; 590 if (next_bit_chunk_ > bit_chunks_->end_) { 591 next_bit_chunk_ = bit_chunks_->end_; 592 } 593 } 594 } 595 596 std::string ToString() const { 597 return absl::StrFormat("0x%08x", next_bit_chunk_); 598 } 599 600 private: 601 // Move next_bit_chunk_ to 1 pass the bit_chunks_->end, to mark that the 602 // iterator has reached the end. When spacing_ is not one, or if we will 603 // change to use a random value instead of spacing_ in function Next(), 604 // normalizing the representation of the iterator ending this way can 605 // can simplify the checking for iterator ending. 606 void MoveNextBitChunkToOnePassEnd() { 607 next_bit_chunk_ = bit_chunks_->end_ + 1; 608 } 609 610 const BitChunks* bit_chunks_; 611 uint64 next_bit_chunk_; 612 }; 613 614 iterator begin() const { return iterator(this); } 615 iterator end() const { 616 iterator end(this); 617 return end.MoveToEnd(); 618 } 619 620 explicit BitChunks(uint64 start = 0, uint64 end = 0, uint64 spacing = 1) 621 : start_(start), end_(end), spacing_(spacing) { 622 CHECK_GE(end_, start_); 623 CHECK_NE(spacing, 0) << ToString(); 624 } 625 626 int64 GetTotalBitChunks() const { 627 if (start_ == end_) { 628 return 1; 629 } 630 631 return 1 + (end_ - start_ + spacing_ - 1) / spacing_; 632 } 633 634 std::string ToString() const { 635 return absl::StrFormat("(0x%08x, 0x%08x, 0x%08x)", start_, end_, spacing_); 636 } 637 638 uint64 start_; 639 uint64 end_; 640 uint64 spacing_; 641 }; 642 643 inline string StringifyNum(BitChunks c) { return c.ToString(); } 644 645 inline string StringifyNum(BitChunks::iterator c) { return c.ToString(); } 646 647 template <typename T> 648 void AppendStringifyNum(std::string* s, T x) { 649 absl::StrAppend(s, StringifyNum(x)); 650 } 651 652 // Represents a set of floating point values through the possible values for 653 // the three components: mantissa, exponent, and sign. Also implements an 654 // iterator for retrieving all the represented floating point values. 655 class FpValues { 656 public: 657 static constexpr uint kTotalBitChunks = 3; 658 659 class iterator 660 : public std::iterator<std::input_iterator_tag, // iterator_category 661 uint64, // value_type 662 uint64, // difference_type 663 const uint64*, // pointer 664 uint64 // reference 665 > { 666 public: 667 explicit iterator(const FpValues* fp_values) : fp_values_(fp_values) { 668 for (int i = 0; i < FpValues::kTotalBitChunks; ++i) { 669 iters_[i] = BitChunks::iterator(&fp_values->GetBitChunks(i)); 670 } 671 } 672 673 iterator& operator++() { 674 Next(); 675 return *this; 676 } 677 678 iterator operator++(int) { 679 iterator retval = *this; 680 Next(); 681 return retval; 682 } 683 684 bool operator==(iterator other) const { 685 for (int i = 0; i < FpValues::kTotalBitChunks; ++i) { 686 if (iters_[i] != other.GetBitChunksIter(i)) { 687 return false; 688 } 689 } 690 return true; 691 } 692 693 bool operator!=(iterator other) const { return !(*this == other); } 694 695 iterator MoveToEnd() { 696 for (int i = 0; i < FpValues::kTotalBitChunks; ++i) { 697 iters_[i].MoveToEnd(); 698 } 699 return *this; 700 } 701 702 uint64 operator*() const { 703 uint64 value = 0; 704 for (int i = 0; i < FpValues::kTotalBitChunks; ++i) { 705 value = value | (*iters_[i]) << fp_values_->offsets_[i]; 706 } 707 return value; 708 } 709 710 const BitChunks::iterator& GetBitChunksIter(int i) { return iters_[i]; } 711 712 std::string ToString() const { 713 return absl::StrJoin(iters_, ",", 714 AppendStringifyNum<BitChunks::iterator>); 715 } 716 717 private: 718 // Moves the iterator for the ith BitChunks to the next value, and 719 // returns true if the new state is not the end of the iterator. 720 bool Next(int i = 0) { 721 iters_[i].Next(); 722 if (iters_[i] == iters_[i].GetBitChunks()->end()) { 723 if (i == FpValues::kTotalBitChunks - 1) { 724 return false; 725 } 726 if (Next(i + 1)) { 727 iters_[i].Reset(); 728 return true; 729 } 730 return false; 731 } 732 return true; 733 } 734 735 std::array<BitChunks::iterator, FpValues::kTotalBitChunks> iters_; 736 const FpValues* fp_values_; 737 }; 738 739 FpValues() : bit_chunks_(), offsets_() {} 740 FpValues(absl::Span<const BitChunks> chunks, absl::Span<const int> offsets) { 741 CHECK_EQ(chunks.size(), offsets.size() - 1); 742 CHECK_EQ(chunks.size(), kTotalBitChunks); 743 std::copy_n(chunks.begin(), kTotalBitChunks, bit_chunks_.begin()); 744 std::copy_n(offsets.begin(), kTotalBitChunks, offsets_.begin()); 745 746 // The last value in `offsets` is the total number of bits. 747 offsets_[kTotalBitChunks] = offsets[kTotalBitChunks]; 748 // Validate the input values. 749 for (int i = 0; i < kTotalBitChunks; ++i) { 750 int total_bits = offsets[i + 1] - offsets[i]; 751 if (total_bits < 64) { 752 uint64 bound = 1ull << total_bits; 753 CHECK_LT(chunks[i].start_, bound); 754 CHECK_LT(chunks[i].end_, bound); 755 } else { 756 CHECK_EQ(total_bits, 64); 757 } 758 } 759 } 760 761 iterator begin() const { return iterator(this); } 762 763 iterator end() const { 764 iterator end(this); 765 return end.MoveToEnd(); 766 } 767 768 int64 GetTotalNumValues() const { 769 int64_t total = 1; 770 absl::c_for_each(bit_chunks_, [&](const BitChunks& chunks) { 771 total *= chunks.GetTotalBitChunks(); 772 }); 773 return total; 774 } 775 776 const BitChunks& GetBitChunks(int i) const { return bit_chunks_[i]; } 777 778 std::string ToString() const { 779 return absl::StrCat( 780 "[", absl::StrJoin(bit_chunks_, ",", AppendStringifyNum<BitChunks>), 781 "]"); 782 } 783 784 std::array<BitChunks, kTotalBitChunks> bit_chunks_; 785 std::array<int, kTotalBitChunks + 1> offsets_; 786 }; 787 788 template <typename T, typename std::enable_if< 789 std::is_same<T, float>::value || 790 std::is_same<T, double>::value>::type* = nullptr> 791 int GetMantissaTotalBits() { 792 return std::numeric_limits<T>::digits - 1; 793 } 794 795 template <typename T> 796 int GetFpTotalBits() { 797 return sizeof(T) * 8; 798 } 799 800 template <typename T> 801 int GetExponentTotalBits() { 802 return GetFpTotalBits<T>() - GetMantissaTotalBits<T>() - 1; 803 } 804 805 template <typename T> 806 uint64 GetAllOneMantissa() { 807 return (1ull << GetMantissaTotalBits<T>()) - 1ull; 808 } 809 810 template <typename T> 811 uint64 GetAllOneExponent() { 812 return (1ull << GetExponentTotalBits<T>()) - 1ull; 813 } 814 815 template <typename T, typename std::enable_if< 816 std::is_same<T, float>::value || 817 std::is_same<T, double>::value>::type* = nullptr> 818 FpValues GetFpValues(BitChunks mantissa, BitChunks exponent, BitChunks sign) { 819 int total_bits = GetFpTotalBits<T>(); 820 return FpValues({mantissa, exponent, sign}, 821 {0, GetMantissaTotalBits<T>(), total_bits - 1, total_bits}); 822 } 823 824 template <typename T> 825 FpValues GetZeros() { 826 return GetFpValues<T>(BitChunks(0, 0, 1), BitChunks(0, 0, 1), 827 BitChunks(0, 1, 1)); 828 } 829 830 template <typename T> 831 FpValues GetSubnormals(int approx_num_values) { 832 int mantissa = GetMantissaTotalBits<T>(); 833 uint64 mantissa_spacing = (1ull << mantissa) / (approx_num_values * 2); 834 return GetFpValues<T>( 835 BitChunks(0x1, GetAllOneMantissa<T>(), mantissa_spacing), 836 BitChunks(0, 0, 1), BitChunks(0, 1, 1)); 837 } 838 839 template <typename T> 840 FpValues GetInfinites() { 841 uint64 all_one_exp = GetAllOneExponent<T>(); 842 return GetFpValues<T>(BitChunks(0, 0, 1), 843 BitChunks(all_one_exp, all_one_exp, 1), 844 BitChunks(0, 1, 1)); 845 } 846 847 template <typename T> 848 FpValues GetNans(int approx_num_values) { 849 int mantissa = GetMantissaTotalBits<T>(); 850 uint64 mantissa_spacing = (1ull << mantissa) / (approx_num_values * 2); 851 uint64 all_one_exp = GetAllOneExponent<T>(); 852 return GetFpValues<T>( 853 BitChunks(0x1, GetAllOneMantissa<T>(), mantissa_spacing), 854 BitChunks(all_one_exp, all_one_exp, 1), BitChunks(0, 1, 1)); 855 } 856 857 template <typename T> 858 FpValues GetNormals(int approx_num_values) { 859 float component_total = std::sqrt(static_cast<float>(approx_num_values)); 860 return GetFpValues<T>( 861 BitChunks(0x1, GetAllOneMantissa<T>(), 862 (1ull << (GetMantissaTotalBits<T>() + 1)) / component_total), 863 BitChunks(0x1, GetAllOneExponent<T>() - 1, 864 (1ull << (GetExponentTotalBits<T>() + 1)) / component_total), 865 BitChunks(0, 1, 1)); 866 } 867 868 // Returns a vector of FpValues, which together represent about 869 // `approx_num_values` floating point values of type `T`, with each FpValues 870 // represents about `num_values_per_group` floating point values. 871 template <typename T> 872 std::vector<FpValues> GetFpValuesWithExponents(uint64 first_exponent, 873 uint64 exponent_spacing, 874 uint64 num_exponents, 875 uint64 approx_num_values, 876 uint64 num_values_per_group) { 877 const uint64 num_signs = 2; 878 uint64 approx_num_mantissa = approx_num_values / (num_exponents * num_signs); 879 uint64 num_mantissa_per_group = 880 num_values_per_group / (num_exponents * num_signs); 881 CHECK_GT(approx_num_mantissa, 0); 882 CHECK_GT(num_mantissa_per_group, 0); 883 884 CHECK_LT(first_exponent + num_exponents - 1ull, GetAllOneExponent<T>()); 885 int mantissa = GetMantissaTotalBits<T>(); 886 uint64 mantissa_spacing = (1ull << mantissa) / approx_num_mantissa; 887 888 std::vector<FpValues> result; 889 for (uint64 group_start = 0; group_start < GetAllOneMantissa<T>(); 890 group_start += mantissa_spacing * num_mantissa_per_group) { 891 uint64 group_end = 892 group_start + (num_mantissa_per_group - 1) * mantissa_spacing; 893 if (group_end > GetAllOneMantissa<T>()) { 894 group_end = GetAllOneMantissa<T>(); 895 } 896 result.push_back(GetFpValues<T>( 897 BitChunks(group_start, group_end, mantissa_spacing), 898 BitChunks(first_exponent, first_exponent + num_exponents - 1, 1), 899 BitChunks(0, 1, 1))); 900 } 901 return result; 902 } 903 904 // Returns a vector of FpValues together represent about `approx_num_values` 905 // "very large" floating point values and `approx_num_values` "very small" 906 // floating point values of type `T`, which each FpValues represent about 907 // `num_values_per_group` floating point values. Because we use FpValues as 908 // a parameter for parameterized testing, the number of floating values 909 // represented by each FpValues affects the input size for each sub-test and 910 // the hence the peak memory usage of the test. 911 template <typename T> 912 std::vector<FpValues> GetFpValuesForMagnitudeExtremeNormals( 913 uint64 approx_num_values = 40000, uint64 num_values_per_group = 4000) { 914 std::vector<FpValues> large = 915 GetFpValuesWithExponents<T>(GetAllOneExponent<T>() - 5, 1, 5, 916 approx_num_values / 2, num_values_per_group); 917 std::vector<FpValues> small = GetFpValuesWithExponents<T>( 918 1, 1, 5, approx_num_values / 2, num_values_per_group); 919 large.insert(large.end(), small.begin(), small.end()); 920 return large; 921 } 922 923 template <typename T> 924 std::vector<FpValues> CreateFpValuesForBoundaryTest() { 925 return {GetZeros<T>(), GetSubnormals<T>(1000), GetInfinites<T>(), 926 GetNans<T>(1000)}; 927 } 928 929 inline std::vector<std::pair<int64, int64>> CreateExhaustiveF32Ranges() { 930 // We break up the 2^32-element space into small'ish chunks to keep peak 931 // memory usage low. 932 std::vector<std::pair<int64, int64>> result; 933 const int64_t step = 1 << 25; 934 for (int64_t i = 0; i < (1l << 32); i += step) { 935 result.push_back({i, i + step}); 936 } 937 return result; 938 } 939 940 template <PrimitiveType T, size_t N> 941 inline ErrorSpec DefaultSpecGenerator( 942 typename ExhaustiveOpTestBase<T, N>::NativeT) { 943 LOG(FATAL) << "Unhandled Type"; 944 } 945 946 template <PrimitiveType T, size_t N> 947 inline ErrorSpec DefaultSpecGenerator( 948 typename ExhaustiveOpTestBase<T, N>::NativeT, 949 typename ExhaustiveOpTestBase<T, N>::NativeT) { 950 LOG(FATAL) << "Unhandled Type"; 951 } 952 953 template <> 954 inline ErrorSpec DefaultSpecGenerator<C128, 1>(complex128) { 955 return ErrorSpec{0.0001, 0.0001}; 956 } 957 958 template <> 959 inline ErrorSpec DefaultSpecGenerator<C64, 1>(complex64) { 960 return ErrorSpec{0.0001, 0.0001}; 961 } 962 963 template <> 964 inline ErrorSpec DefaultSpecGenerator<F64, 1>(double) { 965 return ErrorSpec{0.0001, 0.0001}; 966 } 967 968 template <> 969 inline ErrorSpec DefaultSpecGenerator<F32, 1>(float) { 970 return ErrorSpec{0.0001, 0.0001}; 971 } 972 973 template <> 974 inline ErrorSpec DefaultSpecGenerator<F16, 1>(Eigen::half) { 975 return ErrorSpec{0.001, 0.001}; 976 } 977 978 template <> 979 inline ErrorSpec DefaultSpecGenerator<BF16, 1>(bfloat16) { 980 return ErrorSpec{0.002, 0.02}; 981 } 982 983 template <> 984 inline ErrorSpec DefaultSpecGenerator<F64, 2>(double, double) { 985 return ErrorSpec{0.001, 0.001}; 986 } 987 988 template <> 989 inline ErrorSpec DefaultSpecGenerator<F32, 2>(float, float) { 990 return ErrorSpec{0.001, 0.001}; 991 } 992 993 template <> 994 inline ErrorSpec DefaultSpecGenerator<F16, 2>(Eigen::half, Eigen::half) { 995 return ErrorSpec{0.001, 0.001}; 996 } 997 998 template <> 999 inline ErrorSpec DefaultSpecGenerator<BF16, 2>(bfloat16, bfloat16) { 1000 return ErrorSpec{0.002, 0.02}; 1001 } 1002 1003 template <PrimitiveType T, size_t N> 1004 typename ErrorSpecGenWrapper<T, N>::type GetDefaultSpecGenerator() { 1005 return DefaultSpecGenerator<T, N>; 1006 } 1007 1008 template <typename T, typename std::enable_if< 1009 std::is_same<T, float>::value || 1010 std::is_same<T, double>::value>::type* = nullptr> 1011 T ReferenceMax(T x, T y) { 1012 // We need to propagate NAN here because std::max may not propagate NAN. 1013 if (std::fpclassify(x) == FP_NAN) { 1014 return x; 1015 } 1016 if (std::fpclassify(y) == FP_NAN) { 1017 return y; 1018 } 1019 1020 return std::max<T>(x, y); 1021 } 1022 1023 template <typename T, typename std::enable_if< 1024 std::is_same<T, float>::value || 1025 std::is_same<T, double>::value>::type* = nullptr> 1026 T ReferenceMin(T x, T y) { 1027 // We need to propagate NAN here because std::max may not propagate NAN. 1028 if (std::fpclassify(x) == FP_NAN) { 1029 return x; 1030 } 1031 if (std::fpclassify(y) == FP_NAN) { 1032 return y; 1033 } 1034 1035 return std::min<T>(x, y); 1036 } 1037 1038 // Returns a wrapper of the given build method, which build an HLO operation 1039 // with an empty broadcast dimension. 1040 inline std::function<XlaOp(XlaOp, XlaOp)> AddEmptyBroadcastDimension( 1041 std::function<XlaOp(XlaOp, XlaOp, absl::Span<const int64>)> build_method) { 1042 return [&](XlaOp src0, XlaOp src1) -> XlaOp { 1043 return build_method(src0, src1, {}); 1044 }; 1045 } 1046 1047 template <PrimitiveType T> 1048 class ExhaustiveUnaryTest : public ExhaustiveOpTestBase<T, 1> { 1049 public: 1050 using typename ExhaustiveOpTestBase<T, 1>::ErrorSpecGen; 1051 static ErrorSpecGen GetDefaultSpecGenerator() { 1052 return exhaustive_op_test::GetDefaultSpecGenerator<T, 1>(); 1053 } 1054 }; 1055 1056 template <PrimitiveType T> 1057 using ExhaustiveBinaryTest = ExhaustiveOpTestBase<T, 2>; 1058 1059 } // namespace exhaustive_op_test 1060 } // namespace xla 1061 #endif // TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_ 1062