1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This macro is required to make MSVC defines math constants in math.h
17 #define _USE_MATH_DEFINES
18 #include <math.h>
19
20 #include "tensorflow/compiler/xla/client/lib/math.h"
21
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27
28 namespace xla {
29
30 // Returns operation(operand), except if `operand` is one of the types in
31 // upcast_types, in which case first converts it to F32, and then converts the
32 // result down to the original type.
DoWithUpcastToF32(XlaOp operand,absl::Span<const PrimitiveType> upcast_types,const std::function<XlaOp (XlaOp)> & operation)33 static XlaOp DoWithUpcastToF32(XlaOp operand,
34 absl::Span<const PrimitiveType> upcast_types,
35 const std::function<XlaOp(XlaOp)>& operation) {
36 auto& b = *operand.builder();
37 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
38 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
39 PrimitiveType elem_ty = shape.element_type();
40 bool needs_upcast = absl::c_linear_search(upcast_types, elem_ty);
41
42 if (needs_upcast) {
43 operand = ConvertElementType(operand, F32);
44 }
45 XlaOp result = operation(operand);
46 if (needs_upcast) {
47 result = ConvertElementType(result, elem_ty);
48 }
49 return result;
50 });
51 }
52
53 // TODO(jlebar): Use this function in more places in this file to restrict the
54 // domain of other functions.
EnsureOperandIsRealFp(absl::string_view op_name,XlaOp operand)55 static Status EnsureOperandIsRealFp(absl::string_view op_name, XlaOp operand) {
56 auto& b = *operand.builder();
57 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
58 auto elem_ty = shape.element_type();
59 if (!primitive_util::IsFloatingPointType(elem_ty)) {
60 return InvalidArgument(
61 "Operands to %s must be real-valued floating-point, but got %s",
62 op_name, PrimitiveType_Name(elem_ty));
63 }
64 return Status::OK();
65 }
66
IsPosInf(XlaOp operand)67 XlaOp IsPosInf(XlaOp operand) {
68 auto& b = *operand.builder();
69 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
70 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsPosInf", operand));
71 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
72 // Note that this is only correct for floating-point types. If we wanted it
73 // to be correct for all types, we'd need to Gt(MaxFiniteValue).
74 return Eq(operand, MaxValue(&b, shape.element_type()));
75 });
76 }
77
IsNegInf(XlaOp operand)78 XlaOp IsNegInf(XlaOp operand) {
79 auto& b = *operand.builder();
80 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
81 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegInf", operand));
82 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
83 // Note that this is only correct for floating-point types. If we wanted it
84 // to be correct for all types, we'd need to Lt(MinFiniteValue).
85 return Eq(operand, MinValue(&b, shape.element_type()));
86 });
87 }
88
IsInf(XlaOp operand)89 XlaOp IsInf(XlaOp operand) {
90 auto& b = *operand.builder();
91 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
92 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsInf", operand));
93 return IsPosInf(Abs(operand));
94 });
95 }
96
IsNan(XlaOp operand)97 XlaOp IsNan(XlaOp operand) {
98 auto& b = *operand.builder();
99 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
100 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNan", operand));
101 return Ne(operand, operand);
102 });
103 }
104
IsNegZero(XlaOp operand)105 XlaOp IsNegZero(XlaOp operand) {
106 auto& b = *operand.builder();
107 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
108 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegZero", operand));
109 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
110
111 // The bitwise representation of -0 in bfloat16 and IEEE 754 is 0x80...0
112 // (sign bit on, all other bits off).
113 switch (shape.element_type()) {
114 case F64:
115 return Eq(BitcastConvertType(operand, U64),
116 ConstantR0WithType(&b, U64, uint64{1} << 63));
117 case F32:
118 return Eq(BitcastConvertType(operand, U32),
119 ConstantR0WithType(&b, U32, uint32{1} << 31));
120 case F16:
121 case BF16:
122 // Not all XLA backends handle U16 well, so we convert to F32/U32.
123 // TODO(jlebar): It would be nice if we could stay in (B)F16/U16 for
124 // backends that *do* support it.
125 return Eq(BitcastConvertType(ConvertElementType(operand, F32), U32),
126 ConstantR0WithType(&b, U32, uint32{1} << 31));
127 default:
128 LOG(FATAL) << "Expected real fp type.";
129 }
130 });
131 }
132
Square(XlaOp operand)133 XlaOp Square(XlaOp operand) { return operand * operand; }
134
Reciprocal(XlaOp operand)135 XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; }
136
137 // Evaluate the polynomial given coefficients and `x`.
138 // N.B. Coefficients should be supplied in decreasing order.
EvaluatePolynomial(XlaOp x,absl::Span<const float> coefficients)139 XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const float> coefficients) {
140 XlaOp poly = ScalarLike(x, 0.0);
141 for (float c : coefficients) {
142 poly = poly * x + ScalarLike(x, c);
143 }
144 return poly;
145 }
146
147 // Computes an approximation of the error function complement (1 - erf(x)).
148 //
149 // Precondition: abs(x) >= 1. Otherwise, use ErfImpl.
150 //
151 // This follows Cephes's f32 implementation of erfc, and so it may have errors
152 // for double precision.
153 //
154 // See also these alternate implementations of erf and erfc:
155 //
156 // https://stackoverflow.com/questions/35148198
157 // https://stackoverflow.com/questions/35966695
158 //
ErfcImpl(XlaOp x)159 static XlaOp ErfcImpl(XlaOp x) {
160 // Coefficients for erfc(f32), from Cephes.
161 //
162 // erfc(x) = exp(-x^2) P(1/x), 1 < x < 2
163 static std::array<float, 9> kErfcPCoefficient{
164 +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1,
165 -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1,
166 +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1,
167 };
168 // erfc(x) = exp(-x^2) 1/x P(1/x^2), 2 < x < 14
169 static std::array<float, 8> kErfcRCoefficient{
170 -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0,
171 +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1,
172 -2.820767439740514E-1, +5.641895067754075E-1,
173 };
174
175 XlaOp abs_x = Abs(x);
176 XlaOp z = Exp(-x * x);
177 XlaOp q = ScalarLike(x, 1) / abs_x;
178 XlaOp y = q * q;
179 XlaOp p = Select(Lt(abs_x, ScalarLike(x, 2.0)),
180 EvaluatePolynomial(y, kErfcPCoefficient),
181 EvaluatePolynomial(y, kErfcRCoefficient));
182 y = z * q * p;
183 return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y, y);
184 }
185
186 // Compute a polynomial approximation of the error function.
187 //
188 // Precondition: abs(x) <= 1. Otherwise, use ErfcImpl.
189 //
190 // This follows Cephes's f32 implementation of erf, so it may have errors for
191 // double precision.
ErfImpl(XlaOp x)192 static XlaOp ErfImpl(XlaOp x) {
193 // Coefficients for by erf(f32), from Cephes.
194 //
195 // erf(x) = x P(x^2), 0 < x < 1
196 static std::array<float, 7> kErfTCoefficient{
197 +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3,
198 -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1,
199 +1.128379165726710E+0,
200 };
201
202 return x * EvaluatePolynomial(x * x, kErfTCoefficient);
203 }
204
Erfc(XlaOp x)205 XlaOp Erfc(XlaOp x) {
206 auto& b = *x.builder();
207 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
208 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erfc", x));
209
210 // erfc(x) =
211 // erfc_impl(x) if x > 1
212 // 1 - erf_impl(x) otherwise
213 //
214 // Erf(c)Impl don't have enough precision when run with bf16 intermediates
215 // (not surprising!), so upcast to f32 in this case.
216 return DoWithUpcastToF32(x, {BF16}, [](XlaOp x) {
217 return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl(x),
218 ScalarLike(x, 1) - ErfImpl(x));
219 });
220 });
221 }
222
Erf(XlaOp x)223 XlaOp Erf(XlaOp x) {
224 auto& b = *x.builder();
225 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
226 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erf", x));
227 // erf(x) =
228 // erf_impl(x) if x < 1
229 // 1 - erfc_impl(x) otherwise
230 //
231 // Erf(c)Impl don't have enough precision when run with bf16 intermediates
232 // (not surprising!), so upcast to f32 in this case.
233 return DoWithUpcastToF32(x, {BF16}, [](XlaOp x) {
234 return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl(x),
235 ScalarLike(x, 1) - ErfcImpl(x));
236 });
237 });
238 }
239
240 // Approximation for the inverse error function from
241 // Giles, M., "Approximating the erfinv function".
242 // The approximation has the form:
243 // w = -log((1 - x) * (1 + x))
244 // if ( w < 5 ) {
245 // w = w - 2.5
246 // p = sum_{i=1}^n lq[i]*w^i
247 // } else {
248 // w = sqrt(w) - 3
249 // p = sum_{i=1}^n gq[i]*w^i
250 // }
251 // return p*x
ErfInv(XlaOp x)252 XlaOp ErfInv(XlaOp x) {
253 constexpr int kDegree = 9;
254 constexpr std::array<float, 9> w_less_than_5_constants = {
255 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
256 -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
257 -0.00417768164f, 0.246640727f, 1.50140941f};
258 constexpr std::array<float, 9> w_greater_than_5_constants = {
259 -0.000200214257f, 0.000100950558f, 0.00134934322f,
260 -0.00367342844f, 0.00573950773f, -0.0076224613f,
261 0.00943887047f, 1.00167406f, 2.83297682f};
262
263 auto one = ScalarLike(x, 1.0);
264 auto w = -Log((one - x) * (one + x));
265
266 auto lt = Lt(w, ScalarLike(x, 5.0));
267 auto coefficient = [&](int i) {
268 return Select(lt, FullLike(x, w_less_than_5_constants[i]),
269 FullLike(x, w_greater_than_5_constants[i]));
270 };
271 w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0));
272 auto p = coefficient(0);
273 for (int i = 1; i < kDegree; ++i) {
274 p = coefficient(i) + p * w;
275 }
276
277 // Result modulo edge cases.
278 XlaOp result = p * x;
279
280 // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is
281 // indeterminate, and can give nan or -/+inf.)
282 auto& b = *x.builder();
283 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
284 TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x));
285 return Select(Eq(Abs(x), ScalarLike(x, 1)),
286 x * MaxValue(&b, shape.element_type()), result);
287 });
288 }
289
290 namespace {
291 // Coefficients for the Lanczos approximation of the gamma function. The
292 // coefficients are uniquely determined by the choice of g and n (kLanczosGamma
293 // and kLanczosCoefficients.size() + 1). The coefficients below correspond to
294 // [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7,
295 // 9] seemed to be the least sensitive to the quality of the log function. In
296 // particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
297 // for a particularly inaccurate log function.
298 static constexpr double kLanczosGamma = 7; // aka g
299 static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
300 static constexpr std::array<double, 8> kLanczosCoefficients = {
301 676.520368121885098567009190444019, -1259.13921672240287047156078755283,
302 771.3234287776530788486528258894, -176.61502916214059906584551354,
303 12.507343278686904814458936853, -0.13857109526572011689554707,
304 9.984369578019570859563e-6, 1.50563273514931155834e-7};
305 } // namespace
306
307 // Compute the Lgamma function using Lanczos' approximation from "A Precision
308 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
309 // series B. Vol. 1:
310 // lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
311 // t(z) = z + kLanczosGamma + 1/2
312 // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
Lgamma(XlaOp input)313 XlaOp Lgamma(XlaOp input) {
314 auto do_it = [](XlaOp input) {
315 XlaOp one_half = ScalarLike(input, 0.5);
316 XlaOp one = ScalarLike(input, 1);
317
318 XlaOp pi = ScalarLike(input, M_PI);
319 XlaOp log_pi = ScalarLike(input, std::log(M_PI));
320 XlaOp log_sqrt_two_pi =
321 ScalarLike(input, (std::log(2) + std::log(M_PI)) / 2);
322
323 XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5);
324 XlaOp log_lanczos_gamma_plus_one_half =
325 ScalarLike(input, std::log(kLanczosGamma + 0.5));
326
327 XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff);
328
329 // If the input is less than 0.5 use Euler's reflection formula:
330 // gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
331 XlaOp need_to_reflect = Lt(input, one_half);
332 XlaOp z = Select(need_to_reflect, -input, input - one);
333
334 XlaOp x = base_lanczos_coeff;
335 for (int i = 0; i < kLanczosCoefficients.size(); ++i) {
336 XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]);
337 XlaOp index = ScalarLike(input, i);
338 x = x + lanczos_coefficient / (z + index + one);
339 }
340
341 // To improve accuracy on platforms with less-precise log implementations,
342 // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
343 // the device.
344 // log(t) = log(kLanczosGamma + 0.5 + z)
345 // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
346 XlaOp t = lanczos_gamma_plus_one_half + z;
347 XlaOp log_t = log_lanczos_gamma_plus_one_half +
348 Log1p(z / lanczos_gamma_plus_one_half);
349
350 // Compute the final result (modulo reflection). t(z) may be large, and we
351 // need to be careful not to overflow to infinity in the first term of
352 //
353 // (z + 1/2) * log(t(z)) - t(z).
354 //
355 // Therefore we compute this as
356 //
357 // (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
358 //
359 XlaOp log_y = log_sqrt_two_pi + (z + one_half - t / log_t) * log_t + Log(x);
360
361 // Compute the reflected value, used when x < 0.5:
362 //
363 // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
364 //
365 // (The abs is because lgamma is the log of the absolute value of the gamma
366 // function.)
367 //
368 // We have to be careful when computing the final term above. gamma(x) goes
369 // to +/-inf at every integer x < 0, and this is controlled by the
370 // sin(pi * x) term. The slope is large, so precision is particularly
371 // important.
372 //
373 // Because abs(sin(pi * x)) has period 1, we can equivalently use
374 // abs(sin(pi * frac(x))), where frac(x) is the fractional part of x. This
375 // is more numerically accurate: It doesn't overflow to inf like pi * x can,
376 // and if x is an integer, it evaluates to 0 exactly, which is significant
377 // because we then take the log of this value, and log(0) is inf.
378 //
379 // We don't have a frac(x) primitive in XLA and computing it is tricky, but
380 // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for
381 // our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
382 //
383 // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
384 // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
385 // [0, 1] is symmetric across the line Y=0.5.
386 //
387 XlaOp abs_input = Abs(input);
388 XlaOp abs_frac_input = abs_input - Floor(abs_input);
389 // Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve
390 // precision of pi * abs_frac_input for values of abs_frac_input close to 1.
391 XlaOp reduced_frac_input =
392 Select(Gt(abs_frac_input, ScalarLike(abs_frac_input, 0.5)),
393 ScalarLike(abs_frac_input, 1) - abs_frac_input, abs_frac_input);
394 XlaOp reflection_denom = Log(Sin(pi * reduced_frac_input));
395
396 // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
397 // then it "wins" and the result is +/-inf.
398 XlaOp reflection =
399 Select(IsFinite(reflection_denom), log_pi - reflection_denom - log_y,
400 -reflection_denom);
401 XlaOp result = Select(need_to_reflect, reflection, log_y);
402
403 // lgamma(+/-inf) = +inf.
404 XlaOp inf_bcast = FullLike(input, std::numeric_limits<float>::infinity());
405 return Select(IsInf(input), inf_bcast, result);
406 };
407
408 auto& b = *input.builder();
409 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
410 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Lgamma", input));
411 // F16 and BF16 don't provide sufficient precision for intermediate results
412 // here (although it's better than you might expect!), so do the
413 // computations in F32.
414 return DoWithUpcastToF32(input, {BF16, F16}, do_it);
415 });
416 }
417
418 // Compute the Digamma function using Lanczos' approximation from "A Precision
419 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
420 // series B. Vol. 1:
421 // digamma(z + 1) = log(t(z)) + A'(z) / A(z) - kLanczosGamma / t(z)
422 // t(z) = z + kLanczosGamma + 1/2
423 // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
424 // A'(z) = sigma(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
Digamma(XlaOp input)425 XlaOp Digamma(XlaOp input) {
426 auto do_it = [](XlaOp input) {
427 XlaOp zero = ScalarLike(input, 0);
428 XlaOp one_half = ScalarLike(input, 0.5);
429 XlaOp one = ScalarLike(input, 1);
430
431 XlaOp pi = ScalarLike(input, M_PI);
432
433 XlaOp lanczos_gamma = ScalarLike(input, kLanczosGamma);
434 XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5);
435 XlaOp log_lanczos_gamma_plus_one_half =
436 ScalarLike(input, std::log(kLanczosGamma + 0.5));
437
438 XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff);
439
440 // If the input is less than 0.5 use Euler's reflection formula:
441 // digamma(x) = digamma(1 - x) - pi * cot(pi * x)
442 XlaOp need_to_reflect = Lt(input, one_half);
443 XlaOp z = Select(need_to_reflect, -input, input - one);
444
445 XlaOp num = zero;
446 XlaOp denom = base_lanczos_coeff;
447 for (int i = 0; i < kLanczosCoefficients.size(); ++i) {
448 XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]);
449 XlaOp index = ScalarLike(input, i);
450 num = num - lanczos_coefficient / ((z + index + one) * (z + index + one));
451 denom = denom + lanczos_coefficient / (z + index + one);
452 }
453
454 // To improve accuracy on platforms with less-precise log implementations,
455 // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
456 // the device.
457 // log(t) = log(kLanczosGamma + 0.5 + z)
458 // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
459 XlaOp t = lanczos_gamma_plus_one_half + z;
460 XlaOp log_t = log_lanczos_gamma_plus_one_half +
461 Log1p(z / lanczos_gamma_plus_one_half);
462
463 XlaOp y = log_t + num / denom - lanczos_gamma / t;
464
465 // We need to be careful how we compute cot(pi * input) below: For
466 // near-integral values of `input`, pi * input can lose precision.
467 //
468 // Input is already known to be less than 0.5 (otherwise we don't have to
469 // reflect). We shift values smaller than -0.5 into the range [-.5, .5] to
470 // increase precision of pi * input and the resulting cotangent.
471 XlaOp reduced_input = input + Abs(Floor(input + ScalarLike(input, 0.5)));
472 XlaOp reflection =
473 y - pi * Cos(pi * reduced_input) / Sin(pi * reduced_input);
474 XlaOp real_result = Select(need_to_reflect, reflection, y);
475
476 // Digamma has poles at negative integers and zero; return nan for those.
477 return Select(And(Le(input, zero), Eq(input, Floor(input))),
478 FullLike(input, std::numeric_limits<float>::quiet_NaN()),
479 real_result);
480 };
481
482 auto& b = *input.builder();
483 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
484 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input));
485 return DoWithUpcastToF32(input, {BF16, F16}, do_it);
486 });
487 }
488
489 // Implements Banker's rounding: numbers that are equidistant between two
490 // integers are rounded towards even.
RoundToEven(XlaOp x)491 XlaOp RoundToEven(XlaOp x) {
492 auto& b = *x.builder();
493 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
494 // Reject non-real non-fp inputs (What does it even mean to round a complex
495 // number? Do you round each component equally? In that case, you should
496 // just ask for that explicitly.)
497 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RoundToEven", x));
498
499 auto half = ScalarLike(x, 0.5);
500 auto one = ScalarLike(x, 1.0);
501 auto two = ScalarLike(x, 2.0);
502
503 auto round_val = Floor(x);
504 auto fraction = x - round_val;
505 auto nearest_even_int = round_val - two * Floor(half * x);
506 auto is_odd = Eq(nearest_even_int, one);
507 return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)),
508 round_val + one, round_val);
509 });
510 }
511
512 // Trigonometric functions.
513
514 // acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1
515 // pi if x == -1
Acos(XlaOp x)516 XlaOp Acos(XlaOp x) {
517 return Select(Ne(x, FullLike(x, -1)),
518 ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x),
519 ScalarLike(x, 1.0) + x),
520 FullLike(x, M_PI));
521 }
522
523 // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
Asin(XlaOp x)524 XlaOp Asin(XlaOp x) {
525 return ScalarLike(x, 2.0) *
526 Atan2(x, ScalarLike(x, 1.0) + Sqrt(ScalarLike(x, 1.0) - x * x));
527 }
528
Atan(XlaOp x)529 XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); }
530
Tan(XlaOp x)531 XlaOp Tan(XlaOp x) { return Sin(x) / Cos(x); }
532
533 // Hyperbolic trigonometric functions.
534
535 // acosh(x) = log(x + sqrt(x^2 - 1))
536 // = log(x + sqrt((x+1)*(x-1)))
Acosh(XlaOp x)537 XlaOp Acosh(XlaOp x) {
538 return Log(x + Sqrt((x + ScalarLike(x, 1.0)) * (x - ScalarLike(x, 1.0))));
539 }
540
541 // asinh(x) = log(x + sqrt(x^2 + 1))
Asinh(XlaOp x)542 XlaOp Asinh(XlaOp x) { return Log(x + Sqrt(x * x + ScalarLike(x, 1.0))); }
543
544 // atanh(x) = 0.5 * log((1 + x) / (1 - x))
Atanh(XlaOp x)545 XlaOp Atanh(XlaOp x) {
546 return Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) *
547 ScalarLike(x, 0.5);
548 }
549
Cosh(XlaOp x)550 XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); }
551
Sinh(XlaOp x)552 XlaOp Sinh(XlaOp x) { return (Exp(x) - Exp(-x)) * ScalarLike(x, 0.5); }
553
MaybeConjugate(XlaOp x,bool conjugate)554 XlaOp MaybeConjugate(XlaOp x, bool conjugate) {
555 XlaBuilder* builder = x.builder();
556 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
557 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
558 auto perform_conj =
559 primitive_util::IsComplexType(shape.element_type()) && conjugate;
560 return perform_conj ? Conj(x) : x;
561 });
562 }
563
NextAfter(XlaOp from,XlaOp to)564 XlaOp NextAfter(XlaOp from, XlaOp to) {
565 auto builder = from.builder();
566 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
567 TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(from));
568 int bitwidth = primitive_util::BitWidth(shape.element_type());
569 auto int_type = primitive_util::UnsignedIntegralTypeForBitWidth(bitwidth);
570 auto from_as_int = BitcastConvertType(from, int_type);
571 auto to_as_int = BitcastConvertType(to, int_type);
572
573 // The result is NaN if either "from" or "to" are NaN.
574 auto from_is_nan = Ne(from, from);
575 auto to_is_nan = Ne(to, to);
576 auto nan_input = Or(from_is_nan, to_is_nan);
577 auto result_for_nan =
578 Broadcast(ScalarLike(from, std::numeric_limits<double>::quiet_NaN()),
579 shape.dimensions());
580 result_for_nan = BitcastConvertType(result_for_nan, int_type);
581
582 // The sign bit is the MSB.
583 const int64 sign_mask = int64{1} << (bitwidth - 1);
584 // Discard the sign bit to make the result non-negative.
585 auto from_abs = And(from_as_int, ScalarLike(from_as_int, ~sign_mask));
586 auto to_abs = And(to_as_int, ScalarLike(to_as_int, ~sign_mask));
587
588 // When both "from" and "to" are equal, the result is "to".
589 // N.B. It would not make a difference if we chose the result to be "from".
590 auto from_and_to_are_equal = Eq(from_as_int, to_as_int);
591 auto result_for_equal = to_as_int;
592
593 // When both "from" and "to" are both 0, the result is "to". This ensures we
594 // get a zero signed like "to".
595 auto from_is_zero = Eq(from_abs, ZerosLike(from_abs));
596 auto to_is_zero = Eq(to_abs, ZerosLike(to_abs));
597 auto result_for_both_zero = to_as_int;
598
599 auto from_sign = And(from_as_int, ScalarLike(from_as_int, sign_mask));
600 auto to_sign = And(to_as_int, ScalarLike(to_as_int, sign_mask));
601
602 // If from == 0 && to != 0, we need to return the smallest subnormal number
603 // signed like "to".
604 auto result_for_from_zero_to_non_zero =
605 Or(to_sign, ScalarLike(from_as_int, 1));
606
607 // If the sign of "from" and "to" disagree:
608 // - we need to make the magnitude of "from" smaller so that it is closer to
609 // zero.
610 //
611 // Otherwise the signs agree:
612 // - "from" with a magnitude larger than "to" means we need to make the
613 // magnitude smaller.
614 // - "from" with a magnitude smaller than "to" means we need to make the
615 // magnitude larger.
616 // - "from" with the same magnitude and sign as "to" has already been
617 // handled.
618 auto signs_disagree = Ne(from_sign, to_sign);
619 auto from_magnitude_larger_than_to = Gt(from_abs, to_abs);
620 auto result_has_smaller_magnitude =
621 Or(from_magnitude_larger_than_to, signs_disagree);
622 auto magnitude_adjustment =
623 Select(result_has_smaller_magnitude,
624 Broadcast(ScalarLike(from_as_int, -1), shape.dimensions()),
625 Broadcast(ScalarLike(from_as_int, 1), shape.dimensions()));
626 auto result = Add(from_as_int, magnitude_adjustment);
627 // Handle from == ±0.
628 result = Select(from_is_zero,
629 Select(to_is_zero, result_for_both_zero,
630 result_for_from_zero_to_non_zero),
631 result);
632 // Handle from == to.
633 result = Select(from_and_to_are_equal, result_for_equal, result);
634 // Handle isnan(from) || isnan(to).
635 result = Select(nan_input, result_for_nan, result);
636
637 // Cast back to the original type.
638 return BitcastConvertType(result, shape.element_type());
639 });
640 }
641
642 } // namespace xla
643