1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cmath>
17 #include "absl/base/casts.h"
18 #include "tensorflow/compiler/xla/client/lib/constants.h"
19 #include "tensorflow/compiler/xla/client/lib/math.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
22 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
23 #include "tensorflow/compiler/xla/tests/test_macros.h"
24
25 namespace xla {
26 namespace {
27
28 using Eigen::half;
29
30 template <typename T, size_t N>
EvaluatePolynomial(T x,const std::array<T,N> & coeffs)31 T EvaluatePolynomial(T x, const std::array<T, N>& coeffs) {
32 T result = 0;
33 for (T c : coeffs) {
34 result = result * x + c;
35 }
36 return result;
37 }
38
39 // There's no std::erfinv, so we have to implement it ourselves. This follows
40 // Wichura 1998, https://www.jstor.org/stable/2347330 which, notably, is a
41 // different implementation from that in math.cc.
HostErfInv(float x)42 float HostErfInv(float x) {
43 std::array<double, 8> kPolyA = {
44 8.8709406962545514830200e2, 1.1819493347062294404278e4,
45 2.3782041382114385731252e4, 1.6235862515167575384252e4,
46 4.8548868893843886794648e3, 6.9706266534389598238465e2,
47 4.7072688112383978012285e1, 1.1975323115670912564578e0,
48 };
49 std::array<double, 8> kPolyB = {
50 5.2264952788528545610e3, 2.8729085735721942674e4, 3.9307895800092710610e4,
51 2.1213794301586595867e4, 5.3941960214247511077e3, 6.8718700749205790830e2,
52 4.2313330701600911252e1, 1.0000000000000000000e0,
53 };
54 std::array<double, 8> kPolyC = {
55 7.74545014278341407640e-4, 2.27238449892691845833e-2,
56 2.41780725177450611770e-1, 1.27045825245236838258e0,
57 3.64784832476320460504e0, 5.76949722146069140550e0,
58 4.63033784615654529590e0, 1.42343711074968357734e0,
59 };
60 std::array<double, 8> kPolyD = {
61 1.4859850019840355905497876e-9, 7.7441459065157709165577218e-4,
62 2.1494160384252876777097297e-2, 2.0945065210512749128288442e-1,
63 9.7547832001787427186894837e-1, 2.3707661626024532365971225e0,
64 2.9036514445419946173133295e0, 1.4142135623730950488016887e0,
65 };
66 std::array<double, 8> kPolyE = {
67 2.01033439929228813265e-7, 2.71155556874348757815e-5,
68 1.24266094738807843860e-3, 2.65321895265761230930e-2,
69 2.96560571828504891230e-1, 1.78482653991729133580e0,
70 5.46378491116411436990e0, 6.65790464350110377720e0,
71 };
72 std::array<double, 8> kPolyF = {
73 2.891024605872965461538222e-15, 2.010321207683943062279931e-7,
74 2.611088405080593625138020e-5, 1.112800997078859844711555e-3,
75 2.103693768272068968719679e-2, 1.936480946950659106176712e-1,
76 8.482908416595164588112026e-1, 1.414213562373095048801689e0,
77 };
78
79 if (std::abs(x) > 1 || std::isnan(x)) {
80 return std::numeric_limits<float>::quiet_NaN();
81 }
82 if (std::abs(x) == 1) {
83 return std::copysign(std::numeric_limits<float>::infinity(), x);
84 }
85
86 float unsigned_result = [&] {
87 float y = std::abs(x);
88 if (y <= 0.85) {
89 double r = 0.180625 - 0.25 * y * y;
90 return (y * EvaluatePolynomial(r, kPolyA)) /
91 EvaluatePolynomial(r, kPolyB);
92 } else {
93 double r = std::sqrt(std::log(2.0) - std::log1p(-y));
94 if (r <= 5.0) {
95 r -= 1.6;
96 return EvaluatePolynomial(r, kPolyC) / EvaluatePolynomial(r, kPolyD);
97 } else {
98 r -= 5;
99 return EvaluatePolynomial(r, kPolyE) / EvaluatePolynomial(r, kPolyF);
100 }
101 }
102 }();
103 return std::copysign(unsigned_result, x);
104 }
105
106 // Digamma implementation using a polynomial from Cephes. Notably this is a
107 // different implementation from the one in math.cc.
HostDigamma(float x)108 float HostDigamma(float x) {
109 // Euler-Mascheroni constant
110 float kGamma = 0.57721566490153286061;
111 float kPi = M_PI;
112
113 std::array<float, 4> kPoly = {
114 -4.16666666666666666667E-3,
115 3.96825396825396825397E-3,
116 -8.33333333333333333333E-3,
117 8.33333333333333333333E-2,
118 };
119
120 float reflection = 0;
121 if (x <= 0) {
122 float floor = std::floor(x);
123 if (x == floor) {
124 return std::numeric_limits<float>::quiet_NaN();
125 }
126 // Compute reflection term, pi * cot(pi * x).
127 reflection = x - floor;
128 if (reflection == 0.5) {
129 reflection = 0;
130 } else {
131 if (reflection > 0.5) {
132 reflection = x - (floor + 1.0f);
133 }
134 reflection = kPi / std::tan(kPi * reflection);
135 }
136 x = 1 - x;
137 }
138
139 float result = 0;
140 if (x <= 10 && x == std::floor(x)) {
141 // Special case for integers <= 10.
142 for (int i = 1; i < x; ++i) {
143 result += 1.0f / i;
144 }
145 result -= kGamma;
146 } else {
147 float w = 0;
148 for (; x < 10; ++x) {
149 w += 1.0f / x;
150 }
151 if (x < 1e8) {
152 float z = 1.0f / (x * x);
153 result = z * EvaluatePolynomial(z, kPoly);
154 }
155 result = std::log(x) - 0.5f / x - result - w;
156 }
157
158 // Compute the final, reflected value.
159 return result - reflection;
160 }
161
162 // For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be
163 // guaranteed that we're printing the full number.
164 //
165 // (The general formula is, given a floating-point number with S significand
166 // bits, the number of decimal digits needed to print it to full precision is
167 //
168 // ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103).
169 //
170 // See https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf.)
StringifyNum(float x)171 string StringifyNum(float x) {
172 return absl::StrFormat("%0.9g (0x%08x)", x, absl::bit_cast<uint32>(x));
173 }
174
StringifyNum(half x)175 string StringifyNum(half x) {
176 return absl::StrFormat("%0.5g (0x%04x)", static_cast<float>(x),
177 absl::bit_cast<uint16>(x));
178 }
179
StringifyNum(bfloat16 x)180 string StringifyNum(bfloat16 x) {
181 return absl::StrFormat("%0.4g (0x%04x)", static_cast<float>(x),
182 absl::bit_cast<uint16>(x));
183 }
184
185 // Test parameter is a tuple containing
186 // - primitive type under test,
187 // - (begin, end) range under test, as zero-extended int64s bitcast to the
188 // primtive type under test.
189 class ExhaustiveOpTest
190 : public ClientLibraryTestBase,
191 public ::testing::WithParamInterface<
192 std::tuple<PrimitiveType, std::pair<int64, int64>>> {
193 public:
ExhaustiveOpTest()194 ExhaustiveOpTest()
195 : ty_(std::get<0>(GetParam())), platform_(client_->platform()->Name()) {}
196
Run(std::function<XlaOp (XlaOp)> enqueue_op,float (* evaluate_op)(float))197 void Run(std::function<XlaOp(XlaOp)> enqueue_op,
198 float (*evaluate_op)(float)) {
199 SetFastMathDisabled(true);
200
201 // Run all HLO passes. In particular, constant folding is disabled by
202 // default for tests, but we need to run it in order to tickle some bugs.
203 mutable_debug_options()->clear_xla_disable_hlo_passes();
204
205 PrimitiveType ty;
206 std::tie(ty, std::ignore) = GetParam();
207
208 switch (ty) {
209 case F32:
210 SetDefaultErrSpec(0.0001, 0.0001);
211 RunImpl<float, uint32>(enqueue_op, evaluate_op);
212 break;
213 case F16:
214 SetDefaultErrSpec(0.001, 0.001);
215 RunImpl<half, uint16>(enqueue_op, evaluate_op);
216 break;
217 case BF16:
218 SetDefaultErrSpec(0.001, 0.01);
219 RunImpl<bfloat16, uint16>(enqueue_op, evaluate_op);
220 break;
221 default:
222 LOG(FATAL) << "Unhandled type.";
223 }
224 }
225
SetDefaultErrSpec(float abs_err,float rel_err)226 void SetDefaultErrSpec(float abs_err, float rel_err) {
227 if (!abs_err_.has_value()) {
228 abs_err_ = abs_err;
229 }
230 if (!rel_err_.has_value()) {
231 rel_err_ = rel_err;
232 }
233 }
234
235 template <typename T, typename IntegralT>
RunImpl(std::function<XlaOp (XlaOp)> enqueue_op,float (* evaluate_op)(float))236 void RunImpl(std::function<XlaOp(XlaOp)> enqueue_op,
237 float (*evaluate_op)(float)) {
238 static_assert(
239 sizeof(T) == sizeof(IntegralT),
240 "IntegralT must be an unsigned integer type of the same width as T.");
241
242 PrimitiveType ty;
243 std::pair<int64, int64> test_range;
244 std::tie(ty, test_range) = GetParam();
245 int64 begin, end;
246 std::tie(begin, end) = test_range;
247
248 if (begin >= known_incorrect_begin_ && end <= known_incorrect_end_) {
249 LOG(INFO) << absl::StreamFormat(
250 "Skipping this shard, as the range under test, [%d, %d), falls "
251 "entirely within the known-incorrect range [%d, %d).",
252 begin, end, known_incorrect_begin_, known_incorrect_end_);
253 return;
254 }
255
256 LOG(INFO) << "Checking range [" << begin << ", " << end << ")";
257
258 int64 input_size = end - begin;
259 Literal input_literal = LiteralUtil::CreateFromDimensions(ty, {input_size});
260 absl::Span<T> input_arr = input_literal.data<T>();
261 for (int64 i = 0; i < input_size; i++) {
262 IntegralT input_val = i + begin;
263 // If the operation is known to be buggy on a specific input clamp that
264 // input to 0 under the assumption that the op is at least correct on 0.
265 if (input_val >= known_incorrect_begin_ &&
266 input_val < known_incorrect_end_) {
267 input_arr[i] = T{0};
268 } else {
269 input_arr[i] = absl::bit_cast<T>(input_val);
270 }
271 }
272
273 TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
274 BuildAndRunComputation(enqueue_op, input_literal));
275 ExpectNear<T>(input_literal, result_literal, evaluate_op);
276 }
277
BuildAndRunComputation(const std::function<XlaOp (XlaOp)> & enqueue_op,const Literal & input_literal)278 StatusOr<Literal> BuildAndRunComputation(
279 const std::function<XlaOp(XlaOp)>& enqueue_op,
280 const Literal& input_literal) {
281 XlaBuilder builder(TestName());
282 auto input = Parameter(&builder, 0, input_literal.shape(), "input");
283 enqueue_op(input);
284 TF_ASSIGN_OR_RETURN(XlaComputation comp, builder.Build());
285
286 // Build and run the computation using the LocalClient API, rather than the
287 // plain Client API, which is used by ClientLibraryTestBase. This is
288 // because the plain Client API results does more memcpys to/from Literals,
289 // and that's slow given that we're touching a lot of data here.
290 //
291 // Copy debug options from ClientLibraryTestBase. In particular, we're
292 // interested in disabling constant folding.
293 ExecutableBuildOptions build_opts;
294 *build_opts.mutable_debug_options() = *mutable_debug_options();
295 TF_ASSIGN_OR_RETURN(
296 auto executable,
297 client_->Compile(comp, {&input_literal.shape()}, build_opts));
298
299 TF_ASSIGN_OR_RETURN(
300 ScopedShapedBuffer input_data,
301 client_->LiteralToShapedBuffer(input_literal, /*device_ordinal=*/0));
302
303 ExecutableRunOptions run_opts;
304 run_opts.set_allocator(client_->backend().memory_allocator());
305 run_opts.set_intra_op_thread_pool(
306 client_->backend().eigen_intra_op_thread_pool_device());
307 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
308 executable->Run({&input_data}, run_opts));
309
310 TF_ASSIGN_OR_RETURN(Literal result_literal,
311 client_->ShapedBufferToLiteral(result));
312 return std::move(result_literal);
313 }
314
315 template <typename T>
IsClose(T expected,T actual)316 bool IsClose(T expected, T actual) {
317 float expected_f32 = static_cast<float>(expected);
318 float actual_f32 = static_cast<float>(actual);
319 float abs_err = std::abs(expected_f32 - actual_f32);
320 float rel_err = abs_err / std::abs(expected_f32);
321 if (strict_signed_zeros_ && actual == T{0} && expected == T{0}) {
322 // Check sign of zero.
323 return std::signbit(actual_f32) == std::signbit(expected_f32);
324 }
325 return abs_err < *abs_err_ || rel_err < *rel_err_ ||
326 (std::isnan(expected_f32) && std::isnan(actual_f32)) ||
327 (std::isinf(expected_f32) && std::isinf(actual_f32) &&
328 (expected_f32 > 0) == (actual_f32 > 0));
329 }
330
331 template <typename T>
ExpectNear(const Literal & input_literal,const Literal & result_literal,float (* evaluate_op)(float))332 void ExpectNear(const Literal& input_literal, const Literal& result_literal,
333 float (*evaluate_op)(float)) {
334 // We essentially reimplement LiteralTestUtil::Near here because
335 // a) this streamlined implementation is much faster, and
336 // b) we can print out better error messages (namely, we can print out
337 // which floating-point value input failed, while LiteralTestUtil::Near
338 // can only print out the input index that failed).
339 // c) we need special handling of certain inputs. For example, we say that
340 // a denormal input has multiple correct outputs (namely, f(x) and f(0))
341 // and just needs to be close to one of them.
342 absl::Span<const T> input_arr = input_literal.data<T>();
343 absl::Span<const T> result_arr = result_literal.data<T>();
344 ASSERT_EQ(result_arr.size(), input_arr.size());
345 int64 mismatches = 0;
346 // Hoisting these out of the loop is a nice speedup on shards that have many
347 // denormals.
348 const T expected_at_pos_zero = static_cast<T>(evaluate_op(0));
349 const T expected_at_neg_zero = static_cast<T>(evaluate_op(-0.0));
350 for (int64 i = 0; i < input_arr.size(); ++i) {
351 T input = input_arr[i];
352 float input_f32 = static_cast<float>(input);
353 T actual = result_arr[i];
354 T expected = static_cast<T>(evaluate_op(input_f32));
355
356 if (IsClose(expected, actual)) {
357 continue;
358 }
359
360 // Easy case: If `input` is not denormal and !IsClose(expected, actual),
361 // print an error.
362 //
363 // (This doesn't correctly detect f16 and bfloat16 denormals! This seems
364 // to be OK for now, but at some point we may need to implement fpclassify
365 // for half and bfloat.)
366 if (std::fpclassify(input_f32) != FP_SUBNORMAL) {
367 PrintMismatch(&mismatches, [&] {
368 return absl::StrFormat("Mismatch on %s. Expected %s, but got %s.",
369 StringifyNum(input), StringifyNum(expected),
370 StringifyNum(actual));
371 });
372 continue;
373 }
374
375 // Otherwise, `input` is denormal. For denormal inputs, we accept answers
376 // that are close to any of:
377 //
378 // - evaluate_op(input)
379 // - evaluate_op(+/-0), where the sign of 0 equal to the sign of
380 // `input`,
381 // - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of
382 // 0 is the opposite of `input`.
383 T sign_preserving_ftz_expected =
384 std::signbit(input_f32) ? expected_at_neg_zero : expected_at_pos_zero;
385 T sign_nonpreserving_ftz_expected =
386 std::signbit(input_f32) ? expected_at_pos_zero : expected_at_neg_zero;
387 if (IsClose(sign_preserving_ftz_expected, actual) ||
388 (relaxed_denormal_signs_ &&
389 IsClose(sign_nonpreserving_ftz_expected, actual))) {
390 continue;
391 }
392
393 if (relaxed_denormal_signs_) {
394 PrintMismatch(&mismatches, [&] {
395 return absl::StrFormat(
396 "Mismatch on denormal value %s. Expected one of:\n"
397 " %10s (evaluated at full-precision value)\n"
398 " %10s (evaluated after flushing to sign-preserving zero)\n"
399 " %10s (evaluated after flushing to non-sign-preserving "
400 "zero)\n"
401 "but got %s.",
402 StringifyNum(input), StringifyNum(expected),
403 StringifyNum(sign_preserving_ftz_expected),
404 StringifyNum(sign_nonpreserving_ftz_expected),
405 StringifyNum(actual));
406 });
407 } else {
408 PrintMismatch(&mismatches, [&] {
409 return absl::StrFormat(
410 "Mismatch on denormal value %s. Expected one of:\n"
411 " %10s (evaluated at full-precision value)\n"
412 " %10s (evaluated after flushing to sign-preserving zero)\n"
413 "but got %s.",
414 StringifyNum(input), StringifyNum(expected),
415 StringifyNum(sign_preserving_ftz_expected), StringifyNum(actual));
416 });
417 }
418 }
419 EXPECT_EQ(mismatches, 0);
420 }
421
422 template <typename ErrorGenerator>
PrintMismatch(int64 * mismatches,const ErrorGenerator & err_generator)423 void PrintMismatch(int64* mismatches, const ErrorGenerator& err_generator) {
424 // We send a few mismatches to gunit so they show up nicely in test logs.
425 // Then we send more to LOG(ERROR). The remainder we squelch unless we're
426 // at vlog level 2.
427 constexpr int64 kMaxMismatchesLoggedToGunit = 10;
428 constexpr int64 kMaxMismatchesLoggedToErr = 1000;
429
430 (*mismatches)++;
431 if (*mismatches < kMaxMismatchesLoggedToGunit) {
432 FAIL() << err_generator();
433 } else if (*mismatches < kMaxMismatchesLoggedToErr || VLOG_IS_ON(2)) {
434 LOG(ERROR) << err_generator();
435 } else if (*mismatches == kMaxMismatchesLoggedToErr) {
436 LOG(ERROR) << "Not printing any more mismatches; pass "
437 "--vmodule=exhaustive_f32__op_test=2 to see "
438 "all of them.";
439 }
440 }
441
442 // The following members are set during construction so testcases can read
443 // these values and use them e.g. to influence the values given to the mutable
444 // members below.
445
446 // The primitive type under test.
447 const PrimitiveType ty_;
448
449 // The platform under test.
450 const string platform_;
451
452 // Tests can set the following variables for control over execution. This is
453 // safe because each XLA_TEST_P instantiates a new instance of this class.
454
455 // Testing will ignore the given range (encoded as bitwise representations of
456 // the type under test zero-extended to int64).
457 int64 known_incorrect_begin_ = 0;
458 int64 known_incorrect_end_ = 0;
459
460 // If unset, reasonable defaults will be used depending on the type under
461 // test.
462 absl::optional<float> abs_err_;
463 absl::optional<float> rel_err_;
464
465 // If true, will consider -0 not near to +0 and vice versa. Note that
466 // +epsilon may still be considered close to -0, depending on the error spec;
467 // this only covers the case when both `expected` and `actual` are equal to 0.
468 bool strict_signed_zeros_ = false;
469
470 // If true, allows denormals to be flushed to non-sign-preserving 0.
471 //
472 // For example, normally we'd expect sqrt(-denormal) to be either nan (sqrt of
473 // a negative number) or -inf (flush the denormal to sign-perserving zero,
474 // then sqrt(-0)). But with this as true, we'll also accept 0 (sqrt(0)).
475 //
476 // XLA:GPU preserves denormal signs, but other backends don't.
477 bool relaxed_denormal_signs_ = platform_ != "CUDA";
478 };
479
XLA_TEST_P(ExhaustiveOpTest,Log)480 XLA_TEST_P(ExhaustiveOpTest, Log) {
481 if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
482 abs_err_ = 0.001;
483 rel_err_ = 0.001;
484 }
485
486 Run(Log, std::log);
487 }
488
XLA_TEST_P(ExhaustiveOpTest,Log1p)489 XLA_TEST_P(ExhaustiveOpTest, Log1p) {
490 if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) {
491 abs_err_ = 0.001;
492 rel_err_ = 0.001;
493 }
494
495 Run(Log1p, std::log1p);
496 }
497
XLA_TEST_P(ExhaustiveOpTest,Exp)498 XLA_TEST_P(ExhaustiveOpTest, Exp) {
499 if (platform_ == "Host" && ty_ == F32) {
500 // TODO(b/73142289): The vectorized Exp implementation gives results outside
501 // our error spec in this range.
502 known_incorrect_begin_ = 1107296256 + 11583654;
503 known_incorrect_end_ = 1107296256 + 11629080;
504 } else if (platform_ == "Host" && ty_ == BF16) {
505 // TODO(jlebar): Is this a rounding error? Why doesn't it occur on XLA:GPU?
506 //
507 // Mismatch on 88.5 (0x42b1).
508 // Expected 2.72491739e+38 (0x7f4d), but got inf (0x7f80).
509 known_incorrect_begin_ = 0x42b1;
510 known_incorrect_end_ = 0x42b2;
511 }
512
513 Run(Exp, std::exp);
514 }
515
XLA_TEST_P(ExhaustiveOpTest,Expm1)516 XLA_TEST_P(ExhaustiveOpTest, Expm1) {
517 // Expm1 has the same erroneous behavior on CPU as Exp.
518 if (platform_ == "Host" && ty_ == F32) {
519 // TODO(b/73142289): The vectorized Exp implementation gives results outside
520 // our error spec in this range.
521 known_incorrect_begin_ = 1107296256 + 11583654;
522 known_incorrect_end_ = 1107296256 + 11629080;
523 } else if (platform_ == "Host" && ty_ == BF16) {
524 // TODO(jlebar): Is this a rounding error? Why doesn't it occur on XLA:GPU?
525 //
526 // Mismatch on 88.5 (0x42b1).
527 // Expected 2.72491739e+38 (0x7f4d), but got inf (0x7f80).
528 known_incorrect_begin_ = 0x42b1;
529 known_incorrect_end_ = 0x42b2;
530 }
531
532 Run(Expm1, std::expm1);
533 }
534
535 // It feels a little overkill to exhaustively test sqrt and pow(x, 0.5), but
536 // this *did* find a bug, namely that some backends were assuming sqrt(x) ==
537 // pow(x, 0.5), but this is not true for x == -inf.
XLA_TEST_P(ExhaustiveOpTest,PowOneHalf)538 XLA_TEST_P(ExhaustiveOpTest, PowOneHalf) {
539 Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); },
540 +[](float x) { return std::pow(x, 0.5f); });
541 }
542
XLA_TEST_P(ExhaustiveOpTest,Rsqrt)543 XLA_TEST_P(ExhaustiveOpTest, Rsqrt) {
544 Run(
545 Rsqrt, +[](float x) { return 1 / std::sqrt(x); });
546 }
547
XLA_TEST_P(ExhaustiveOpTest,Sqrt)548 XLA_TEST_P(ExhaustiveOpTest, Sqrt) {
549 if (platform_ == "Host" || platform_ == "CUDA") {
550 strict_signed_zeros_ = true;
551 }
552
553 Run(Sqrt, std::sqrt);
554 }
555
556 // TODO(jlebar): Add remaining trig functions. Don't forget Atan2!
557 // TODO(jlebar): Test trig functions over complex inputs.
XLA_TEST_P(ExhaustiveOpTest,Tanh)558 XLA_TEST_P(ExhaustiveOpTest, Tanh) { Run(Tanh, std::tanh); }
559
XLA_TEST_P(ExhaustiveOpTest,Erf)560 XLA_TEST_P(ExhaustiveOpTest, Erf) { Run(Erf, std::erf); }
XLA_TEST_P(ExhaustiveOpTest,Erfc)561 XLA_TEST_P(ExhaustiveOpTest, Erfc) { Run(Erfc, std::erfc); }
XLA_TEST_P(ExhaustiveOpTest,ErfInv)562 XLA_TEST_P(ExhaustiveOpTest, ErfInv) { Run(ErfInv, HostErfInv); }
XLA_TEST_P(ExhaustiveOpTest,Digamma)563 XLA_TEST_P(ExhaustiveOpTest, Digamma) {
564 if (platform_ != "Host" && platform_ != "CUDA") {
565 // TODO(b/123956399): This is a fairly high error, significantly higher than
566 // we see on CPU/GPU.
567 rel_err_ = 0.01;
568 abs_err_ = 0.01;
569 }
570
571 if (platform_ == "CUDA") {
572 // On GPU we get a wrong answer for the denormal inputs +/-2.93873588e-39
573 // (0x00200000 and 0x80200000). These should return -/+inf (at least
574 // according to our reference implementation!) but XLA:GPU returns
575 // -/+3.40282326e+38 (0xff7ffffe and 0x7f7ffffe).
576 //
577 // I deem this an acceptable result, as XLA:GPU flushes denormals, and as
578 // the results we get here are very close to MAX_FLOAT. We just hardcode
579 // these results, as this is better than ignoring these inputs altogether.
580 auto host_digamma_with_gpu_ftz_errors = +[](float x) {
581 if (absl::bit_cast<uint32>(x) == 0x00200000 ||
582 absl::bit_cast<uint32>(x) == 0x80200000) {
583 return std::copysign(std::numeric_limits<float>::max(), -x);
584 }
585 return HostDigamma(x);
586 };
587 Run(Digamma, host_digamma_with_gpu_ftz_errors);
588 } else {
589 Run(Digamma, HostDigamma);
590 }
591 }
XLA_TEST_P(ExhaustiveOpTest,Lgamma)592 XLA_TEST_P(ExhaustiveOpTest, Lgamma) {
593 // Our implementation gets within 0.0001 rel error except for ~20 denormal
594 // inputs on GPU. Anyway 0.001 rel error should be good enough for lgamma.
595 if (platform_ == "CUDA" && (ty_ == F32 || ty_ == F16)) {
596 rel_err_ = 0.001;
597 }
598 if (platform_ != "Host" && platform_ != "CUDA") {
599 // TODO(b/123956399): This is a fairly high error, significantly higher than
600 // we see on CPU/GPU.
601 rel_err_ = 0.01;
602 abs_err_ = 0.01;
603
604 // Overflows for to inf for input 4.08500343e+36 (0x7c44af8e).
605 if (ty_ == F32) {
606 known_incorrect_begin_ = 0x7c44af8e;
607 known_incorrect_end_ = 0x7c44af8e + 1;
608 }
609 }
610 Run(Lgamma, std::lgamma);
611 }
612
XLA_TEST_P(ExhaustiveOpTest,Round)613 XLA_TEST_P(ExhaustiveOpTest, Round) { Run(Round, std::round); }
614
CreateExhaustiveF32Ranges()615 std::vector<std::pair<int64, int64>> CreateExhaustiveF32Ranges() {
616 // We break up the 2^32-element space into small'ish chunks to keep peak
617 // memory usage low.
618 std::vector<std::pair<int64, int64>> result;
619 const int64 step = 1 << 25;
620 for (int64 i = 0; i < (1l << 32); i += step) {
621 result.push_back({i, i + step});
622 }
623 return result;
624 }
625
626 INSTANTIATE_TEST_SUITE_P(
627 F32, ExhaustiveOpTest,
628 ::testing::Combine(::testing::Values(F32),
629 ::testing::ValuesIn(CreateExhaustiveF32Ranges())));
630
631 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
632 INSTANTIATE_TEST_SUITE_P(
633 F16, ExhaustiveOpTest,
634 ::testing::Combine(::testing::Values(F16),
635 ::testing::Values(std::make_pair(0, 1 << 16))));
636 #endif
637
638 #if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
639 INSTANTIATE_TEST_SUITE_P(
640 BF16, ExhaustiveOpTest,
641 ::testing::Combine(::testing::Values(BF16),
642 ::testing::Values(std::make_pair(0, 1 << 16))));
643 #endif
644
645 } // namespace
646 } // namespace xla
647