• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 #include "absl/base/casts.h"
18 #include "tensorflow/compiler/xla/client/lib/constants.h"
19 #include "tensorflow/compiler/xla/client/lib/math.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
22 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
23 #include "tensorflow/compiler/xla/tests/test_macros.h"
24 
25 namespace xla {
26 namespace {
27 
28 using Eigen::half;
29 
30 template <typename T, size_t N>
EvaluatePolynomial(T x,const std::array<T,N> & coeffs)31 T EvaluatePolynomial(T x, const std::array<T, N>& coeffs) {
32   T result = 0;
33   for (T c : coeffs) {
34     result = result * x + c;
35   }
36   return result;
37 }
38 
39 // There's no std::erfinv, so we have to implement it ourselves.  This follows
40 // Wichura 1998, https://www.jstor.org/stable/2347330 which, notably, is a
41 // different implementation from that in math.cc.
HostErfInv(float x)42 float HostErfInv(float x) {
43   std::array<double, 8> kPolyA = {
44       8.8709406962545514830200e2, 1.1819493347062294404278e4,
45       2.3782041382114385731252e4, 1.6235862515167575384252e4,
46       4.8548868893843886794648e3, 6.9706266534389598238465e2,
47       4.7072688112383978012285e1, 1.1975323115670912564578e0,
48   };
49   std::array<double, 8> kPolyB = {
50       5.2264952788528545610e3, 2.8729085735721942674e4, 3.9307895800092710610e4,
51       2.1213794301586595867e4, 5.3941960214247511077e3, 6.8718700749205790830e2,
52       4.2313330701600911252e1, 1.0000000000000000000e0,
53   };
54   std::array<double, 8> kPolyC = {
55       7.74545014278341407640e-4, 2.27238449892691845833e-2,
56       2.41780725177450611770e-1, 1.27045825245236838258e0,
57       3.64784832476320460504e0,  5.76949722146069140550e0,
58       4.63033784615654529590e0,  1.42343711074968357734e0,
59   };
60   std::array<double, 8> kPolyD = {
61       1.4859850019840355905497876e-9, 7.7441459065157709165577218e-4,
62       2.1494160384252876777097297e-2, 2.0945065210512749128288442e-1,
63       9.7547832001787427186894837e-1, 2.3707661626024532365971225e0,
64       2.9036514445419946173133295e0,  1.4142135623730950488016887e0,
65   };
66   std::array<double, 8> kPolyE = {
67       2.01033439929228813265e-7, 2.71155556874348757815e-5,
68       1.24266094738807843860e-3, 2.65321895265761230930e-2,
69       2.96560571828504891230e-1, 1.78482653991729133580e0,
70       5.46378491116411436990e0,  6.65790464350110377720e0,
71   };
72   std::array<double, 8> kPolyF = {
73       2.891024605872965461538222e-15, 2.010321207683943062279931e-7,
74       2.611088405080593625138020e-5,  1.112800997078859844711555e-3,
75       2.103693768272068968719679e-2,  1.936480946950659106176712e-1,
76       8.482908416595164588112026e-1,  1.414213562373095048801689e0,
77   };
78 
79   if (std::abs(x) > 1 || std::isnan(x)) {
80     return std::numeric_limits<float>::quiet_NaN();
81   }
82   if (std::abs(x) == 1) {
83     return std::copysign(std::numeric_limits<float>::infinity(), x);
84   }
85 
86   float unsigned_result = [&] {
87     float y = std::abs(x);
88     if (y <= 0.85) {
89       double r = 0.180625 - 0.25 * y * y;
90       return (y * EvaluatePolynomial(r, kPolyA)) /
91              EvaluatePolynomial(r, kPolyB);
92     } else {
93       double r = std::sqrt(std::log(2.0) - std::log1p(-y));
94       if (r <= 5.0) {
95         r -= 1.6;
96         return EvaluatePolynomial(r, kPolyC) / EvaluatePolynomial(r, kPolyD);
97       } else {
98         r -= 5;
99         return EvaluatePolynomial(r, kPolyE) / EvaluatePolynomial(r, kPolyF);
100       }
101     }
102   }();
103   return std::copysign(unsigned_result, x);
104 }
105 
106 // Digamma implementation using a polynomial from Cephes.  Notably this is a
107 // different implementation from the one in math.cc.
HostDigamma(float x)108 float HostDigamma(float x) {
109   // Euler-Mascheroni constant
110   float kGamma = 0.57721566490153286061;
111   float kPi = M_PI;
112 
113   std::array<float, 4> kPoly = {
114       -4.16666666666666666667E-3,
115       3.96825396825396825397E-3,
116       -8.33333333333333333333E-3,
117       8.33333333333333333333E-2,
118   };
119 
120   float reflection = 0;
121   if (x <= 0) {
122     float floor = std::floor(x);
123     if (x == floor) {
124       return std::numeric_limits<float>::quiet_NaN();
125     }
126     // Compute reflection term, pi * cot(pi * x).
127     reflection = x - floor;
128     if (reflection == 0.5) {
129       reflection = 0;
130     } else {
131       if (reflection > 0.5) {
132         reflection = x - (floor + 1.0f);
133       }
134       reflection = kPi / std::tan(kPi * reflection);
135     }
136     x = 1 - x;
137   }
138 
139   float result = 0;
140   if (x <= 10 && x == std::floor(x)) {
141     // Special case for integers <= 10.
142     for (int i = 1; i < x; ++i) {
143       result += 1.0f / i;
144     }
145     result -= kGamma;
146   } else {
147     float w = 0;
148     for (; x < 10; ++x) {
149       w += 1.0f / x;
150     }
151     if (x < 1e8) {
152       float z = 1.0f / (x * x);
153       result = z * EvaluatePolynomial(z, kPoly);
154     }
155     result = std::log(x) - 0.5f / x - result - w;
156   }
157 
158   // Compute the final, reflected value.
159   return result - reflection;
160 }
161 
162 // For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be
163 // guaranteed that we're printing the full number.
164 //
165 // (The general formula is, given a floating-point number with S significand
166 // bits, the number of decimal digits needed to print it to full precision is
167 //
168 //   ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103).
169 //
170 // See https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf.)
StringifyNum(float x)171 string StringifyNum(float x) {
172   return absl::StrFormat("%0.9g (0x%08x)", x, absl::bit_cast<uint32>(x));
173 }
174 
StringifyNum(half x)175 string StringifyNum(half x) {
176   return absl::StrFormat("%0.5g (0x%04x)", static_cast<float>(x),
177                          absl::bit_cast<uint16>(x));
178 }
179 
StringifyNum(bfloat16 x)180 string StringifyNum(bfloat16 x) {
181   return absl::StrFormat("%0.4g (0x%04x)", static_cast<float>(x),
182                          absl::bit_cast<uint16>(x));
183 }
184 
185 // Test parameter is a tuple containing
186 //   - primitive type under test,
187 //   - (begin, end) range under test, as zero-extended int64s bitcast to the
188 //     primtive type under test.
189 class ExhaustiveOpTest
190     : public ClientLibraryTestBase,
191       public ::testing::WithParamInterface<
192           std::tuple<PrimitiveType, std::pair<int64, int64>>> {
193  public:
ExhaustiveOpTest()194   ExhaustiveOpTest()
195       : ty_(std::get<0>(GetParam())), platform_(client_->platform()->Name()) {}
196 
Run(std::function<XlaOp (XlaOp)> enqueue_op,float (* evaluate_op)(float))197   void Run(std::function<XlaOp(XlaOp)> enqueue_op,
198            float (*evaluate_op)(float)) {
199     SetFastMathDisabled(true);
200 
201     // Run all HLO passes.  In particular, constant folding is disabled by
202     // default for tests, but we need to run it in order to tickle some bugs.
203     mutable_debug_options()->clear_xla_disable_hlo_passes();
204 
205     PrimitiveType ty;
206     std::tie(ty, std::ignore) = GetParam();
207 
208     switch (ty) {
209       case F32:
210         SetDefaultErrSpec(0.0001, 0.0001);
211         RunImpl<float, uint32>(enqueue_op, evaluate_op);
212         break;
213       case F16:
214         SetDefaultErrSpec(0.001, 0.001);
215         RunImpl<half, uint16>(enqueue_op, evaluate_op);
216         break;
217       case BF16:
218         SetDefaultErrSpec(0.001, 0.01);
219         RunImpl<bfloat16, uint16>(enqueue_op, evaluate_op);
220         break;
221       default:
222         LOG(FATAL) << "Unhandled type.";
223     }
224   }
225 
SetDefaultErrSpec(float abs_err,float rel_err)226   void SetDefaultErrSpec(float abs_err, float rel_err) {
227     if (!abs_err_.has_value()) {
228       abs_err_ = abs_err;
229     }
230     if (!rel_err_.has_value()) {
231       rel_err_ = rel_err;
232     }
233   }
234 
235   template <typename T, typename IntegralT>
RunImpl(std::function<XlaOp (XlaOp)> enqueue_op,float (* evaluate_op)(float))236   void RunImpl(std::function<XlaOp(XlaOp)> enqueue_op,
237                float (*evaluate_op)(float)) {
238     static_assert(
239         sizeof(T) == sizeof(IntegralT),
240         "IntegralT must be an unsigned integer type of the same width as T.");
241 
242     PrimitiveType ty;
243     std::pair<int64, int64> test_range;
244     std::tie(ty, test_range) = GetParam();
245     int64 begin, end;
246     std::tie(begin, end) = test_range;
247 
248     if (begin >= known_incorrect_begin_ && end <= known_incorrect_end_) {
249       LOG(INFO) << absl::StreamFormat(
250           "Skipping this shard, as the range under test, [%d, %d), falls "
251           "entirely within the known-incorrect range [%d, %d).",
252           begin, end, known_incorrect_begin_, known_incorrect_end_);
253       return;
254     }
255 
256     LOG(INFO) << "Checking range [" << begin << ", " << end << ")";
257 
258     int64 input_size = end - begin;
259     Literal input_literal = LiteralUtil::CreateFromDimensions(ty, {input_size});
260     absl::Span<T> input_arr = input_literal.data<T>();
261     for (int64 i = 0; i < input_size; i++) {
262       IntegralT input_val = i + begin;
263       // If the operation is known to be buggy on a specific input clamp that
264       // input to 0 under the assumption that the op is at least correct on 0.
265       if (input_val >= known_incorrect_begin_ &&
266           input_val < known_incorrect_end_) {
267         input_arr[i] = T{0};
268       } else {
269         input_arr[i] = absl::bit_cast<T>(input_val);
270       }
271     }
272 
273     TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
274                             BuildAndRunComputation(enqueue_op, input_literal));
275     ExpectNear<T>(input_literal, result_literal, evaluate_op);
276   }
277 
BuildAndRunComputation(const std::function<XlaOp (XlaOp)> & enqueue_op,const Literal & input_literal)278   StatusOr<Literal> BuildAndRunComputation(
279       const std::function<XlaOp(XlaOp)>& enqueue_op,
280       const Literal& input_literal) {
281     XlaBuilder builder(TestName());
282     auto input = Parameter(&builder, 0, input_literal.shape(), "input");
283     enqueue_op(input);
284     TF_ASSIGN_OR_RETURN(XlaComputation comp, builder.Build());
285 
286     // Build and run the computation using the LocalClient API, rather than the
287     // plain Client API, which is used by ClientLibraryTestBase.  This is
288     // because the plain Client API results does more memcpys to/from Literals,
289     // and that's slow given that we're touching a lot of data here.
290     //
291     // Copy debug options from ClientLibraryTestBase.  In particular, we're
292     // interested in disabling constant folding.
293     ExecutableBuildOptions build_opts;
294     *build_opts.mutable_debug_options() = *mutable_debug_options();
295     TF_ASSIGN_OR_RETURN(
296         auto executable,
297         client_->Compile(comp, {&input_literal.shape()}, build_opts));
298 
299     TF_ASSIGN_OR_RETURN(
300         ScopedShapedBuffer input_data,
301         client_->LiteralToShapedBuffer(input_literal, /*device_ordinal=*/0));
302 
303     ExecutableRunOptions run_opts;
304     run_opts.set_allocator(client_->backend().memory_allocator());
305     run_opts.set_intra_op_thread_pool(
306         client_->backend().eigen_intra_op_thread_pool_device());
307     TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
308                         executable->Run({&input_data}, run_opts));
309 
310     TF_ASSIGN_OR_RETURN(Literal result_literal,
311                         client_->ShapedBufferToLiteral(result));
312     return std::move(result_literal);
313   }
314 
315   template <typename T>
IsClose(T expected,T actual)316   bool IsClose(T expected, T actual) {
317     float expected_f32 = static_cast<float>(expected);
318     float actual_f32 = static_cast<float>(actual);
319     float abs_err = std::abs(expected_f32 - actual_f32);
320     float rel_err = abs_err / std::abs(expected_f32);
321     if (strict_signed_zeros_ && actual == T{0} && expected == T{0}) {
322       // Check sign of zero.
323       return std::signbit(actual_f32) == std::signbit(expected_f32);
324     }
325     return abs_err < *abs_err_ || rel_err < *rel_err_ ||
326            (std::isnan(expected_f32) && std::isnan(actual_f32)) ||
327            (std::isinf(expected_f32) && std::isinf(actual_f32) &&
328             (expected_f32 > 0) == (actual_f32 > 0));
329   }
330 
331   template <typename T>
ExpectNear(const Literal & input_literal,const Literal & result_literal,float (* evaluate_op)(float))332   void ExpectNear(const Literal& input_literal, const Literal& result_literal,
333                   float (*evaluate_op)(float)) {
334     // We essentially reimplement LiteralTestUtil::Near here because
335     //  a) this streamlined implementation is much faster, and
336     //  b) we can print out better error messages (namely, we can print out
337     //     which floating-point value input failed, while LiteralTestUtil::Near
338     //     can only print out the input index that failed).
339     //  c) we need special handling of certain inputs.  For example, we say that
340     //     a denormal input has multiple correct outputs (namely, f(x) and f(0))
341     //     and just needs to be close to one of them.
342     absl::Span<const T> input_arr = input_literal.data<T>();
343     absl::Span<const T> result_arr = result_literal.data<T>();
344     ASSERT_EQ(result_arr.size(), input_arr.size());
345     int64 mismatches = 0;
346     // Hoisting these out of the loop is a nice speedup on shards that have many
347     // denormals.
348     const T expected_at_pos_zero = static_cast<T>(evaluate_op(0));
349     const T expected_at_neg_zero = static_cast<T>(evaluate_op(-0.0));
350     for (int64 i = 0; i < input_arr.size(); ++i) {
351       T input = input_arr[i];
352       float input_f32 = static_cast<float>(input);
353       T actual = result_arr[i];
354       T expected = static_cast<T>(evaluate_op(input_f32));
355 
356       if (IsClose(expected, actual)) {
357         continue;
358       }
359 
360       // Easy case: If `input` is not denormal and !IsClose(expected, actual),
361       // print an error.
362       //
363       // (This doesn't correctly detect f16 and bfloat16 denormals!  This seems
364       // to be OK for now, but at some point we may need to implement fpclassify
365       // for half and bfloat.)
366       if (std::fpclassify(input_f32) != FP_SUBNORMAL) {
367         PrintMismatch(&mismatches, [&] {
368           return absl::StrFormat("Mismatch on %s. Expected %s, but got %s.",
369                                  StringifyNum(input), StringifyNum(expected),
370                                  StringifyNum(actual));
371         });
372         continue;
373       }
374 
375       // Otherwise, `input` is denormal.  For denormal inputs, we accept answers
376       // that are close to any of:
377       //
378       //   - evaluate_op(input)
379       //   - evaluate_op(+/-0), where the sign of 0 equal to the sign of
380       //     `input`,
381       //   - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of
382       //     0 is the opposite of `input`.
383       T sign_preserving_ftz_expected =
384           std::signbit(input_f32) ? expected_at_neg_zero : expected_at_pos_zero;
385       T sign_nonpreserving_ftz_expected =
386           std::signbit(input_f32) ? expected_at_pos_zero : expected_at_neg_zero;
387       if (IsClose(sign_preserving_ftz_expected, actual) ||
388           (relaxed_denormal_signs_ &&
389            IsClose(sign_nonpreserving_ftz_expected, actual))) {
390         continue;
391       }
392 
393       if (relaxed_denormal_signs_) {
394         PrintMismatch(&mismatches, [&] {
395           return absl::StrFormat(
396               "Mismatch on denormal value %s.  Expected one of:\n"
397               "  %10s (evaluated at full-precision value)\n"
398               "  %10s (evaluated after flushing to sign-preserving zero)\n"
399               "  %10s (evaluated after flushing to non-sign-preserving "
400               "zero)\n"
401               "but got %s.",
402               StringifyNum(input), StringifyNum(expected),
403               StringifyNum(sign_preserving_ftz_expected),
404               StringifyNum(sign_nonpreserving_ftz_expected),
405               StringifyNum(actual));
406         });
407       } else {
408         PrintMismatch(&mismatches, [&] {
409           return absl::StrFormat(
410               "Mismatch on denormal value %s.  Expected one of:\n"
411               "  %10s (evaluated at full-precision value)\n"
412               "  %10s (evaluated after flushing to sign-preserving zero)\n"
413               "but got %s.",
414               StringifyNum(input), StringifyNum(expected),
415               StringifyNum(sign_preserving_ftz_expected), StringifyNum(actual));
416         });
417       }
418     }
419     EXPECT_EQ(mismatches, 0);
420   }
421 
422   template <typename ErrorGenerator>
PrintMismatch(int64 * mismatches,const ErrorGenerator & err_generator)423   void PrintMismatch(int64* mismatches, const ErrorGenerator& err_generator) {
424     // We send a few mismatches to gunit so they show up nicely in test logs.
425     // Then we send more to LOG(ERROR).  The remainder we squelch unless we're
426     // at vlog level 2.
427     constexpr int64 kMaxMismatchesLoggedToGunit = 10;
428     constexpr int64 kMaxMismatchesLoggedToErr = 1000;
429 
430     (*mismatches)++;
431     if (*mismatches < kMaxMismatchesLoggedToGunit) {
432       FAIL() << err_generator();
433     } else if (*mismatches < kMaxMismatchesLoggedToErr || VLOG_IS_ON(2)) {
434       LOG(ERROR) << err_generator();
435     } else if (*mismatches == kMaxMismatchesLoggedToErr) {
436       LOG(ERROR) << "Not printing any more mismatches; pass "
437                     "--vmodule=exhaustive_f32__op_test=2 to see "
438                     "all of them.";
439     }
440   }
441 
442   // The following members are set during construction so testcases can read
443   // these values and use them e.g. to influence the values given to the mutable
444   // members below.
445 
446   // The primitive type under test.
447   const PrimitiveType ty_;
448 
449   // The platform under test.
450   const string platform_;
451 
452   // Tests can set the following variables for control over execution.  This is
453   // safe because each XLA_TEST_P instantiates a new instance of this class.
454 
455   // Testing will ignore the given range (encoded as bitwise representations of
456   // the type under test zero-extended to int64).
457   int64 known_incorrect_begin_ = 0;
458   int64 known_incorrect_end_ = 0;
459 
460   // If unset, reasonable defaults will be used depending on the type under
461   // test.
462   absl::optional<float> abs_err_;
463   absl::optional<float> rel_err_;
464 
465   // If true, will consider -0 not near to +0 and vice versa.  Note that
466   // +epsilon may still be considered close to -0, depending on the error spec;
467   // this only covers the case when both `expected` and `actual` are equal to 0.
468   bool strict_signed_zeros_ = false;
469 
470   // If true, allows denormals to be flushed to non-sign-preserving 0.
471   //
472   // For example, normally we'd expect sqrt(-denormal) to be either nan (sqrt of
473   // a negative number) or -inf (flush the denormal to sign-perserving zero,
474   // then sqrt(-0)).  But with this as true, we'll also accept 0 (sqrt(0)).
475   //
476   // XLA:GPU preserves denormal signs, but other backends don't.
477   bool relaxed_denormal_signs_ = platform_ != "CUDA";
478 };
479 
XLA_TEST_P(ExhaustiveOpTest,Log)480 XLA_TEST_P(ExhaustiveOpTest, Log) {
481   if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
482     abs_err_ = 0.001;
483     rel_err_ = 0.001;
484   }
485 
486   Run(Log, std::log);
487 }
488 
XLA_TEST_P(ExhaustiveOpTest,Log1p)489 XLA_TEST_P(ExhaustiveOpTest, Log1p) {
490   if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
491     abs_err_ = 0.001;
492     rel_err_ = 0.001;
493   }
494 
495   Run(Log1p, std::log1p);
496 }
497 
XLA_TEST_P(ExhaustiveOpTest,Exp)498 XLA_TEST_P(ExhaustiveOpTest, Exp) {
499   if (platform_ == "Host" && ty_ == F32) {
500     // TODO(b/73142289): The vectorized Exp implementation gives results outside
501     // our error spec in this range.
502     known_incorrect_begin_ = 1107296256 + 11583654;
503     known_incorrect_end_ = 1107296256 + 11629080;
504   } else if (platform_ == "Host" && ty_ == BF16) {
505     // TODO(jlebar): Is this a rounding error?  Why doesn't it occur on XLA:GPU?
506     //
507     // Mismatch on 88.5 (0x42b1).
508     //   Expected 2.72491739e+38 (0x7f4d), but got inf (0x7f80).
509     known_incorrect_begin_ = 0x42b1;
510     known_incorrect_end_ = 0x42b2;
511   }
512 
513   Run(Exp, std::exp);
514 }
515 
XLA_TEST_P(ExhaustiveOpTest,Expm1)516 XLA_TEST_P(ExhaustiveOpTest, Expm1) {
517   // Expm1 has the same erroneous behavior on CPU as Exp.
518   if (platform_ == "Host" && ty_ == F32) {
519     // TODO(b/73142289): The vectorized Exp implementation gives results outside
520     // our error spec in this range.
521     known_incorrect_begin_ = 1107296256 + 11583654;
522     known_incorrect_end_ = 1107296256 + 11629080;
523   } else if (platform_ == "Host" && ty_ == BF16) {
524     // TODO(jlebar): Is this a rounding error?  Why doesn't it occur on XLA:GPU?
525     //
526     // Mismatch on 88.5 (0x42b1).
527     //   Expected 2.72491739e+38 (0x7f4d), but got inf (0x7f80).
528     known_incorrect_begin_ = 0x42b1;
529     known_incorrect_end_ = 0x42b2;
530   }
531 
532   Run(Expm1, std::expm1);
533 }
534 
535 // It feels a little overkill to exhaustively test sqrt and pow(x, 0.5), but
536 // this *did* find a bug, namely that some backends were assuming sqrt(x) ==
537 // pow(x, 0.5), but this is not true for x == -inf.
XLA_TEST_P(ExhaustiveOpTest,PowOneHalf)538 XLA_TEST_P(ExhaustiveOpTest, PowOneHalf) {
539   Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); },
540       +[](float x) { return std::pow(x, 0.5f); });
541 }
542 
XLA_TEST_P(ExhaustiveOpTest,Rsqrt)543 XLA_TEST_P(ExhaustiveOpTest, Rsqrt) {
544   Run(
545       Rsqrt, +[](float x) { return 1 / std::sqrt(x); });
546 }
547 
XLA_TEST_P(ExhaustiveOpTest,Sqrt)548 XLA_TEST_P(ExhaustiveOpTest, Sqrt) {
549   if (platform_ == "Host" || platform_ == "CUDA") {
550     strict_signed_zeros_ = true;
551   }
552 
553   Run(Sqrt, std::sqrt);
554 }
555 
556 // TODO(jlebar): Add remaining trig functions.  Don't forget Atan2!
557 // TODO(jlebar): Test trig functions over complex inputs.
XLA_TEST_P(ExhaustiveOpTest,Tanh)558 XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); }
559 
XLA_TEST_P(ExhaustiveOpTest,Erf)560 XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); }
XLA_TEST_P(ExhaustiveOpTest,Erfc)561 XLA_TEST_P(ExhaustiveOpTest, Erfc) { Run(Erfc, std::erfc); }
XLA_TEST_P(ExhaustiveOpTest,ErfInv)562 XLA_TEST_P(ExhaustiveOpTest, ErfInv) { Run(ErfInv, HostErfInv); }
XLA_TEST_P(ExhaustiveOpTest,Digamma)563 XLA_TEST_P(ExhaustiveOpTest, Digamma) {
564   if (platform_ != "Host" && platform_ != "CUDA") {
565     // TODO(b/123956399): This is a fairly high error, significantly higher than
566     // we see on CPU/GPU.
567     rel_err_ = 0.01;
568     abs_err_ = 0.01;
569   }
570 
571   if (platform_ == "CUDA") {
572     // On GPU we get a wrong answer for the denormal inputs +/-2.93873588e-39
573     // (0x00200000 and 0x80200000).  These should return -/+inf (at least
574     // according to our reference implementation!) but XLA:GPU returns
575     // -/+3.40282326e+38 (0xff7ffffe and 0x7f7ffffe).
576     //
577     // I deem this an acceptable result, as XLA:GPU flushes denormals, and as
578     // the results we get here are very close to MAX_FLOAT.  We just hardcode
579     // these results, as this is better than ignoring these inputs altogether.
580     auto host_digamma_with_gpu_ftz_errors = +[](float x) {
581       if (absl::bit_cast<uint32>(x) == 0x00200000 ||
582           absl::bit_cast<uint32>(x) == 0x80200000) {
583         return std::copysign(std::numeric_limits<float>::max(), -x);
584       }
585       return HostDigamma(x);
586     };
587     Run(Digamma, host_digamma_with_gpu_ftz_errors);
588   } else {
589     Run(Digamma, HostDigamma);
590   }
591 }
XLA_TEST_P(ExhaustiveOpTest,Lgamma)592 XLA_TEST_P(ExhaustiveOpTest, Lgamma) {
593   // Our implementation gets within 0.0001 rel error except for ~20 denormal
594   // inputs on GPU.  Anyway 0.001 rel error should be good enough for lgamma.
595   if (platform_ == "CUDA" && (ty_ == F32 || ty_ == F16)) {
596     rel_err_ = 0.001;
597   }
598   if (platform_ != "Host" && platform_ != "CUDA") {
599     // TODO(b/123956399): This is a fairly high error, significantly higher than
600     // we see on CPU/GPU.
601     rel_err_ = 0.01;
602     abs_err_ = 0.01;
603 
604     // Overflows for to inf for input 4.08500343e+36 (0x7c44af8e).
605     if (ty_ == F32) {
606       known_incorrect_begin_ = 0x7c44af8e;
607       known_incorrect_end_ = 0x7c44af8e + 1;
608     }
609   }
610   Run(Lgamma, std::lgamma);
611 }
612 
XLA_TEST_P(ExhaustiveOpTest,Round)613 XLA_TEST_P(ExhaustiveOpTest, Round) { Run(Round, std::round); }
614 
CreateExhaustiveF32Ranges()615 std::vector<std::pair<int64, int64>> CreateExhaustiveF32Ranges() {
616   // We break up the 2^32-element space into small'ish chunks to keep peak
617   // memory usage low.
618   std::vector<std::pair<int64, int64>> result;
619   const int64 step = 1 << 25;
620   for (int64 i = 0; i < (1l << 32); i += step) {
621     result.push_back({i, i + step});
622   }
623   return result;
624 }
625 
626 INSTANTIATE_TEST_SUITE_P(
627     F32, ExhaustiveOpTest,
628     ::testing::Combine(::testing::Values(F32),
629                        ::testing::ValuesIn(CreateExhaustiveF32Ranges())));
630 
631 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
632 INSTANTIATE_TEST_SUITE_P(
633     F16, ExhaustiveOpTest,
634     ::testing::Combine(::testing::Values(F16),
635                        ::testing::Values(std::make_pair(0, 1 << 16))));
636 #endif
637 
638 #if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
639 INSTANTIATE_TEST_SUITE_P(
640     BF16, ExhaustiveOpTest,
641     ::testing::Combine(::testing::Values(BF16),
642                        ::testing::Values(std::make_pair(0, 1 << 16))));
643 #endif
644 
645 }  // namespace
646 }  // namespace xla
647