• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This macro is required to make MSVC defines math constants in math.h
17 #define _USE_MATH_DEFINES
18 #include <math.h>
19 
20 #include "tensorflow/compiler/xla/client/lib/math.h"
21 
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 
28 namespace xla {
29 
30 // Returns operation(operand), except if `operand` is one of the types in
31 // upcast_types, in which case first converts it to F32, and then converts the
32 // result down to the original type.
DoWithUpcastToF32(XlaOp operand,absl::Span<const PrimitiveType> upcast_types,const std::function<XlaOp (XlaOp)> & operation)33 static XlaOp DoWithUpcastToF32(XlaOp operand,
34                                absl::Span<const PrimitiveType> upcast_types,
35                                const std::function<XlaOp(XlaOp)>& operation) {
36   auto& b = *operand.builder();
37   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
38     TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
39     PrimitiveType elem_ty = shape.element_type();
40     bool needs_upcast = absl::c_linear_search(upcast_types, elem_ty);
41 
42     if (needs_upcast) {
43       operand = ConvertElementType(operand, F32);
44     }
45     XlaOp result = operation(operand);
46     if (needs_upcast) {
47       result = ConvertElementType(result, elem_ty);
48     }
49     return result;
50   });
51 }
52 
53 // TODO(jlebar): Use this function in more places in this file to restrict the
54 // domain of other functions.
EnsureOperandIsRealFp(absl::string_view op_name,XlaOp operand)55 static Status EnsureOperandIsRealFp(absl::string_view op_name, XlaOp operand) {
56   auto& b = *operand.builder();
57   TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
58   auto elem_ty = shape.element_type();
59   if (!primitive_util::IsFloatingPointType(elem_ty)) {
60     return InvalidArgument(
61         "Operands to %s must be real-valued floating-point, but got %s",
62         op_name, PrimitiveType_Name(elem_ty));
63   }
64   return Status::OK();
65 }
66 
IsPosInf(XlaOp operand)67 XlaOp IsPosInf(XlaOp operand) {
68   auto& b = *operand.builder();
69   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
70     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsPosInf", operand));
71     TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
72     // Note that this is only correct for floating-point types.  If we wanted it
73     // to be correct for all types, we'd need to Gt(MaxFiniteValue).
74     return Eq(operand, MaxValue(&b, shape.element_type()));
75   });
76 }
77 
IsNegInf(XlaOp operand)78 XlaOp IsNegInf(XlaOp operand) {
79   auto& b = *operand.builder();
80   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
81     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegInf", operand));
82     TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
83     // Note that this is only correct for floating-point types.  If we wanted it
84     // to be correct for all types, we'd need to Lt(MinFiniteValue).
85     return Eq(operand, MinValue(&b, shape.element_type()));
86   });
87 }
88 
IsInf(XlaOp operand)89 XlaOp IsInf(XlaOp operand) {
90   auto& b = *operand.builder();
91   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
92     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsInf", operand));
93     return IsPosInf(Abs(operand));
94   });
95 }
96 
IsNan(XlaOp operand)97 XlaOp IsNan(XlaOp operand) {
98   auto& b = *operand.builder();
99   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
100     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNan", operand));
101     return Ne(operand, operand);
102   });
103 }
104 
IsNegZero(XlaOp operand)105 XlaOp IsNegZero(XlaOp operand) {
106   auto& b = *operand.builder();
107   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
108     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegZero", operand));
109     TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
110 
111     // The bitwise representation of -0 in bfloat16 and IEEE 754 is 0x80...0
112     // (sign bit on, all other bits off).
113     switch (shape.element_type()) {
114       case F64:
115         return Eq(BitcastConvertType(operand, U64),
116                   ConstantR0WithType(&b, U64, uint64{1} << 63));
117       case F32:
118         return Eq(BitcastConvertType(operand, U32),
119                   ConstantR0WithType(&b, U32, uint32{1} << 31));
120       case F16:
121       case BF16:
122         // Not all XLA backends handle U16 well, so we convert to F32/U32.
123         // TODO(jlebar): It would be nice if we could stay in (B)F16/U16 for
124         // backends that *do* support it.
125         return Eq(BitcastConvertType(ConvertElementType(operand, F32), U32),
126                   ConstantR0WithType(&b, U32, uint32{1} << 31));
127       default:
128         LOG(FATAL) << "Expected real fp type.";
129     }
130   });
131 }
132 
Square(XlaOp operand)133 XlaOp Square(XlaOp operand) { return operand * operand; }
134 
Reciprocal(XlaOp operand)135 XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; }
136 
137 // Evaluate the polynomial given coefficients and `x`.
138 // N.B. Coefficients should be supplied in decreasing order.
EvaluatePolynomial(XlaOp x,absl::Span<const float> coefficients)139 XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const float> coefficients) {
140   XlaOp poly = ScalarLike(x, 0.0);
141   for (float c : coefficients) {
142     poly = poly * x + ScalarLike(x, c);
143   }
144   return poly;
145 }
146 
147 // Computes an approximation of the error function complement (1 - erf(x)).
148 //
149 // Precondition: abs(x) >= 1.  Otherwise, use ErfImpl.
150 //
151 // This follows Cephes's f32 implementation of erfc, and so it may have errors
152 // for double precision.
153 //
154 // See also these alternate implementations of erf and erfc:
155 //
156 //   https://stackoverflow.com/questions/35148198
157 //   https://stackoverflow.com/questions/35966695
158 //
ErfcImpl(XlaOp x)159 static XlaOp ErfcImpl(XlaOp x) {
160   // Coefficients for erfc(f32), from Cephes.
161   //
162   // erfc(x) = exp(-x^2) P(1/x), 1 < x < 2
163   static std::array<float, 9> kErfcPCoefficient{
164       +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1,
165       -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1,
166       +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1,
167   };
168   // erfc(x) = exp(-x^2) 1/x P(1/x^2), 2 < x < 14
169   static std::array<float, 8> kErfcRCoefficient{
170       -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0,
171       +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1,
172       -2.820767439740514E-1, +5.641895067754075E-1,
173   };
174 
175   XlaOp abs_x = Abs(x);
176   XlaOp z = Exp(-x * x);
177   XlaOp q = ScalarLike(x, 1) / abs_x;
178   XlaOp y = q * q;
179   XlaOp p = Select(Lt(abs_x, ScalarLike(x, 2.0)),
180                    EvaluatePolynomial(y, kErfcPCoefficient),
181                    EvaluatePolynomial(y, kErfcRCoefficient));
182   y = z * q * p;
183   return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y, y);
184 }
185 
186 // Compute a polynomial approximation of the error function.
187 //
188 // Precondition: abs(x) <= 1.  Otherwise, use ErfcImpl.
189 //
190 // This follows Cephes's f32 implementation of erf, so it may have errors for
191 // double precision.
ErfImpl(XlaOp x)192 static XlaOp ErfImpl(XlaOp x) {
193   // Coefficients for by erf(f32), from Cephes.
194   //
195   // erf(x) = x P(x^2), 0 < x < 1
196   static std::array<float, 7> kErfTCoefficient{
197       +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3,
198       -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1,
199       +1.128379165726710E+0,
200   };
201 
202   return x * EvaluatePolynomial(x * x, kErfTCoefficient);
203 }
204 
Erfc(XlaOp x)205 XlaOp Erfc(XlaOp x) {
206   auto& b = *x.builder();
207   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
208     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erfc", x));
209 
210     // erfc(x) =
211     //   erfc_impl(x)           if x > 1
212     //   1 - erf_impl(x)        otherwise
213     //
214     // Erf(c)Impl don't have enough precision when run with bf16 intermediates
215     // (not surprising!), so upcast to f32 in this case.
216     return DoWithUpcastToF32(x, {BF16}, [](XlaOp x) {
217       return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl(x),
218                     ScalarLike(x, 1) - ErfImpl(x));
219     });
220   });
221 }
222 
Erf(XlaOp x)223 XlaOp Erf(XlaOp x) {
224   auto& b = *x.builder();
225   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
226     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erf", x));
227     // erf(x) =
228     //   erf_impl(x)            if x < 1
229     //   1 - erfc_impl(x)       otherwise
230     //
231     // Erf(c)Impl don't have enough precision when run with bf16 intermediates
232     // (not surprising!), so upcast to f32 in this case.
233     return DoWithUpcastToF32(x, {BF16}, [](XlaOp x) {
234       return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl(x),
235                     ScalarLike(x, 1) - ErfcImpl(x));
236     });
237   });
238 }
239 
240 // Approximation for the inverse error function from
241 //   Giles, M., "Approximating the erfinv function".
242 // The approximation has the form:
243 //   w = -log((1 - x) * (1 + x))
244 //   if ( w < 5 ) {
245 //     w = w - 2.5
246 //     p = sum_{i=1}^n lq[i]*w^i
247 //   } else {
248 //     w = sqrt(w) - 3
249 //     p = sum_{i=1}^n gq[i]*w^i
250 //   }
251 //   return p*x
ErfInv(XlaOp x)252 XlaOp ErfInv(XlaOp x) {
253   constexpr int kDegree = 9;
254   constexpr std::array<float, 9> w_less_than_5_constants = {
255       2.81022636e-08f,  3.43273939e-07f, -3.5233877e-06f,
256       -4.39150654e-06f, 0.00021858087f,  -0.00125372503f,
257       -0.00417768164f,  0.246640727f,    1.50140941f};
258   constexpr std::array<float, 9> w_greater_than_5_constants = {
259       -0.000200214257f, 0.000100950558f, 0.00134934322f,
260       -0.00367342844f,  0.00573950773f,  -0.0076224613f,
261       0.00943887047f,   1.00167406f,     2.83297682f};
262 
263   auto one = ScalarLike(x, 1.0);
264   auto w = -Log((one - x) * (one + x));
265 
266   auto lt = Lt(w, ScalarLike(x, 5.0));
267   auto coefficient = [&](int i) {
268     return Select(lt, FullLike(x, w_less_than_5_constants[i]),
269                   FullLike(x, w_greater_than_5_constants[i]));
270   };
271   w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0));
272   auto p = coefficient(0);
273   for (int i = 1; i < kDegree; ++i) {
274     p = coefficient(i) + p * w;
275   }
276 
277   // Result modulo edge cases.
278   XlaOp result = p * x;
279 
280   // Handle edge cases, namely erfinv(+/-1) = +/-inf.  (The above computation is
281   // indeterminate, and can give nan or -/+inf.)
282   auto& b = *x.builder();
283   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
284     TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x));
285     return Select(Eq(Abs(x), ScalarLike(x, 1)),
286                   x * MaxValue(&b, shape.element_type()), result);
287   });
288 }
289 
290 namespace {
291 // Coefficients for the Lanczos approximation of the gamma function. The
292 // coefficients are uniquely determined by the choice of g and n (kLanczosGamma
293 // and kLanczosCoefficients.size() + 1). The coefficients below correspond to
294 // [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7,
295 // 9] seemed to be the least sensitive to the quality of the log function. In
296 // particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
297 // for a particularly inaccurate log function.
298 static constexpr double kLanczosGamma = 7;  // aka g
299 static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
300 static constexpr std::array<double, 8> kLanczosCoefficients = {
301     676.520368121885098567009190444019, -1259.13921672240287047156078755283,
302     771.3234287776530788486528258894,   -176.61502916214059906584551354,
303     12.507343278686904814458936853,     -0.13857109526572011689554707,
304     9.984369578019570859563e-6,         1.50563273514931155834e-7};
305 }  // namespace
306 
307 // Compute the Lgamma function using Lanczos' approximation from "A Precision
308 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
309 // series B. Vol. 1:
310 // lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
311 // t(z) = z + kLanczosGamma + 1/2
312 // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
Lgamma(XlaOp input)313 XlaOp Lgamma(XlaOp input) {
314   auto do_it = [](XlaOp input) {
315     XlaOp one_half = ScalarLike(input, 0.5);
316     XlaOp one = ScalarLike(input, 1);
317 
318     XlaOp pi = ScalarLike(input, M_PI);
319     XlaOp log_pi = ScalarLike(input, std::log(M_PI));
320     XlaOp log_sqrt_two_pi =
321         ScalarLike(input, (std::log(2) + std::log(M_PI)) / 2);
322 
323     XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5);
324     XlaOp log_lanczos_gamma_plus_one_half =
325         ScalarLike(input, std::log(kLanczosGamma + 0.5));
326 
327     XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff);
328 
329     // If the input is less than 0.5 use Euler's reflection formula:
330     // gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
331     XlaOp need_to_reflect = Lt(input, one_half);
332     XlaOp z = Select(need_to_reflect, -input, input - one);
333 
334     XlaOp x = base_lanczos_coeff;
335     for (int i = 0; i < kLanczosCoefficients.size(); ++i) {
336       XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]);
337       XlaOp index = ScalarLike(input, i);
338       x = x + lanczos_coefficient / (z + index + one);
339     }
340 
341     // To improve accuracy on platforms with less-precise log implementations,
342     // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
343     // the device.
344     // log(t) = log(kLanczosGamma + 0.5 + z)
345     //        = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
346     XlaOp t = lanczos_gamma_plus_one_half + z;
347     XlaOp log_t = log_lanczos_gamma_plus_one_half +
348                   Log1p(z / lanczos_gamma_plus_one_half);
349 
350     // Compute the final result (modulo reflection).  t(z) may be large, and we
351     // need to be careful not to overflow to infinity in the first term of
352     //
353     //   (z + 1/2) * log(t(z)) - t(z).
354     //
355     // Therefore we compute this as
356     //
357     //   (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
358     //
359     XlaOp log_y = log_sqrt_two_pi + (z + one_half - t / log_t) * log_t + Log(x);
360 
361     // Compute the reflected value, used when x < 0.5:
362     //
363     //   lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
364     //
365     // (The abs is because lgamma is the log of the absolute value of the gamma
366     // function.)
367     //
368     // We have to be careful when computing the final term above. gamma(x) goes
369     // to +/-inf at every integer x < 0, and this is controlled by the
370     // sin(pi * x) term.  The slope is large, so precision is particularly
371     // important.
372     //
373     // Because abs(sin(pi * x)) has period 1, we can equivalently use
374     // abs(sin(pi * frac(x))), where frac(x) is the fractional part of x.  This
375     // is more numerically accurate: It doesn't overflow to inf like pi * x can,
376     // and if x is an integer, it evaluates to 0 exactly, which is significant
377     // because we then take the log of this value, and log(0) is inf.
378     //
379     // We don't have a frac(x) primitive in XLA and computing it is tricky, but
380     // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for
381     // our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
382     //
383     // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
384     // to 1.  To remedy this, we can use the fact that sin(pi * x) in the domain
385     // [0, 1] is symmetric across the line Y=0.5.
386     //
387     XlaOp abs_input = Abs(input);
388     XlaOp abs_frac_input = abs_input - Floor(abs_input);
389     // Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve
390     // precision of pi * abs_frac_input for values of abs_frac_input close to 1.
391     XlaOp reduced_frac_input =
392         Select(Gt(abs_frac_input, ScalarLike(abs_frac_input, 0.5)),
393                ScalarLike(abs_frac_input, 1) - abs_frac_input, abs_frac_input);
394     XlaOp reflection_denom = Log(Sin(pi * reduced_frac_input));
395 
396     // Avoid computing -inf - inf, which is nan.  If reflection_denom is +/-inf,
397     // then it "wins" and the result is +/-inf.
398     XlaOp reflection =
399         Select(IsFinite(reflection_denom), log_pi - reflection_denom - log_y,
400                -reflection_denom);
401     XlaOp result = Select(need_to_reflect, reflection, log_y);
402 
403     // lgamma(+/-inf) = +inf.
404     XlaOp inf_bcast = FullLike(input, std::numeric_limits<float>::infinity());
405     return Select(IsInf(input), inf_bcast, result);
406   };
407 
408   auto& b = *input.builder();
409   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
410     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Lgamma", input));
411     // F16 and BF16 don't provide sufficient precision for intermediate results
412     // here (although it's better than you might expect!), so do the
413     // computations in F32.
414     return DoWithUpcastToF32(input, {BF16, F16}, do_it);
415   });
416 }
417 
418 // Compute the Digamma function using Lanczos' approximation from "A Precision
419 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
420 // series B. Vol. 1:
421 // digamma(z + 1) = log(t(z)) + A'(z) / A(z) - kLanczosGamma / t(z)
422 // t(z) = z + kLanczosGamma + 1/2
423 // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
424 // A'(z) = sigma(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
Digamma(XlaOp input)425 XlaOp Digamma(XlaOp input) {
426   auto do_it = [](XlaOp input) {
427     XlaOp zero = ScalarLike(input, 0);
428     XlaOp one_half = ScalarLike(input, 0.5);
429     XlaOp one = ScalarLike(input, 1);
430 
431     XlaOp pi = ScalarLike(input, M_PI);
432 
433     XlaOp lanczos_gamma = ScalarLike(input, kLanczosGamma);
434     XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5);
435     XlaOp log_lanczos_gamma_plus_one_half =
436         ScalarLike(input, std::log(kLanczosGamma + 0.5));
437 
438     XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff);
439 
440     // If the input is less than 0.5 use Euler's reflection formula:
441     // digamma(x) = digamma(1 - x) - pi * cot(pi * x)
442     XlaOp need_to_reflect = Lt(input, one_half);
443     XlaOp z = Select(need_to_reflect, -input, input - one);
444 
445     XlaOp num = zero;
446     XlaOp denom = base_lanczos_coeff;
447     for (int i = 0; i < kLanczosCoefficients.size(); ++i) {
448       XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]);
449       XlaOp index = ScalarLike(input, i);
450       num = num - lanczos_coefficient / ((z + index + one) * (z + index + one));
451       denom = denom + lanczos_coefficient / (z + index + one);
452     }
453 
454     // To improve accuracy on platforms with less-precise log implementations,
455     // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
456     // the device.
457     // log(t) = log(kLanczosGamma + 0.5 + z)
458     //        = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
459     XlaOp t = lanczos_gamma_plus_one_half + z;
460     XlaOp log_t = log_lanczos_gamma_plus_one_half +
461                   Log1p(z / lanczos_gamma_plus_one_half);
462 
463     XlaOp y = log_t + num / denom - lanczos_gamma / t;
464 
465     // We need to be careful how we compute cot(pi * input) below: For
466     // near-integral values of `input`, pi * input can lose precision.
467     //
468     // Input is already known to be less than 0.5 (otherwise we don't have to
469     // reflect).  We shift values smaller than -0.5 into the range [-.5, .5] to
470     // increase precision of pi * input and the resulting cotangent.
471     XlaOp reduced_input = input + Abs(Floor(input + ScalarLike(input, 0.5)));
472     XlaOp reflection =
473         y - pi * Cos(pi * reduced_input) / Sin(pi * reduced_input);
474     XlaOp real_result = Select(need_to_reflect, reflection, y);
475 
476     // Digamma has poles at negative integers and zero; return nan for those.
477     return Select(And(Le(input, zero), Eq(input, Floor(input))),
478                   FullLike(input, std::numeric_limits<float>::quiet_NaN()),
479                   real_result);
480   };
481 
482   auto& b = *input.builder();
483   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
484     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input));
485     return DoWithUpcastToF32(input, {BF16, F16}, do_it);
486   });
487 }
488 
489 // Implements Banker's rounding: numbers that are equidistant between two
490 // integers are rounded towards even.
RoundToEven(XlaOp x)491 XlaOp RoundToEven(XlaOp x) {
492   auto& b = *x.builder();
493   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
494     // Reject non-real non-fp inputs (What does it even mean to round a complex
495     // number?  Do you round each component equally?  In that case, you should
496     // just ask for that explicitly.)
497     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RoundToEven", x));
498 
499     auto half = ScalarLike(x, 0.5);
500     auto one = ScalarLike(x, 1.0);
501     auto two = ScalarLike(x, 2.0);
502 
503     auto round_val = Floor(x);
504     auto fraction = x - round_val;
505     auto nearest_even_int = round_val - two * Floor(half * x);
506     auto is_odd = Eq(nearest_even_int, one);
507     return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)),
508                   round_val + one, round_val);
509   });
510 }
511 
512 // Trigonometric functions.
513 
514 // acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1
515 //           pi                                if x == -1
Acos(XlaOp x)516 XlaOp Acos(XlaOp x) {
517   return Select(Ne(x, FullLike(x, -1)),
518                 ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x),
519                                            ScalarLike(x, 1.0) + x),
520                 FullLike(x, M_PI));
521 }
522 
523 // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
Asin(XlaOp x)524 XlaOp Asin(XlaOp x) {
525   return ScalarLike(x, 2.0) *
526          Atan2(x, ScalarLike(x, 1.0) + Sqrt(ScalarLike(x, 1.0) - x * x));
527 }
528 
Atan(XlaOp x)529 XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); }
530 
Tan(XlaOp x)531 XlaOp Tan(XlaOp x) { return Sin(x) / Cos(x); }
532 
533 // Hyperbolic trigonometric functions.
534 
535 // acosh(x) = log(x + sqrt(x^2 - 1))
536 //          = log(x + sqrt((x+1)*(x-1)))
Acosh(XlaOp x)537 XlaOp Acosh(XlaOp x) {
538   return Log(x + Sqrt((x + ScalarLike(x, 1.0)) * (x - ScalarLike(x, 1.0))));
539 }
540 
541 // asinh(x) = log(x + sqrt(x^2 + 1))
Asinh(XlaOp x)542 XlaOp Asinh(XlaOp x) { return Log(x + Sqrt(x * x + ScalarLike(x, 1.0))); }
543 
544 // atanh(x) = 0.5 * log((1 + x) / (1 - x))
Atanh(XlaOp x)545 XlaOp Atanh(XlaOp x) {
546   return Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) *
547          ScalarLike(x, 0.5);
548 }
549 
Cosh(XlaOp x)550 XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); }
551 
Sinh(XlaOp x)552 XlaOp Sinh(XlaOp x) { return (Exp(x) - Exp(-x)) * ScalarLike(x, 0.5); }
553 
MaybeConjugate(XlaOp x,bool conjugate)554 XlaOp MaybeConjugate(XlaOp x, bool conjugate) {
555   XlaBuilder* builder = x.builder();
556   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
557     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
558     auto perform_conj =
559         primitive_util::IsComplexType(shape.element_type()) && conjugate;
560     return perform_conj ? Conj(x) : x;
561   });
562 }
563 
NextAfter(XlaOp from,XlaOp to)564 XlaOp NextAfter(XlaOp from, XlaOp to) {
565   auto builder = from.builder();
566   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
567     TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(from));
568     int bitwidth = primitive_util::BitWidth(shape.element_type());
569     auto int_type = primitive_util::UnsignedIntegralTypeForBitWidth(bitwidth);
570     auto from_as_int = BitcastConvertType(from, int_type);
571     auto to_as_int = BitcastConvertType(to, int_type);
572 
573     // The result is NaN if either "from" or "to" are NaN.
574     auto from_is_nan = Ne(from, from);
575     auto to_is_nan = Ne(to, to);
576     auto nan_input = Or(from_is_nan, to_is_nan);
577     auto result_for_nan =
578         Broadcast(ScalarLike(from, std::numeric_limits<double>::quiet_NaN()),
579                   shape.dimensions());
580     result_for_nan = BitcastConvertType(result_for_nan, int_type);
581 
582     // The sign bit is the MSB.
583     const int64 sign_mask = int64{1} << (bitwidth - 1);
584     // Discard the sign bit to make the result non-negative.
585     auto from_abs = And(from_as_int, ScalarLike(from_as_int, ~sign_mask));
586     auto to_abs = And(to_as_int, ScalarLike(to_as_int, ~sign_mask));
587 
588     // When both "from" and "to" are equal, the result is "to".
589     // N.B. It would not make a difference if we chose the result to be "from".
590     auto from_and_to_are_equal = Eq(from_as_int, to_as_int);
591     auto result_for_equal = to_as_int;
592 
593     // When both "from" and "to" are both 0, the result is "to". This ensures we
594     // get a zero signed like "to".
595     auto from_is_zero = Eq(from_abs, ZerosLike(from_abs));
596     auto to_is_zero = Eq(to_abs, ZerosLike(to_abs));
597     auto result_for_both_zero = to_as_int;
598 
599     auto from_sign = And(from_as_int, ScalarLike(from_as_int, sign_mask));
600     auto to_sign = And(to_as_int, ScalarLike(to_as_int, sign_mask));
601 
602     // If from == 0 && to != 0, we need to return the smallest subnormal number
603     // signed like "to".
604     auto result_for_from_zero_to_non_zero =
605         Or(to_sign, ScalarLike(from_as_int, 1));
606 
607     // If the sign of "from" and "to" disagree:
608     // - we need to make the magnitude of "from" smaller so that it is closer to
609     //   zero.
610     //
611     // Otherwise the signs agree:
612     // - "from" with a magnitude larger than "to" means we need to make the
613     //   magnitude smaller.
614     // - "from" with a magnitude smaller than "to" means we need to make the
615     //   magnitude larger.
616     // - "from" with the same magnitude and sign as "to" has already been
617     //   handled.
618     auto signs_disagree = Ne(from_sign, to_sign);
619     auto from_magnitude_larger_than_to = Gt(from_abs, to_abs);
620     auto result_has_smaller_magnitude =
621         Or(from_magnitude_larger_than_to, signs_disagree);
622     auto magnitude_adjustment =
623         Select(result_has_smaller_magnitude,
624                Broadcast(ScalarLike(from_as_int, -1), shape.dimensions()),
625                Broadcast(ScalarLike(from_as_int, 1), shape.dimensions()));
626     auto result = Add(from_as_int, magnitude_adjustment);
627     // Handle from == ±0.
628     result = Select(from_is_zero,
629                     Select(to_is_zero, result_for_both_zero,
630                            result_for_from_zero_to_non_zero),
631                     result);
632     // Handle from == to.
633     result = Select(from_and_to_are_equal, result_for_equal, result);
634     // Handle isnan(from) || isnan(to).
635     result = Select(nan_input, result_for_nan, result);
636 
637     // Cast back to the original type.
638     return BitcastConvertType(result, shape.element_type());
639   });
640 }
641 
642 }  // namespace xla
643