• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/math.h"
17 
18 #include <cmath>
19 
20 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/lib/loops.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 
28 namespace xla {
29 namespace {
30 
31 // Evaluate the polynomial given `x` and coefficients in decreasing order.
32 template <typename FP>
EvaluatePolynomial(XlaOp x,absl::Span<const FP> coefficients)33 XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const FP> coefficients) {
34   static_assert(std::is_floating_point<FP>::value,
35                 "Template-argument 'FP' must be a floating-point type");
36   XlaOp poly = ScalarLike(x, 0.0);
37   for (FP c : coefficients) {
38     poly = poly * x + ScalarLike(x, c);
39   }
40   return poly;
41 }
42 
43 // Evaluate the chebyshev polynomial given `x` and coefficients in decreasing
44 // order.
45 template <typename FP>
EvaluateChebyshevPolynomial(XlaOp x,absl::Span<const FP> coefficients)46 XlaOp EvaluateChebyshevPolynomial(XlaOp x, absl::Span<const FP> coefficients) {
47   static_assert(std::is_floating_point<FP>::value,
48                 "Template-argument 'FP' must be a floating-point type");
49   XlaOp b0 = ScalarLike(x, 0.0);
50   XlaOp b1 = ScalarLike(x, 0.0);
51   XlaOp b2 = ScalarLike(x, 0.0);
52   for (FP c : coefficients) {
53     b2 = b1;
54     b1 = b0;
55     b0 = x * b1 - b2 + ScalarLike(x, c);
56   }
57   return ScalarLike(x, 0.5) * (b0 - b2);
58 }
59 
60 }  // namespace
61 
62 // Returns operation(operand), except if `operand` is one of the types in
63 // upcast_types, in which case first converts it to F32, and then converts the
64 // result down to the original type.
DoWithUpcastToF32(XlaOp operand,absl::Span<const PrimitiveType> upcast_types,const std::function<XlaOp (XlaOp)> & operation)65 static XlaOp DoWithUpcastToF32(XlaOp operand,
66                                absl::Span<const PrimitiveType> upcast_types,
67                                const std::function<XlaOp(XlaOp)>& operation) {
68   auto& b = *operand.builder();
69   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
70     TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
71     PrimitiveType elem_ty = shape.element_type();
72     bool needs_upcast = absl::c_linear_search(upcast_types, elem_ty);
73 
74     if (needs_upcast) {
75       operand = ConvertElementType(operand, F32);
76     }
77     XlaOp result = operation(operand);
78     if (needs_upcast) {
79       result = ConvertElementType(result, elem_ty);
80     }
81     return result;
82   });
83 }
84 
85 // TODO(jlebar): Use this function in more places in this file to restrict the
86 // domain of other functions.
EnsureOperandIsRealFp(absl::string_view op_name,XlaOp operand)87 static Status EnsureOperandIsRealFp(absl::string_view op_name, XlaOp operand) {
88   auto& b = *operand.builder();
89   TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
90   auto elem_ty = shape.element_type();
91   if (!primitive_util::IsFloatingPointType(elem_ty)) {
92     return InvalidArgument(
93         "Operands to %s must be real-valued floating-point, but got %s",
94         op_name, PrimitiveType_Name(elem_ty));
95   }
96   return Status::OK();
97 }
98 
IsPosInf(XlaOp operand)99 XlaOp IsPosInf(XlaOp operand) {
100   auto& b = *operand.builder();
101   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
102     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsPosInf", operand));
103     TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
104     // Note that this is only correct for floating-point types.  If we wanted it
105     // to be correct for all types, we'd need to Gt(MaxFiniteValue).
106     return Eq(operand, MaxValue(&b, shape.element_type()));
107   });
108 }
109 
IsNegInf(XlaOp operand)110 XlaOp IsNegInf(XlaOp operand) {
111   auto& b = *operand.builder();
112   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
113     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegInf", operand));
114     TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
115     // Note that this is only correct for floating-point types.  If we wanted it
116     // to be correct for all types, we'd need to Lt(MinFiniteValue).
117     return Eq(operand, MinValue(&b, shape.element_type()));
118   });
119 }
120 
IsInf(XlaOp operand)121 XlaOp IsInf(XlaOp operand) {
122   auto& b = *operand.builder();
123   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
124     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsInf", operand));
125     return IsPosInf(Abs(operand));
126   });
127 }
128 
IsNan(XlaOp operand)129 XlaOp IsNan(XlaOp operand) {
130   auto& b = *operand.builder();
131   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
132     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNan", operand));
133     return Ne(operand, operand);
134   });
135 }
136 
IsNegZero(XlaOp operand)137 XlaOp IsNegZero(XlaOp operand) {
138   auto& b = *operand.builder();
139   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
140     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegZero", operand));
141     TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
142 
143     // The bitwise representation of -0 in bfloat16 and IEEE 754 is 0x80...0
144     // (sign bit on, all other bits off).
145     switch (shape.element_type()) {
146       case F64:
147         return Eq(BitcastConvertType(operand, U64),
148                   ConstantR0WithType(&b, U64, uint64{1} << 63));
149       case F32:
150         return Eq(BitcastConvertType(operand, U32),
151                   ConstantR0WithType(&b, U32, uint32{1} << 31));
152       case F16:
153       case BF16:
154         // Not all XLA backends handle U16 well, so we convert to F32/U32.
155         // TODO(jlebar): It would be nice if we could stay in (B)F16/U16 for
156         // backends that *do* support it.
157         return Eq(BitcastConvertType(ConvertElementType(operand, F32), U32),
158                   ConstantR0WithType(&b, U32, uint32{1} << 31));
159       default:
160         LOG(FATAL) << "Expected real fp type.";
161     }
162   });
163 }
164 
Square(XlaOp operand)165 XlaOp Square(XlaOp operand) { return operand * operand; }
166 
Reciprocal(XlaOp operand)167 XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; }
168 
169 // Computes an approximation of the error function complement (1 - erf(x)).
170 //
171 // Precondition: abs(x) >= 1.  Otherwise, use ErfImpl.
172 //
173 // This follows Cephes's f32 implementation of erfc.
ErfcImpl32(XlaOp x)174 static XlaOp ErfcImpl32(XlaOp x) {
175   // Coefficients for erfc(f32), from Cephes.
176   const double kMaxlog = 88.72283905206835;
177   // erfc(x) = exp(-x^2) P(1/x^2), 1 < x < 2
178   static const std::array<float, 9> kErfcPCoefficient{
179       +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1,
180       -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1,
181       +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1,
182   };
183   // erfc(x) = exp(-x^2) R(1/x^2), 2 <= x < kMaxlog
184   static const std::array<float, 8> kErfcRCoefficient{
185       -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0,
186       +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1,
187       -2.820767439740514E-1, +5.641895067754075E-1,
188   };
189   XlaOp abs_x = Abs(x);
190   XlaOp z = Exp(-x * x);
191   XlaOp q = ScalarLike(x, 1) / abs_x;
192   XlaOp y = q * q;
193   XlaOp p = Select(Lt(abs_x, ScalarLike(x, 2.0)),
194                    EvaluatePolynomial<float>(y, kErfcPCoefficient),
195                    EvaluatePolynomial<float>(y, kErfcRCoefficient));
196   y = z * q * p;
197   XlaOp y_clamp = Select(Lt(z, ScalarLike(x, -kMaxlog)), ScalarLike(x, 0), y);
198   return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y_clamp, y_clamp);
199 }
200 
201 // Compute a polynomial approximation of the error function.
202 //
203 // Precondition: abs(x) <= 1.  Otherwise, use ErfcImpl.
204 //
205 // This follows Cephes's f32 implementation of erf.
ErfImpl32Cephes(XlaOp x)206 static XlaOp ErfImpl32Cephes(XlaOp x) {
207   // Coefficients for by erf(f32), from Cephes.
208   //
209   // erf(x) = x P(x^2), 0 < x < 1
210   static const std::array<float, 7> kErfTCoefficient{
211       +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3,
212       -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1,
213       +1.128379165726710E+0,
214   };
215   return x * EvaluatePolynomial<float>(x * x, kErfTCoefficient);
216 }
217 
ErfcImpl64(XlaOp x)218 static XlaOp ErfcImpl64(XlaOp x) {
219   // Coefficients for erfc(f64), from Cephes.
220   const double kMaxlog = 7.09782712893383996843E2;
221   // erfc(x) = exp(-x^2) P(|x|) / Q(|x|), 1 < x < 8
222   static const std::array<double, 9> kErfcPCoefficient{
223       2.46196981473530512524E-10, 5.64189564831068821977E-1,
224       7.46321056442269912687E0,   4.86371970985681366614E1,
225       1.96520832956077098242E2,   5.26445194995477358631E2,
226       9.34528527171957607540E2,   1.02755188689515710272E3,
227       5.57535335369399327526E2};
228   static const std::array<double, 9> kErfcQCoefficient{
229       1.00000000000000000000E0, 1.32281951154744992508E1,
230       8.67072140885989742329E1, 3.54937778887819891062E2,
231       9.75708501743205489753E2, 1.82390916687909736289E3,
232       2.24633760818710981792E3, 1.65666309194161350182E3,
233       5.57535340817727675546E2};
234 
235   // erfc(x) = exp(-x^2) R(|x|) / S(|x|), 8 <= x < kMaxlog
236   static const std::array<double, 6> kErfcRCoefficient{
237       5.64189583547755073984E-1, 1.27536670759978104416E0,
238       5.01905042251180477414E0,  6.16021097993053585195E0,
239       7.40974269950448939160E0,  2.97886665372100240670E0};
240   static const std::array<double, 7> kErfcSCoefficient{
241       1.00000000000000000000E0, 2.26052863220117276590E0,
242       9.39603524938001434673E0, 1.20489539808096656605E1,
243       1.70814450747565897222E1, 9.60896809063285878198E0,
244       3.36907645100081516050E0};
245 
246   XlaOp z = -x * x;
247   XlaOp abs_x = Abs(x);
248   XlaOp y =
249       Select(Lt(abs_x, ScalarLike(x, 8.0)),
250              Exp(z) * EvaluatePolynomial<double>(abs_x, kErfcPCoefficient) /
251                  EvaluatePolynomial<double>(abs_x, kErfcQCoefficient),
252              Exp(z) * EvaluatePolynomial<double>(abs_x, kErfcRCoefficient) /
253                  EvaluatePolynomial<double>(abs_x, kErfcSCoefficient));
254   XlaOp y_clamp = Select(Lt(z, ScalarLike(x, -kMaxlog)), ScalarLike(x, 0), y);
255   return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y_clamp, y_clamp);
256 }
257 
258 // Compute a polynomial approximation of the error function.
259 //
260 // Precondition: abs(x) <= 1.  Otherwise, use ErfcImpl.
ErfImpl64(XlaOp x)261 static XlaOp ErfImpl64(XlaOp x) {
262   // Coefficients for by erf(f64), from Cephes.
263   //
264   // erf(x) = x T(x^2) / U(x^2), 0 < x < 1
265   static std::array<double, 5> kErfTCoefficient{
266       9.60497373987051638749E0, 9.00260197203842689217E1,
267       2.23200534594684319226E3, 7.00332514112805075473E3,
268       5.55923013010394962768E4};
269   static std::array<double, 6> kErfUCoefficient{
270       1.00000000000000000000E0, 3.35617141647503099647E1,
271       5.21357949780152679795E2, 4.59432382970980127987E3,
272       2.26290000613890934246E4, 4.92673942608635921086E4};
273   XlaOp z = x * x;
274   return x * EvaluatePolynomial<double>(z, kErfTCoefficient) /
275          EvaluatePolynomial<double>(z, kErfUCoefficient);
276 }
277 
Erfc(XlaOp x)278 XlaOp Erfc(XlaOp x) {
279   auto& b = *x.builder();
280   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
281     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erfc", x));
282     TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
283     // erfc(x) =
284     //   erfc_impl(x)           if x > 1
285     //   1 - erf_impl(x)        otherwise
286     if (shape.element_type() == F64) {
287       return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl64(x),
288                     ScalarLike(x, 1) - ErfImpl64(x));
289     }
290     // Erf(c)Impl don't have enough precision when run with bf16 intermediates
291     // (not surprising!), so upcast to f32 in this case.
292     return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
293       return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x),
294                     ScalarLike(x, 1) - ErfImpl32Cephes(x));
295     });
296   });
297 }
298 
299 // Compute a polynomial approximation of the error function.
300 // This is the same approximation used by Eigen.
ErfImpl32(XlaOp x)301 static XlaOp ErfImpl32(XlaOp x) {
302   static const std::array<float, 7> kAlpha{
303       -2.72614225801306e-10f, 2.77068142495902e-08f,  -2.10102402082508e-06f,
304       -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
305       -1.60960333262415e-02f,
306   };
307 
308   static const std::array<float, 5> kBeta{
309       -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
310       -7.37332916720468e-03f, -1.42647390514189e-02f,
311   };
312 
313   x = Clamp(ScalarLike(x, -4.f), x, ScalarLike(x, 4.f));
314   auto x2 = x * x;
315   return x * EvaluatePolynomial<float>(x2, kAlpha) /
316          EvaluatePolynomial<float>(x2, kBeta);
317 }
318 
Erf(XlaOp x)319 XlaOp Erf(XlaOp x) {
320   auto& b = *x.builder();
321   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
322     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erf", x));
323     TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
324     // erf(x) =
325     //   erf_impl(x)            if x < 1
326     //   1 - erfc_impl(x)       otherwise
327     if (shape.element_type() == F64) {
328       return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl64(x),
329                     ScalarLike(x, 1) - ErfcImpl64(x));
330     }
331     // Erf(c)Impl don't have enough precision when run with bf16 intermediates
332     // (not surprising!), so upcast to f32 in this case.
333     return DoWithUpcastToF32(x, {BF16, F16},
334                              [](XlaOp x) { return ErfImpl32(x); });
335   });
336 }
337 
338 namespace {
339 
340 // Approximation for the inverse error function from
341 //   Giles, M., "Approximating the erfinv function".
342 // The approximation has the form:
343 //   w = -log((1 - x) * (1 + x))
344 //   if ( w < 5 ) {
345 //     w = w - 2.5
346 //     p = sum_{i=1}^n lq[i]*w^i
347 //   } else {
348 //     w = sqrt(w) - 3
349 //     p = sum_{i=1}^n gq[i]*w^i
350 //   }
351 //   return p*x
ErfInv32(XlaOp x)352 XlaOp ErfInv32(XlaOp x) {
353   constexpr int kDegree = 9;
354   constexpr std::array<float, 9> w_less_than_5_constants = {
355       2.81022636e-08f,  3.43273939e-07f, -3.5233877e-06f,
356       -4.39150654e-06f, 0.00021858087f,  -0.00125372503f,
357       -0.00417768164f,  0.246640727f,    1.50140941f};
358   constexpr std::array<float, 9> w_greater_than_5_constants = {
359       -0.000200214257f, 0.000100950558f, 0.00134934322f,
360       -0.00367342844f,  0.00573950773f,  -0.0076224613f,
361       0.00943887047f,   1.00167406f,     2.83297682f};
362 
363   // Compute logarithm of (1+arg) using log1p(arg) which is more precise than
364   // log(1+arg) when arg is close to zero. For more details, see
365   // https://en.cppreference.com/w/cpp/numeric/math/log1p
366   auto w = -Log1p(-x * x);
367 
368   auto lt = Lt(w, ScalarLike(x, 5.0));
369   auto coefficient = [&](int i) {
370     return Select(lt, FullLike(x, w_less_than_5_constants[i]),
371                   FullLike(x, w_greater_than_5_constants[i]));
372   };
373   w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0));
374   auto p = coefficient(0);
375   for (int i = 1; i < kDegree; ++i) {
376     p = coefficient(i) + p * w;
377   }
378 
379   // Result modulo edge cases.
380   XlaOp result = p * x;
381 
382   // Handle edge cases, namely erfinv(+/-1) = +/-inf.  (The above computation is
383   // indeterminate, and can give nan or -/+inf.)
384   auto& b = *x.builder();
385   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
386     TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x));
387     return Select(Eq(Abs(x), ScalarLike(x, 1)),
388                   x * MaxValue(&b, shape.element_type()), result);
389   });
390 }
391 
ErfInv64(XlaOp x)392 XlaOp ErfInv64(XlaOp x) {
393   constexpr std::array<double, 23> w_less_than_6_25_constants = {
394       -3.6444120640178196996e-21, -1.685059138182016589e-19,
395       1.2858480715256400167e-18,  1.115787767802518096e-17,
396       -1.333171662854620906e-16,  2.0972767875968561637e-17,
397       6.6376381343583238325e-15,  -4.0545662729752068639e-14,
398       -8.1519341976054721522e-14, 2.6335093153082322977e-12,
399       -1.2975133253453532498e-11, -5.4154120542946279317e-11,
400       1.051212273321532285e-09,   -4.1126339803469836976e-09,
401       -2.9070369957882005086e-08, 4.2347877827932403518e-07,
402       -1.3654692000834678645e-06, -1.3882523362786468719e-05,
403       0.0001867342080340571352,   -0.00074070253416626697512,
404       -0.0060336708714301490533,  0.24015818242558961693,
405       1.6536545626831027356};
406   constexpr std::array<double, 19> w_less_than_16_constants = {
407       2.2137376921775787049e-09,  9.0756561938885390979e-08,
408       -2.7517406297064545428e-07, 1.8239629214389227755e-08,
409       1.5027403968909827627e-06,  -4.013867526981545969e-06,
410       2.9234449089955446044e-06,  1.2475304481671778723e-05,
411       -4.7318229009055733981e-05, 6.8284851459573175448e-05,
412       2.4031110387097893999e-05,  -0.0003550375203628474796,
413       0.00095328937973738049703,  -0.0016882755560235047313,
414       0.0024914420961078508066,   -0.0037512085075692412107,
415       0.005370914553590063617,    1.0052589676941592334,
416       3.0838856104922207635,
417   };
418   constexpr std::array<double, 17> w_greater_than_16_constants = {
419       -2.7109920616438573243e-11, -2.5556418169965252055e-10,
420       1.5076572693500548083e-09,  -3.7894654401267369937e-09,
421       7.6157012080783393804e-09,  -1.4960026627149240478e-08,
422       2.9147953450901080826e-08,  -6.7711997758452339498e-08,
423       2.2900482228026654717e-07,  -9.9298272942317002539e-07,
424       4.5260625972231537039e-06,  -1.9681778105531670567e-05,
425       7.5995277030017761139e-05,  -0.00021503011930044477347,
426       -0.00013871931833623122026, 1.0103004648645343977,
427       4.8499064014085844221,
428   };
429   // Compute logarithm of (1+arg) using log1p(arg) which is more precise than
430   // log(1+arg) when arg is close to zero. For more details, see
431   // https://en.cppreference.com/w/cpp/numeric/math/log1p
432   auto w = -Log1p(-x * x);
433 
434   auto lt_6_25 = Lt(w, ScalarLike(x, 6.25));
435   auto lt_16 = Lt(w, ScalarLike(x, 16));
436   auto coefficient = [&](int i) {
437     auto c = FullLike(x, w_less_than_6_25_constants[i]);
438     if (i < 19) {
439       c = Select(lt_6_25, c, FullLike(x, w_less_than_16_constants[i]));
440     }
441     if (i < 17) {
442       c = Select(lt_16, c, FullLike(x, w_greater_than_16_constants[i]));
443     }
444     return c;
445   };
446   auto sqrt_w = Sqrt(w);
447   w = Select(lt_6_25, w - ScalarLike(x, 3.125),
448              sqrt_w - Select(lt_16, ScalarLike(x, 3.25), ScalarLike(x, 5.0)));
449   auto p = coefficient(0);
450   for (int i = 1; i < 17; ++i) {
451     p = coefficient(i) + p * w;
452   }
453   for (int i = 17; i < 19; ++i) {
454     p = Select(lt_16, coefficient(i) + p * w, p);
455   }
456   for (int i = 19; i < 23; ++i) {
457     p = Select(lt_6_25, coefficient(i) + p * w, p);
458   }
459   // Result modulo edge cases.
460   XlaOp result = p * x;
461 
462   // Handle edge cases, namely erfinv(+/-1) = +/-inf.  (The above computation is
463   // indeterminate, and can give nan or -/+inf.)
464   auto& b = *x.builder();
465   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
466     TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x));
467     return Select(Eq(Abs(x), ScalarLike(x, 1)),
468                   x * MaxValue(&b, shape.element_type()), result);
469   });
470 }
471 
472 }  // namespace
473 
ErfInv(XlaOp x)474 XlaOp ErfInv(XlaOp x) {
475   auto& b = *x.builder();
476   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
477     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("ErfInv", x));
478     TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
479     if (shape.element_type() == F64) {
480       return ErfInv64(x);
481     }
482     return DoWithUpcastToF32(x, {BF16, F16},
483                              [](XlaOp x) { return ErfInv32(x); });
484   });
485 }
486 
487 namespace {
488 // Coefficients for the Lanczos approximation of the gamma function. The
489 // coefficients are uniquely determined by the choice of g and n (kLanczosGamma
490 // and kLanczosCoefficients.size() + 1). The coefficients below correspond to
491 // [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7,
492 // 9] seemed to be the least sensitive to the quality of the log function. In
493 // particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
494 // for a particularly inaccurate log function.
495 static constexpr double kLanczosGamma = 7;  // aka g
496 static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
497 static constexpr std::array<double, 8> kLanczosCoefficients = {
498     676.520368121885098567009190444019, -1259.13921672240287047156078755283,
499     771.3234287776530788486528258894,   -176.61502916214059906584551354,
500     12.507343278686904814458936853,     -0.13857109526572011689554707,
501     9.984369578019570859563e-6,         1.50563273514931155834e-7};
502 }  // namespace
503 
504 // Compute the Lgamma function using Lanczos' approximation from "A Precision
505 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
506 // series B. Vol. 1:
507 // lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
508 // t(z) = z + kLanczosGamma + 1/2
509 // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
Lgamma(XlaOp input)510 XlaOp Lgamma(XlaOp input) {
511   auto do_it = [](XlaOp input) {
512     XlaOp one_half = ScalarLike(input, 0.5);
513     XlaOp one = ScalarLike(input, 1);
514 
515     XlaOp pi = ScalarLike(input, M_PI);
516     XlaOp log_pi = ScalarLike(input, std::log(M_PI));
517     XlaOp log_sqrt_two_pi =
518         ScalarLike(input, (std::log(2) + std::log(M_PI)) / 2);
519 
520     XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5);
521     XlaOp log_lanczos_gamma_plus_one_half =
522         ScalarLike(input, std::log(kLanczosGamma + 0.5));
523 
524     XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff);
525 
526     // If the input is less than 0.5 use Euler's reflection formula:
527     // gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
528     XlaOp need_to_reflect = Lt(input, one_half);
529     XlaOp z = Select(need_to_reflect, -input, input - one);
530 
531     XlaOp x = base_lanczos_coeff;
532     for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
533       XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]);
534       XlaOp index = ScalarLike(input, i);
535       x = x + lanczos_coefficient / (z + index + one);
536     }
537 
538     // To improve accuracy on platforms with less-precise log implementations,
539     // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
540     // the device.
541     // log(t) = log(kLanczosGamma + 0.5 + z)
542     //        = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
543     XlaOp t = lanczos_gamma_plus_one_half + z;
544     XlaOp log_t = log_lanczos_gamma_plus_one_half +
545                   Log1p(z / lanczos_gamma_plus_one_half);
546 
547     // Compute the final result (modulo reflection).  t(z) may be large, and we
548     // need to be careful not to overflow to infinity in the first term of
549     //
550     //   (z + 1/2) * log(t(z)) - t(z).
551     //
552     // Therefore we compute this as
553     //
554     //   (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
555     //
556     XlaOp log_y = log_sqrt_two_pi + (z + one_half - t / log_t) * log_t + Log(x);
557 
558     // Compute the reflected value, used when x < 0.5:
559     //
560     //   lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
561     //
562     // (The abs is because lgamma is the log of the absolute value of the gamma
563     // function.)
564     //
565     // We have to be careful when computing the final term above. gamma(x) goes
566     // to +/-inf at every integer x < 0, and this is controlled by the
567     // sin(pi * x) term.  The slope is large, so precision is particularly
568     // important.
569     //
570     // Because abs(sin(pi * x)) has period 1, we can equivalently use
571     // abs(sin(pi * frac(x))), where frac(x) is the fractional part of x.  This
572     // is more numerically accurate: It doesn't overflow to inf like pi * x can,
573     // and if x is an integer, it evaluates to 0 exactly, which is significant
574     // because we then take the log of this value, and log(0) is inf.
575     //
576     // We don't have a frac(x) primitive in XLA and computing it is tricky, but
577     // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for
578     // our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
579     //
580     // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
581     // to 1.  To remedy this, we can use the fact that sin(pi * x) in the domain
582     // [0, 1] is symmetric across the line Y=0.5.
583     //
584     XlaOp abs_input = Abs(input);
585     XlaOp abs_frac_input = abs_input - Floor(abs_input);
586     // Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve
587     // precision of pi * abs_frac_input for values of abs_frac_input close to 1.
588     XlaOp reduced_frac_input =
589         Select(Gt(abs_frac_input, ScalarLike(abs_frac_input, 0.5)),
590                ScalarLike(abs_frac_input, 1) - abs_frac_input, abs_frac_input);
591     XlaOp reflection_denom = Log(Sin(pi * reduced_frac_input));
592 
593     // Avoid computing -inf - inf, which is nan.  If reflection_denom is +/-inf,
594     // then it "wins" and the result is +/-inf.
595     XlaOp reflection =
596         Select(IsFinite(reflection_denom), log_pi - reflection_denom - log_y,
597                -reflection_denom);
598     XlaOp result = Select(need_to_reflect, reflection, log_y);
599 
600     // lgamma(+/-inf) = +inf.
601     XlaOp inf_bcast = FullLike(input, std::numeric_limits<float>::infinity());
602     return Select(IsInf(input), inf_bcast, result);
603   };
604 
605   auto& b = *input.builder();
606   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
607     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Lgamma", input));
608     // F16 and BF16 don't provide sufficient precision for intermediate results
609     // here (although it's better than you might expect!), so do the
610     // computations in F32.
611     return DoWithUpcastToF32(input, {BF16, F16}, do_it);
612   });
613 }
614 
615 // Computes an approximation of the lbeta function which is equivalent to
616 // log(abs(Beta(a, b))) but avoids overflow by computing it with lgamma.
Lbeta(XlaOp a,XlaOp b)617 static XlaOp Lbeta(XlaOp a, XlaOp b) {
618   // Beta(a, b) can be computed using Gamma as per
619   // http://dlmf.nist.gov/5.12.E1 as follows:
620   //   Beta(a, b) = (Gamma(a) * Gamma(b)) / Gamma(a + b)
621   //
622   // To avoid overflow, we compute in the log domain.
623   //
624   // As per http://dlmf.nist.gov/4.8.E2 we can transform:
625   //   Log(a * b)
626   // into:
627   //   Log(a) + Log(b)
628   //
629   // Likewise, per https://dlmf.nist.gov/4.8.E4, we can turn:
630   //   Log(a - b)
631   // into:
632   //   Log(a) - Log(b)
633   //
634   // This means that we can compute Log(Beta(a, b)) by:
635   //   Log(Gamma(a)) + Log(Gamma(b)) - Log(Gamma(a + b))
636   return Lgamma(a) + Lgamma(b) - Lgamma(a + b);
637 }
638 
639 // Compute the Digamma function using Lanczos' approximation from "A Precision
640 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
641 // series B. Vol. 1:
642 // digamma(z + 1) = log(t(z)) + A'(z) / A(z) - kLanczosGamma / t(z)
643 // t(z) = z + kLanczosGamma + 1/2
644 // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
645 // A'(z) = sigma(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
Digamma(XlaOp input)646 XlaOp Digamma(XlaOp input) {
647   auto do_it = [](XlaOp input) {
648     XlaOp zero = ScalarLike(input, 0);
649     XlaOp one_half = ScalarLike(input, 0.5);
650     XlaOp one = ScalarLike(input, 1);
651 
652     XlaOp pi = ScalarLike(input, M_PI);
653 
654     XlaOp lanczos_gamma = ScalarLike(input, kLanczosGamma);
655     XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5);
656     XlaOp log_lanczos_gamma_plus_one_half =
657         ScalarLike(input, std::log(kLanczosGamma + 0.5));
658 
659     XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff);
660 
661     // If the input is less than 0.5 use Euler's reflection formula:
662     // digamma(x) = digamma(1 - x) - pi * cot(pi * x)
663     XlaOp need_to_reflect = Lt(input, one_half);
664     XlaOp z = Select(need_to_reflect, -input, input - one);
665 
666     XlaOp num = zero;
667     XlaOp denom = base_lanczos_coeff;
668     for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
669       XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]);
670       XlaOp index = ScalarLike(input, i);
671       num = num - lanczos_coefficient / ((z + index + one) * (z + index + one));
672       denom = denom + lanczos_coefficient / (z + index + one);
673     }
674 
675     // To improve accuracy on platforms with less-precise log implementations,
676     // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
677     // the device.
678     // log(t) = log(kLanczosGamma + 0.5 + z)
679     //        = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
680     XlaOp t = lanczos_gamma_plus_one_half + z;
681     XlaOp log_t = log_lanczos_gamma_plus_one_half +
682                   Log1p(z / lanczos_gamma_plus_one_half);
683 
684     XlaOp y = log_t + num / denom - lanczos_gamma / t;
685 
686     // We need to be careful how we compute cot(pi * input) below: For
687     // near-integral values of `input`, pi * input can lose precision.
688     //
689     // Input is already known to be less than 0.5 (otherwise we don't have to
690     // reflect).  We shift values smaller than -0.5 into the range [-.5, .5] to
691     // increase precision of pi * input and the resulting cotangent.
692     XlaOp reduced_input = input + Abs(Floor(input + ScalarLike(input, 0.5)));
693     XlaOp reflection =
694         y - pi * Cos(pi * reduced_input) / Sin(pi * reduced_input);
695     XlaOp real_result = Select(need_to_reflect, reflection, y);
696 
697     // Digamma has poles at negative integers and zero; return nan for those.
698     return Select(And(Le(input, zero), Eq(input, Floor(input))),
699                   FullLike(input, std::numeric_limits<float>::quiet_NaN()),
700                   real_result);
701   };
702 
703   auto& b = *input.builder();
704   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
705     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input));
706     return DoWithUpcastToF32(input, {BF16, F16}, do_it);
707   });
708 }
709 
710 // Incomplete gamma functions
711 
712 namespace {
713 
714 enum kIgammaMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE };
715 
716 // Helper function for computing Igamma using a power series.
717 template <kIgammaMode mode>
IgammaSeries(XlaOp ax,XlaOp x,XlaOp a,XlaOp enabled,xla::PrimitiveType type)718 XlaOp IgammaSeries(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
719                    xla::PrimitiveType type) {
720   // vals: (enabled, r, c, ans, x)
721   // 'enabled' is a predication mask that says for which elements we should
722   // execute the loop body. Disabled elements have no effect in the loop body.
723   // TODO(phawkins): in general this isn't an optimal implementation on any
724   // backend. For example, on GPU, we should probably vectorize to the warp
725   // size, and then run independent loops for each warp's worth of
726   // data.
727   auto cond = [&](absl::Span<const XlaOp> vals,
728                   XlaBuilder* builder) -> StatusOr<XlaOp> {
729     XlaOp enabled = vals[0];
730     return Any(enabled);
731   };
732   auto body = [&](absl::Span<const XlaOp> vals,
733                   XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
734     XlaOp enabled = vals[0];
735     XlaOp r = vals[1];
736     XlaOp c = vals[2];
737     XlaOp ans = vals[3];
738     XlaOp x = vals[4];
739     XlaOp dc_da = vals[5];
740     XlaOp dans_da = vals[6];
741 
742     r = r + ScalarLike(r, 1);
743     dc_da = dc_da * (x / r) + (ScalarLike(r, -1) * c * x) / (r * r);
744     dans_da = dans_da + dc_da;
745     c = c * (x / r);
746     ans = ans + c;
747     XlaOp conditional;
748     if (mode == VALUE) {
749       conditional = And(enabled, Gt(c / ans, Epsilon(builder, type)));
750     } else {
751       conditional =
752           And(enabled, Gt(Abs(dc_da / dans_da), Epsilon(builder, type)));
753     }
754 
755     return std::vector<XlaOp>{
756         conditional,
757         Select(enabled, r, vals[1]),
758         Select(enabled, c, vals[2]),
759         Select(enabled, ans, vals[3]),
760         Select(enabled, x, vals[4]),
761         Select(enabled, dc_da, vals[5]),
762         Select(enabled, dans_da, vals[6]),
763     };
764   };
765   auto& b = *ax.builder();
766   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
767     std::vector<XlaOp> vals = {
768         enabled,        a, FullLike(a, 1), FullLike(a, 1), x, FullLike(a, 0),
769         FullLike(a, 0),
770     };
771 
772     TF_ASSIGN_OR_RETURN(vals, WhileLoopHelper(cond, body, vals, "igamma", &b));
773     XlaOp ans = vals[3];
774     XlaOp dans_da = vals[6];
775     if (mode == VALUE) {
776       return (ans * ax) / a;
777     }
778 
779     XlaOp dlogax_da = Log(x) - Digamma(a + ScalarLike(a, 1));
780 
781     switch (mode) {
782       case DERIVATIVE:
783         return ax * (ans * dlogax_da + dans_da) / a;
784       case SAMPLE_DERIVATIVE:
785       default:
786         return -(dans_da + ans * dlogax_da) * x / a;
787     }
788   });
789 }
790 
791 // Helper function for computing Igammac using a continued fraction.
792 template <kIgammaMode mode>
IgammacContinuedFraction(XlaOp ax,XlaOp x,XlaOp a,XlaOp enabled,xla::PrimitiveType type)793 XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
794                                xla::PrimitiveType type) {
795   // vals: enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2
796   auto cond = [&](absl::Span<const XlaOp> vals,
797                   XlaBuilder* builder) -> StatusOr<XlaOp> {
798     XlaOp enabled = vals[0];
799     XlaOp c = vals[5];
800     return And(Lt(c, ScalarLike(c, 2000)), Any(enabled));
801   };
802   auto body = [&](absl::Span<const XlaOp> vals,
803                   XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
804     XlaOp enabled = vals[0];
805     XlaOp ans = vals[1];
806     XlaOp t = vals[2];
807     XlaOp y = vals[3];
808     XlaOp z = vals[4];
809     XlaOp c = vals[5];
810     XlaOp pkm1 = vals[6];
811     XlaOp qkm1 = vals[7];
812     XlaOp pkm2 = vals[8];
813     XlaOp qkm2 = vals[9];
814 
815     XlaOp dpkm2_da = vals[10];
816     XlaOp dqkm2_da = vals[11];
817     XlaOp dpkm1_da = vals[12];
818     XlaOp dqkm1_da = vals[13];
819     XlaOp dans_da = vals[14];
820 
821     c = c + ScalarLike(c, 1);
822     y = y + ScalarLike(y, 1);
823     z = z + ScalarLike(z, 2);
824     XlaOp yc = y * c;
825     XlaOp pk = pkm1 * z - pkm2 * yc;
826     XlaOp qk = qkm1 * z - qkm2 * yc;
827     XlaOp qk_is_nonzero = Ne(qk, ScalarLike(qk, 0));
828     XlaOp r = pk / qk;
829 
830     t = Select(qk_is_nonzero, Abs((ans - r) / r), FullLike(t, 1));
831     ans = Select(qk_is_nonzero, r, ans);
832 
833     XlaOp dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c;
834     XlaOp dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c;
835     XlaOp dans_da_new =
836         Select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da);
837     XlaOp grad_conditional =
838         Select(qk_is_nonzero, Abs(dans_da_new - dans_da), FullLike(dans_da, 1));
839 
840     pkm2 = pkm1;
841     pkm1 = pk;
842     qkm2 = qkm1;
843     qkm1 = qk;
844 
845     dpkm2_da = dpkm1_da;
846     dqkm2_da = dqkm1_da;
847     dpkm1_da = dpk_da;
848     dqkm1_da = dqk_da;
849 
850     XlaOp rescale = Gt(Abs(pk), Reciprocal(Epsilon(builder, type)));
851     pkm2 = Select(rescale, pkm2 * Epsilon(builder, type), pkm2);
852     pkm1 = Select(rescale, pkm1 * Epsilon(builder, type), pkm1);
853     qkm2 = Select(rescale, qkm2 * Epsilon(builder, type), qkm2);
854     qkm1 = Select(rescale, qkm1 * Epsilon(builder, type), qkm1);
855 
856     dpkm2_da = Select(rescale, dpkm2_da * Epsilon(builder, type), dpkm2_da);
857     dqkm2_da = Select(rescale, dqkm2_da * Epsilon(builder, type), dqkm2_da);
858     dpkm1_da = Select(rescale, dpkm1_da * Epsilon(builder, type), dpkm1_da);
859     dqkm1_da = Select(rescale, dqkm1_da * Epsilon(builder, type), dqkm1_da);
860 
861     XlaOp conditional;
862     if (mode == VALUE) {
863       conditional = And(enabled, Gt(t, Epsilon(builder, type)));
864     } else {
865       conditional = And(enabled, Gt(grad_conditional, Epsilon(builder, type)));
866     }
867 
868     return std::vector<XlaOp>{conditional,
869                               Select(enabled, ans, vals[1]),
870                               Select(enabled, t, vals[2]),
871                               Select(enabled, y, vals[3]),
872                               Select(enabled, z, vals[4]),
873                               c,
874                               Select(enabled, pkm1, vals[6]),
875                               Select(enabled, qkm1, vals[7]),
876                               Select(enabled, pkm2, vals[8]),
877                               Select(enabled, qkm2, vals[9]),
878                               Select(enabled, dpkm2_da, vals[10]),
879                               Select(enabled, dqkm2_da, vals[11]),
880                               Select(enabled, dpkm1_da, vals[12]),
881                               Select(enabled, dqkm1_da, vals[13]),
882                               Select(enabled, dans_da_new, vals[14])};
883   };
884 
885   auto& b = *ax.builder();
886   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
887     XlaOp y = ScalarLike(a, 1) - a;
888     XlaOp z = x + y + ScalarLike(x, 1);
889     XlaOp c = ScalarLike(x, 0);
890     XlaOp pkm2 = FullLike(x, 1);
891     XlaOp qkm2 = x;
892     XlaOp pkm1 = x + ScalarLike(x, 1);
893     XlaOp qkm1 = z * x;
894     XlaOp ans = pkm1 / qkm1;
895     XlaOp t = FullLike(x, 1);
896     XlaOp dpkm2_da = FullLike(x, 0);
897     XlaOp dqkm2_da = FullLike(x, 0);
898     XlaOp dpkm1_da = FullLike(x, 0);
899     XlaOp dqkm1_da = -x;
900     XlaOp dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
901     std::vector<XlaOp> vals = {enabled,  ans,      t,        y,        z,
902                                c,        pkm1,     qkm1,     pkm2,     qkm2,
903                                dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da};
904 
905     TF_ASSIGN_OR_RETURN(vals, WhileLoopHelper(cond, body, vals, "igammac", &b));
906     ans = vals[1];
907     if (mode == VALUE) {
908       return ans * ax;
909     }
910 
911     dans_da = vals[14];
912     XlaOp dlogax_da = Log(x) - Digamma(a);
913 
914     switch (mode) {
915       case DERIVATIVE:
916         return ax * (ans * dlogax_da + dans_da);
917       case SAMPLE_DERIVATIVE:
918       default:
919         return -(dans_da + ans * dlogax_da) * x;
920     }
921   });
922 }
923 
924 }  // namespace
925 
Igamma(XlaOp a,XlaOp x)926 XlaOp Igamma(XlaOp a, XlaOp x) {
927   auto& b = *a.builder();
928   auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
929     XlaOp is_nan = Or(IsNan(a), IsNan(x));
930     XlaOp x_is_zero = Eq(x, ScalarLike(x, 0));
931     XlaOp x_is_infinity =
932         Eq(x, ScalarLike(x, std::numeric_limits<float>::infinity()));
933     XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
934     XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a));
935     XlaOp ax = a * Log(x) - x - Lgamma(a);
936     XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
937     ax = Exp(ax);
938     XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan));
939     const double nan = std::numeric_limits<double>::quiet_NaN();
940     XlaOp output = Select(
941         use_igammac,
942         ScalarLike(a, 1) - IgammacContinuedFraction<VALUE>(
943                                ax, x, a, And(enabled, use_igammac), type),
944         IgammaSeries<VALUE>(ax, x, a, And(enabled, Not(use_igammac)), type));
945     output = Select(x_is_zero, ZerosLike(output), output);
946     output = Select(x_is_infinity, FullLike(output, 1), output);
947     output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
948     return output;
949   };
950   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
951     TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
952     TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
953     if (a_shape != x_shape) {
954       return InvalidArgument(
955           "Arguments to Igamma must have equal shapes and types; got %s and %s",
956           a_shape.ToString(), x_shape.ToString());
957     }
958     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a));
959     PrimitiveType a_x_type = a_shape.element_type();
960     bool needs_upcast =
961         a_shape.element_type() == F16 || a_shape.element_type() == BF16;
962 
963     if (needs_upcast) {
964       a = ConvertElementType(a, F32);
965       x = ConvertElementType(x, F32);
966       a_x_type = F32;
967     }
968     XlaOp result = doit(a, x, a_x_type);
969     if (needs_upcast) {
970       result = ConvertElementType(result, a_shape.element_type());
971     }
972     return result;
973   });
974 }
975 
IgammaGradA(XlaOp a,XlaOp x)976 XlaOp IgammaGradA(XlaOp a, XlaOp x) {
977   auto& b = *a.builder();
978   auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
979     XlaOp is_nan = Or(IsNan(a), IsNan(x));
980     XlaOp x_is_zero = Eq(x, ScalarLike(x, 0));
981     XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
982     XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a));
983     XlaOp ax = a * Log(x) - x - Lgamma(a);
984     XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
985     ax = Exp(ax);
986     XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan));
987     const double nan = std::numeric_limits<double>::quiet_NaN();
988     XlaOp output = Select(use_igammac,
989                           -IgammacContinuedFraction<DERIVATIVE>(
990                               ax, x, a, And(enabled, use_igammac), type),
991                           IgammaSeries<DERIVATIVE>(
992                               ax, x, a, And(enabled, Not(use_igammac)), type));
993     output = Select(x_is_zero, ZerosLike(output), output);
994     output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
995     return output;
996   };
997   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
998     TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
999     TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
1000     if (a_shape != x_shape) {
1001       return InvalidArgument(
1002           "Arguments to IgammaGradA must have equal shapes and types; got %s "
1003           "and %s",
1004           a_shape.ToString(), x_shape.ToString());
1005     }
1006     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
1007     bool needs_upcast =
1008         a_shape.element_type() == F16 || a_shape.element_type() == BF16;
1009 
1010     if (needs_upcast) {
1011       a = ConvertElementType(a, F32);
1012       x = ConvertElementType(x, F32);
1013     }
1014     XlaOp result = doit(a, x, a_shape.element_type());
1015     if (needs_upcast) {
1016       result = ConvertElementType(result, a_shape.element_type());
1017     }
1018     return result;
1019   });
1020 }
1021 
1022 // Gradient of Gamma sample from Gamma(a, 1) with respect to `a`.
RandomGammaGrad(XlaOp a,XlaOp x)1023 XlaOp RandomGammaGrad(XlaOp a, XlaOp x) {
1024   auto& b = *a.builder();
1025   auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
1026     XlaOp is_nan = Or(IsNan(a), IsNan(x));
1027     XlaOp x_is_zero = Eq(x, ScalarLike(x, 0));
1028     XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
1029     XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a));
1030     XlaOp ax = a * Log(x) - x - Lgamma(a);
1031     XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
1032     ax = Exp(ax);
1033     XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan));
1034     const double nan = std::numeric_limits<double>::quiet_NaN();
1035     XlaOp output = Select(use_igammac,
1036                           -IgammacContinuedFraction<SAMPLE_DERIVATIVE>(
1037                               ax, x, a, And(enabled, use_igammac), type),
1038                           IgammaSeries<SAMPLE_DERIVATIVE>(
1039                               ax, x, a, And(enabled, Not(use_igammac)), type));
1040     output = Select(x_is_zero, ZerosLike(output), output);
1041     output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
1042     return output;
1043   };
1044   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1045     TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
1046     TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
1047     if (a_shape != x_shape) {
1048       return InvalidArgument(
1049           "Arguments to RandomGammaGrad must have equal shapes and types; got "
1050           "%s and %s",
1051           a_shape.ToString(), x_shape.ToString());
1052     }
1053     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RandomGammaGrad", a));
1054     bool needs_upcast =
1055         a_shape.element_type() == F16 || a_shape.element_type() == BF16;
1056 
1057     if (needs_upcast) {
1058       a = ConvertElementType(a, F32);
1059       x = ConvertElementType(x, F32);
1060     }
1061     XlaOp result = doit(a, x, a_shape.element_type());
1062     if (needs_upcast) {
1063       result = ConvertElementType(result, a_shape.element_type());
1064     }
1065     return result;
1066   });
1067 }
1068 
Igammac(XlaOp a,XlaOp x)1069 XlaOp Igammac(XlaOp a, XlaOp x) {
1070   auto& b = *a.builder();
1071   auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
1072     XlaOp out_of_range = Or(Le(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
1073     XlaOp use_igamma = Or(Lt(x, ScalarLike(x, 1)), Lt(x, a));
1074     XlaOp ax = a * Log(x) - x - Lgamma(a);
1075     XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
1076     XlaOp enabled = Not(Or(out_of_range, underflow));
1077     ax = Exp(ax);
1078     XlaOp result =
1079         Select(use_igamma,
1080                ScalarLike(a, 1) - IgammaSeries<VALUE>(
1081                                       ax, x, a, And(enabled, use_igamma), type),
1082                IgammacContinuedFraction<VALUE>(
1083                    ax, x, a, And(enabled, Not(use_igamma)), type));
1084     XlaOp x_is_infinity =
1085         Eq(x, ScalarLike(x, std::numeric_limits<float>::infinity()));
1086     result = Select(x_is_infinity, ZerosLike(result), result);
1087     return Select(out_of_range, FullLike(a, 1), result);
1088   };
1089   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1090     TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
1091     TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
1092     if (a_shape != x_shape) {
1093       return InvalidArgument(
1094           "Arguments to Igammac must have equal shapes and types; "
1095           "got %s and %s",
1096           a_shape.ToString(), x_shape.ToString());
1097     }
1098     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igammac", a));
1099     PrimitiveType a_x_type = a_shape.element_type();
1100     bool needs_upcast =
1101         a_shape.element_type() == F16 || a_shape.element_type() == BF16;
1102 
1103     if (needs_upcast) {
1104       a = ConvertElementType(a, F32);
1105       x = ConvertElementType(x, F32);
1106       a_x_type = F32;
1107     }
1108     XlaOp result = doit(a, x, a_x_type);
1109     if (needs_upcast) {
1110       result = ConvertElementType(result, a_shape.element_type());
1111     }
1112     return result;
1113   });
1114 }
1115 
1116 // Implements Banker's rounding: numbers that are equidistant between two
1117 // integers are rounded towards even.
RoundToEven(XlaOp x)1118 XlaOp RoundToEven(XlaOp x) {
1119   auto& b = *x.builder();
1120   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1121     // Reject non-real non-fp inputs (What does it even mean to round a complex
1122     // number?  Do you round each component equally?  In that case, you should
1123     // just ask for that explicitly.)
1124     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RoundToEven", x));
1125 
1126     auto half = ScalarLike(x, 0.5);
1127     auto one = ScalarLike(x, 1.0);
1128     auto two = ScalarLike(x, 2.0);
1129 
1130     auto round_val = Floor(x);
1131     auto fraction = x - round_val;
1132     auto nearest_even_int = round_val - two * Floor(half * x);
1133     auto is_odd = Eq(nearest_even_int, one);
1134     return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)),
1135                   round_val + one, round_val);
1136   });
1137 }
1138 
1139 // Trigonometric functions.
1140 
1141 // acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1
1142 //           pi                                if x == -1
1143 // For complex:
1144 // acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x))))
Acos(XlaOp x)1145 XlaOp Acos(XlaOp x) {
1146   XlaBuilder* b = x.builder();
1147   return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1148     TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
1149 
1150     if (primitive_util::IsComplexType(shape.element_type())) {
1151       auto one = ScalarLike(x, 1);
1152       auto imag_one = Complex(
1153           Zero(b, primitive_util::ComplexComponentType(shape.element_type())),
1154           One(b, primitive_util::ComplexComponentType(shape.element_type())));
1155 
1156       auto result =
1157           Neg(imag_one * Log(x + imag_one * Sqrt((one + x) * (one - x))));
1158       return result;
1159     }
1160     return Select(Ne(x, FullLike(x, -1)),
1161                   ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x),
1162                                              ScalarLike(x, 1.0) + x),
1163                   FullLike(x, M_PI));
1164   });
1165 }
1166 
1167 // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
Asin(XlaOp x)1168 XlaOp Asin(XlaOp x) {
1169   return ScalarLike(x, 2.0) *
1170          Atan2(x, ScalarLike(x, 1.0) + Sqrt(ScalarLike(x, 1.0) - x * x));
1171 }
1172 
Atan(XlaOp x)1173 XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); }
1174 
Tan(XlaOp x)1175 XlaOp Tan(XlaOp x) {
1176   return DoWithUpcastToF32(x, {F16}, [](XlaOp x) { return Sin(x) / Cos(x); });
1177 }
1178 
1179 // Hyperbolic trigonometric functions.
1180 
1181 // acosh(x) = log(x + sqrt(x^2 - 1))      if x >= -1
1182 //          = log(x + sqrt((x+1)*(x-1)))
1183 // acosh(x) = nan                         if x < -1
1184 //
1185 // If x^2 will overflow, we approximate sqrt(x^2 - 1) == x and compute as
1186 // log(2*x) = log(2) + log(x).  (Note this works because negative x never
1187 // overflows; x < -1 simply yields nan.  This is quite different than asinh!)
Acosh(XlaOp x)1188 XlaOp Acosh(XlaOp x) {
1189   XlaBuilder* b = x.builder();
1190   return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1191     TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
1192 
1193     auto one = ScalarLike(x, 1);
1194     auto neg_one = ScalarLike(x, -1);
1195     auto nan = FullLike(x, std::numeric_limits<float>::quiet_NaN());
1196 
1197     // return
1198     //
1199     //   nan                        if x < -1
1200     //   log(x) + log(2)            if x >= sqrt_max_value
1201     //   log(x + sqrt((x+1)*(x-1))) otherwise
1202     //
1203     // TODO(jlebar): For now, we ignore the question of overflow if x is a
1204     // complex type, because we don't yet have exhaustive tests for complex trig
1205     // functions.
1206     auto naive_result = Log(x + Sqrt((x + one) * (x - one)));
1207     if (primitive_util::IsComplexType(shape.element_type())) {
1208       return naive_result;
1209     }
1210     auto overflow_result = Log(x) + Log(ScalarLike(x, 2));
1211 
1212     auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type()));
1213     return Select(Lt(x, neg_one), nan,
1214                   Select(Ge(x, sqrt_max_value), overflow_result, naive_result));
1215   });
1216 }
1217 
1218 // asinh(x) = log(x + sqrt(x^2 + 1))
1219 //
1220 // If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1)
1221 // as 2*x and return log(2) + log(x).
1222 //
1223 // If x is negative, the above would give us some trouble; we can't approximate
1224 // the result as x + abs(x) = 0!  But we're saved by the fact that asinh(-x) =
1225 // -asinh(x).
Asinh(XlaOp x)1226 XlaOp Asinh(XlaOp x) {
1227   XlaBuilder* b = x.builder();
1228   auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
1229     TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
1230     auto one = ScalarLike(x, 1);
1231 
1232     // Let a = abs(x).  Compute
1233     //
1234     //   y = log(a + sqrt(a*a + 1))  if a < sqrt_max_value, or
1235     //   y = log(a) + log(2)         otherwise
1236     //
1237     // and then return
1238     //
1239     //   y * sign(x).
1240     //
1241     // TODO(jlebar): For now, we ignore the question of overflow if x is a
1242     // complex type, because we don't yet have exhaustive tests for complex trig
1243     // functions.
1244     if (primitive_util::IsComplexType(shape.element_type())) {
1245       return Log(x + Sqrt(x * x + one));
1246     }
1247     // For small x, sqrt(x**2 + 1) will evaluate to 1 due to floating point
1248     // arithmetic. However, we would like to retain the low order term of this,
1249     // which is around 0.5 * x**2 using a binomial expansion.
1250     // Let z = sqrt(a**2 + 1)
1251     // log(a + sqrt(a**2 + 1)) =
1252     // log((a + sqrt(a**2 + 1)) * (1 + sqrt(a**2 + 1)) / (1 + sqrt(a**2 + 1))) =
1253     // log((a + a**2 + 1 + a * z + z) / (1 + z)) =
1254     // log(1 + a + a**2 / (1 + z)) =
1255     // log(1 + a + a ** 2 / (1 + sqrt(a**2 + 1)))
1256     // This rewrite retains the lower order term.
1257     auto a = Abs(x);
1258     auto small_result = Log1p(a + a * a / (one + Sqrt(a * a + one)));
1259     auto naive_result = Log(a + Sqrt(a * a + one));
1260     auto overflow_result = Log(Abs(a)) + Log(ScalarLike(a, 2));
1261     auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type()));
1262     return Sign(x) * Select(Ge(a, sqrt_max_value), overflow_result,
1263                             Select(Le(a, one), small_result, naive_result));
1264   };
1265   // These upcasts are not strictly necessary on all platforms to get within our
1266   // error tolerances, so we could relax this if it ever mattered.
1267   return DoWithUpcastToF32(x, {BF16, F16}, [&](XlaOp x) {
1268     return b->ReportErrorOrReturn(do_it(x));
1269   });
1270 }
1271 
1272 // atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
1273 // atanh(x) = nan                          otherwise
Atanh(XlaOp x)1274 XlaOp Atanh(XlaOp x) {
1275   XlaBuilder* b = x.builder();
1276   auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
1277     TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
1278     auto naive_result = (Log1p(x) - Log1p(-x)) * ScalarLike(x, 0.5);
1279 
1280     // TODO(jlebar): For now, we ignore the nan edge case for complex inputs,
1281     // because we don't yet have exhaustive tests for complex trig functions.
1282     if (primitive_util::IsComplexType(shape.element_type())) {
1283       return naive_result;
1284     }
1285 
1286     auto nan = FullLike(x, std::numeric_limits<float>::quiet_NaN());
1287     return Select(Gt(Abs(x), ScalarLike(x, 1)), nan, naive_result);
1288   };
1289   return DoWithUpcastToF32(x, {BF16}, [&](XlaOp x) {  //
1290     return b->ReportErrorOrReturn(do_it(x));
1291   });
1292 }
1293 
1294 // Cosh(x) = (e^x + e^-x) / 2
1295 //         = e^(x + log(1/2)) + e^(-x + log(1/2)).
1296 //
1297 // The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
1298 // inf.
1299 //
1300 // This incorrectly overflows to inf for two f32 input values, namely
1301 // +/-89.4159851, due to rounding error when computing x +/- log(1/2).  The
1302 // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
1303 // we deem this acceptable.
Cosh(XlaOp x)1304 XlaOp Cosh(XlaOp x) {
1305   return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
1306     auto log_one_half = Log(ScalarLike(x, 0.5));
1307     return Exp(x + log_one_half) + Exp(-x + log_one_half);
1308   });
1309 }
1310 
1311 // Sinh(x) = (e^x - e^-x) / 2
1312 //         = e^(x + log(1/2)) - e^(-x + log(1/2)).
1313 //
1314 // The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
1315 // inf.
1316 //
1317 // This incorrectly overflows to +/-inf for two f32 input values, namely
1318 // +/-89.4159851, due to rounding error when computing x +/- log(1/2).  The
1319 // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
1320 // we deem this acceptable.
Sinh(XlaOp x)1321 XlaOp Sinh(XlaOp x) {
1322   XlaBuilder* b = x.builder();
1323   auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
1324     TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
1325     auto one_half = ScalarLike(x, 0.5);
1326     auto log_one_half = Log(ScalarLike(x, 0.5));
1327     auto large_sinh_result = Exp(x + log_one_half) - Exp(-x + log_one_half);
1328 
1329     if (primitive_util::IsComplexType(shape.element_type())) {
1330       return large_sinh_result;
1331     }
1332 
1333     // Here we use e^x = e^(x / 2) * e^(x / 2). This avoids overflow for large
1334     // values of x.
1335 
1336     // For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in
1337     // 0.
1338     // Rewrite this to avoid that. We use expm1(x) because that preserves the
1339     // first order term of the taylor series of e^x.
1340     // (e^(x) - e^(-x)) / 2. =
1341     // (e^(x) - 1 + 1 - e^(-x)) / 2.
1342     // (expm1(x) + (e^(x) - 1) / e^x) / 2.
1343     // (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2.
1344     auto expm1 = Expm1(x);
1345     auto one = ScalarLike(x, 1.);
1346     auto small_sinh_result = one_half * (expm1 + expm1 / (expm1 + one));
1347     return Select(Lt(Abs(x), one), small_sinh_result, large_sinh_result);
1348   };
1349   return DoWithUpcastToF32(x, {BF16, F16}, [&](XlaOp x) {
1350     return b->ReportErrorOrReturn(do_it(x));
1351   });
1352 }
1353 
MaybeConjugate(XlaOp x,bool conjugate)1354 XlaOp MaybeConjugate(XlaOp x, bool conjugate) {
1355   XlaBuilder* builder = x.builder();
1356   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1357     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
1358     auto perform_conj =
1359         primitive_util::IsComplexType(shape.element_type()) && conjugate;
1360     return perform_conj ? Conj(x) : x;
1361   });
1362 }
1363 
NextAfter(XlaOp from,XlaOp to)1364 XlaOp NextAfter(XlaOp from, XlaOp to) {
1365   auto builder = from.builder();
1366   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1367     TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(from));
1368     int bitwidth = primitive_util::BitWidth(shape.element_type());
1369     auto int_type = primitive_util::UnsignedIntegralTypeForBitWidth(bitwidth);
1370     auto from_as_int = BitcastConvertType(from, int_type);
1371     auto to_as_int = BitcastConvertType(to, int_type);
1372 
1373     // The result is NaN if either "from" or "to" are NaN.
1374     auto from_is_nan = Ne(from, from);
1375     auto to_is_nan = Ne(to, to);
1376     auto nan_input = Or(from_is_nan, to_is_nan);
1377     auto result_for_nan =
1378         Broadcast(ScalarLike(from, std::numeric_limits<double>::quiet_NaN()),
1379                   shape.dimensions());
1380     result_for_nan = BitcastConvertType(result_for_nan, int_type);
1381 
1382     // The sign bit is the MSB.
1383     const int64_t sign_mask = int64{1} << (bitwidth - 1);
1384     // Discard the sign bit to make the result non-negative.
1385     auto from_abs = And(from_as_int, ScalarLike(from_as_int, ~sign_mask));
1386     auto to_abs = And(to_as_int, ScalarLike(to_as_int, ~sign_mask));
1387 
1388     // When both "from" and "to" are equal, the result is "to".
1389     // N.B. It would not make a difference if we chose the result to be "from".
1390     auto from_and_to_are_equal = Eq(from_as_int, to_as_int);
1391     auto result_for_equal = to_as_int;
1392 
1393     // When both "from" and "to" are both 0, the result is "to". This ensures we
1394     // get a zero signed like "to".
1395     auto from_is_zero = Eq(from_abs, ZerosLike(from_abs));
1396     auto to_is_zero = Eq(to_abs, ZerosLike(to_abs));
1397     auto result_for_both_zero = to_as_int;
1398 
1399     auto from_sign = And(from_as_int, ScalarLike(from_as_int, sign_mask));
1400     auto to_sign = And(to_as_int, ScalarLike(to_as_int, sign_mask));
1401 
1402     // If from == 0 && to != 0, we need to return the smallest subnormal number
1403     // signed like "to".
1404     auto result_for_from_zero_to_non_zero =
1405         Or(to_sign, ScalarLike(from_as_int, 1));
1406 
1407     // If the sign of "from" and "to" disagree:
1408     // - we need to make the magnitude of "from" smaller so that it is closer to
1409     //   zero.
1410     //
1411     // Otherwise the signs agree:
1412     // - "from" with a magnitude larger than "to" means we need to make the
1413     //   magnitude smaller.
1414     // - "from" with a magnitude smaller than "to" means we need to make the
1415     //   magnitude larger.
1416     // - "from" with the same magnitude and sign as "to" has already been
1417     //   handled.
1418     auto signs_disagree = Ne(from_sign, to_sign);
1419     auto from_magnitude_larger_than_to = Gt(from_abs, to_abs);
1420     auto result_has_smaller_magnitude =
1421         Or(from_magnitude_larger_than_to, signs_disagree);
1422     auto magnitude_adjustment =
1423         Select(result_has_smaller_magnitude,
1424                Broadcast(ScalarLike(from_as_int, -1), shape.dimensions()),
1425                Broadcast(ScalarLike(from_as_int, 1), shape.dimensions()));
1426     auto result = Add(from_as_int, magnitude_adjustment);
1427     // Handle from == ±0.
1428     result = Select(from_is_zero,
1429                     Select(to_is_zero, result_for_both_zero,
1430                            result_for_from_zero_to_non_zero),
1431                     result);
1432     // Handle from == to.
1433     result = Select(from_and_to_are_equal, result_for_equal, result);
1434     // Handle isnan(from) || isnan(to).
1435     result = Select(nan_input, result_for_nan, result);
1436 
1437     // Cast back to the original type.
1438     return BitcastConvertType(result, shape.element_type());
1439   });
1440 }
1441 
1442 // Computes an approximation to the modified Bessel function of the first kind,
1443 // zeroth order.
1444 // The following implementation follows Cephes' F32 and F64 implementation of
1445 // i0e.
I0eImpl32(XlaOp x)1446 static XlaOp I0eImpl32(XlaOp x) {
1447   static const std::array<float, 18> kI0eCoeffsA{
1448       -1.30002500998624804212E-8f, 6.04699502254191894932E-8f,
1449       -2.67079385394061173391E-7f, 1.11738753912010371815E-6f,
1450       -4.41673835845875056359E-6f, 1.64484480707288970893E-5f,
1451       -5.75419501008210370398E-5f, 1.88502885095841655729E-4f,
1452       -5.76375574538582365885E-4f, 1.63947561694133579842E-3f,
1453       -4.32430999505057594430E-3f, 1.05464603945949983183E-2f,
1454       -2.37374148058994688156E-2f, 4.93052842396707084878E-2f,
1455       -9.49010970480476444210E-2f, 1.71620901522208775349E-1f,
1456       -3.04682672343198398683E-1f, 6.76795274409476084995E-1f};
1457 
1458   static const std::array<float, 7> kI0eCoeffsB{
1459       3.39623202570838634515E-9f, 2.26666899049817806459E-8f,
1460       2.04891858946906374183E-7f, 2.89137052083475648297E-6f,
1461       6.88975834691682398426E-5f, 3.36911647825569408990E-3f,
1462       8.04490411014108831608E-1f};
1463 
1464   x = Abs(x);
1465   auto half = xla::ScalarLike(x, 0.5);
1466   auto two = xla::ScalarLike(x, 2.0);
1467   auto thirty_two = xla::ScalarLike(x, 32.0);
1468   auto result_le_8 =
1469       EvaluateChebyshevPolynomial<float>(half * x - two, kI0eCoeffsA);
1470   auto result_gt_8 =
1471       EvaluateChebyshevPolynomial<float>(thirty_two / x - two, kI0eCoeffsB) /
1472       Sqrt(x);
1473   return Select(Le(x, xla::ScalarLike(x, 8.0)), result_le_8, result_gt_8);
1474 }
1475 
I0eImpl64(XlaOp x)1476 static XlaOp I0eImpl64(XlaOp x) {
1477   static const std::array<double, 30> kI0eCoeffsA{
1478       -4.41534164647933937950E-18, 3.33079451882223809783E-17,
1479       -2.43127984654795469359E-16, 1.71539128555513303061E-15,
1480       -1.16853328779934516808E-14, 7.67618549860493561688E-14,
1481       -4.85644678311192946090E-13, 2.95505266312963983461E-12,
1482       -1.72682629144155570723E-11, 9.67580903537323691224E-11,
1483       -5.18979560163526290666E-10, 2.65982372468238665035E-9,
1484       -1.30002500998624804212E-8,  6.04699502254191894932E-8,
1485       -2.67079385394061173391E-7,  1.11738753912010371815E-6,
1486       -4.41673835845875056359E-6,  1.64484480707288970893E-5,
1487       -5.75419501008210370398E-5,  1.88502885095841655729E-4,
1488       -5.76375574538582365885E-4,  1.63947561694133579842E-3,
1489       -4.32430999505057594430E-3,  1.05464603945949983183E-2,
1490       -2.37374148058994688156E-2,  4.93052842396707084878E-2,
1491       -9.49010970480476444210E-2,  1.71620901522208775349E-1,
1492       -3.04682672343198398683E-1,  6.76795274409476084995E-1};
1493 
1494   static const std::array<double, 25> kI0eCoeffsB{
1495       -7.23318048787475395456E-18, -4.83050448594418207126E-18,
1496       4.46562142029675999901E-17,  3.46122286769746109310E-17,
1497       -2.82762398051658348494E-16, -3.42548561967721913462E-16,
1498       1.77256013305652638360E-15,  3.81168066935262242075E-15,
1499       -9.55484669882830764870E-15, -4.15056934728722208663E-14,
1500       1.54008621752140982691E-14,  3.85277838274214270114E-13,
1501       7.18012445138366623367E-13,  -1.79417853150680611778E-12,
1502       -1.32158118404477131188E-11, -3.14991652796324136454E-11,
1503       1.18891471078464383424E-11,  4.94060238822496958910E-10,
1504       3.39623202570838634515E-9,   2.26666899049817806459E-8,
1505       2.04891858946906374183E-7,   2.89137052083475648297E-6,
1506       6.88975834691682398426E-5,   3.36911647825569408990E-3,
1507       8.04490411014108831608E-1};
1508 
1509   x = Abs(x);
1510   auto half = xla::ScalarLike(x, 0.5);
1511   auto two = xla::ScalarLike(x, 2.0);
1512   auto thirty_two = xla::ScalarLike(x, 32.0);
1513   auto result_le_8 =
1514       EvaluateChebyshevPolynomial<double>(half * x - two, kI0eCoeffsA);
1515   auto result_gt_8 =
1516       EvaluateChebyshevPolynomial<double>(thirty_two / x - two, kI0eCoeffsB) /
1517       Sqrt(x);
1518   return Select(Le(x, xla::ScalarLike(x, 8.0)), result_le_8, result_gt_8);
1519 }
1520 
BesselI0e(XlaOp x)1521 XlaOp BesselI0e(XlaOp x) {
1522   auto& b = *x.builder();
1523   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1524     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("BesselI0e", x));
1525     TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
1526     if (shape.element_type() == F64) {
1527       return I0eImpl64(x);
1528     }
1529     // I0eF32Impl don't have enough precision when run with bf16 intermediates
1530     // (not surprising!), so upcast to f32 in this case.
1531     return DoWithUpcastToF32(x, {BF16, F16},
1532                              [](XlaOp x) { return I0eImpl32(x); });
1533   });
1534 }
1535 
1536 // Computes an approximation to the modified Bessel function of the first kind,
1537 // first order.
1538 // The following implementation follows Cephes' F32 and F64 implementation of
1539 // i1e.
1540 
I1eImpl32(XlaOp x)1541 static XlaOp I1eImpl32(XlaOp x) {
1542   static const std::array<float, 17> kI1eCoeffsA{
1543       9.38153738649577178388E-9f, -4.44505912879632808065E-8f,
1544       2.00329475355213526229E-7f, -8.56872026469545474066E-7f,
1545       3.47025130813767847674E-6f, -1.32731636560394358279E-5f,
1546       4.78156510755005422638E-5f, -1.61760815825896745588E-4f,
1547       5.12285956168575772895E-4f, -1.51357245063125314899E-3f,
1548       4.15642294431288815669E-3f, -1.05640848946261981558E-2f,
1549       2.47264490306265168283E-2f, -5.29459812080949914269E-2f,
1550       1.02643658689847095384E-1f, -1.76416518357834055153E-1f,
1551       2.52587186443633654823E-1f};
1552 
1553   static const std::array<float, 7> kI1eCoeffsB{
1554       -3.83538038596423702205E-9f, -2.63146884688951950684E-8f,
1555       -2.51223623787020892529E-7f, -3.88256480887769039346E-6f,
1556       -1.10588938762623716291E-4f, -9.76109749136146840777E-3f,
1557       7.78576235018280120474E-1f};
1558   XlaOp z = Abs(x);
1559   auto half = xla::ScalarLike(x, 0.5);
1560   auto two = xla::ScalarLike(x, 2.0);
1561   auto thirty_two = xla::ScalarLike(x, 32.0);
1562   auto result_le_8 =
1563       z * EvaluateChebyshevPolynomial<float>(half * z - two, kI1eCoeffsA);
1564   auto result_gt_8 =
1565       EvaluateChebyshevPolynomial<float>(thirty_two / z - two, kI1eCoeffsB) /
1566       Sqrt(z);
1567   return Sign(x) *
1568          Select(Le(z, xla::ScalarLike(x, 8.0)), result_le_8, result_gt_8);
1569 }
1570 
I1eImpl64(XlaOp x)1571 static XlaOp I1eImpl64(XlaOp x) {
1572   static const std::array<double, 29> kI1eCoeffsA{
1573       2.77791411276104639959E-18, -2.11142121435816608115E-17,
1574       1.55363195773620046921E-16, -1.10559694773538630805E-15,
1575       7.60068429473540693410E-15, -5.04218550472791168711E-14,
1576       3.22379336594557470981E-13, -1.98397439776494371520E-12,
1577       1.17361862988909016308E-11, -6.66348972350202774223E-11,
1578       3.62559028155211703701E-10, -1.88724975172282928790E-9,
1579       9.38153738649577178388E-9,  -4.44505912879632808065E-8,
1580       2.00329475355213526229E-7,  -8.56872026469545474066E-7,
1581       3.47025130813767847674E-6,  -1.32731636560394358279E-5,
1582       4.78156510755005422638E-5,  -1.61760815825896745588E-4,
1583       5.12285956168575772895E-4,  -1.51357245063125314899E-3,
1584       4.15642294431288815669E-3,  -1.05640848946261981558E-2,
1585       2.47264490306265168283E-2,  -5.29459812080949914269E-2,
1586       1.02643658689847095384E-1,  -1.76416518357834055153E-1,
1587       2.52587186443633654823E-1};
1588 
1589   static const std::array<double, 25> kI1eCoeffsB{
1590       7.51729631084210481353E-18,  4.41434832307170791151E-18,
1591       -4.65030536848935832153E-17, -3.20952592199342395980E-17,
1592       2.96262899764595013876E-16,  3.30820231092092828324E-16,
1593       -1.88035477551078244854E-15, -3.81440307243700780478E-15,
1594       1.04202769841288027642E-14,  4.27244001671195135429E-14,
1595       -2.10154184277266431302E-14, -4.08355111109219731823E-13,
1596       -7.19855177624590851209E-13, 2.03562854414708950722E-12,
1597       1.41258074366137813316E-11,  3.25260358301548823856E-11,
1598       -1.89749581235054123450E-11, -5.58974346219658380687E-10,
1599       -3.83538038596423702205E-9,  -2.63146884688951950684E-8,
1600       -2.51223623787020892529E-7,  -3.88256480887769039346E-6,
1601       -1.10588938762623716291E-4,  -9.76109749136146840777E-3,
1602       7.78576235018280120474E-1};
1603 
1604   XlaOp z = Abs(x);
1605   auto half = xla::ScalarLike(x, 0.5);
1606   auto two = xla::ScalarLike(x, 2.0);
1607   auto thirty_two = xla::ScalarLike(x, 32.0);
1608   auto result_le_8 =
1609       z * EvaluateChebyshevPolynomial<double>(half * z - two, kI1eCoeffsA);
1610   auto result_gt_8 =
1611       EvaluateChebyshevPolynomial<double>(thirty_two / z - two, kI1eCoeffsB) /
1612       Sqrt(z);
1613   return Sign(x) *
1614          Select(Le(z, xla::ScalarLike(x, 8.0)), result_le_8, result_gt_8);
1615 }
1616 
BesselI1e(XlaOp x)1617 XlaOp BesselI1e(XlaOp x) {
1618   auto& b = *x.builder();
1619   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1620     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("BesselI1e", x));
1621     TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
1622     if (shape.element_type() == F64) {
1623       return I1eImpl64(x);
1624     }
1625     // I1eF32Impl don't have enough precision when run with bf16 intermediates
1626     // (not surprising!), so upcast to f32 in this case.
1627     return DoWithUpcastToF32(x, {BF16, F16},
1628                              [](XlaOp x) { return I1eImpl32(x); });
1629   });
1630 }
1631 
1632 // I J Thompson and A R Barnett. 1986. Coulomb and Bessel functions of complex
1633 // arguments and order. J. Comput. Phys. 64, 2 (June 1986), 490-509.
1634 // DOI=http://dx.doi.org/10.1016/0021-9991(86)90046-X
LentzThompsonBarnettAlgorithm(int64_t num_iterations,double small,double threshold,const ForEachIndexBodyFunction & nth_partial_numerator,const ForEachIndexBodyFunction & nth_partial_denominator,absl::Span<const XlaOp> inputs,absl::string_view name)1635 static XlaOp LentzThompsonBarnettAlgorithm(
1636     int64_t num_iterations, double small, double threshold,
1637     const ForEachIndexBodyFunction& nth_partial_numerator,
1638     const ForEachIndexBodyFunction& nth_partial_denominator,
1639     absl::Span<const XlaOp> inputs, absl::string_view name) {
1640   auto& b = *inputs.front().builder();
1641   return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1642     TF_RET_CHECK(num_iterations < INT32_MAX);
1643 
1644     enum {
1645       // Position in the evaluation.
1646       kIterationIdx,
1647       // Whether or not we have reached the desired tolerance.
1648       kValuesUnconvergedIdx,
1649       // Ratio between nth canonical numerator and the nth-1 canonical
1650       // numerator.
1651       kCIdx,
1652       // Ratio between nth-1 canonical denominator and the nth canonical
1653       // denominator.
1654       kDIdx,
1655       // Computed approximant in the evaluation.
1656       kHIdx,
1657       // Inputs follow all of the other state.
1658       kFirstInputIdx,
1659     };
1660     auto while_cond_fn = [num_iterations](
1661                              absl::Span<const XlaOp> values,
1662                              XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
1663       auto iteration = values[kIterationIdx];
1664       auto iterations_remain_cond =
1665           Lt(iteration, ScalarLike(iteration, num_iterations));
1666       auto values_unconverged_cond = values[kValuesUnconvergedIdx];
1667       return And(iterations_remain_cond, values_unconverged_cond);
1668     };
1669 
1670     auto while_body_fn =
1671         [small, threshold, &nth_partial_numerator, &nth_partial_denominator](
1672             absl::Span<const XlaOp> values,
1673             XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
1674       XlaOp iteration = values[kIterationIdx];
1675 
1676       TF_ASSIGN_OR_RETURN(
1677           std::vector<XlaOp> partial_numerator,
1678           nth_partial_numerator(iteration, values.subspan(kFirstInputIdx),
1679                                 body_builder));
1680       TF_RET_CHECK(partial_numerator.size() == 1);
1681 
1682       TF_ASSIGN_OR_RETURN(
1683           std::vector<XlaOp> partial_denominator,
1684           nth_partial_denominator(iteration, values.subspan(kFirstInputIdx),
1685                                   body_builder));
1686       TF_RET_CHECK(partial_denominator.size() == 1);
1687 
1688       auto c = partial_denominator[0] + partial_numerator[0] / values[kCIdx];
1689       auto small_constant = FullLike(c, small);
1690       c = Select(Lt(Abs(c), small_constant), small_constant, c);
1691 
1692       auto d = partial_denominator[0] + partial_numerator[0] * values[kDIdx];
1693       d = Select(Lt(Abs(d), small_constant), small_constant, d);
1694 
1695       d = Reciprocal(d);
1696 
1697       auto delta = c * d;
1698       auto h = values[kHIdx] * delta;
1699 
1700       std::vector<XlaOp> updated_values(values.size());
1701       updated_values[kIterationIdx] = Add(iteration, ScalarLike(iteration, 1));
1702       updated_values[kCIdx] = c;
1703       updated_values[kDIdx] = d;
1704       updated_values[kHIdx] = h;
1705       std::copy(values.begin() + kFirstInputIdx, values.end(),
1706                 updated_values.begin() + kFirstInputIdx);
1707 
1708       // If any values are greater than the tolerance, we have not converged.
1709       auto tolerance_comparison =
1710           Ge(Abs(Sub(delta, FullLike(delta, 1.0))), FullLike(delta, threshold));
1711       updated_values[kValuesUnconvergedIdx] =
1712           ReduceAll(tolerance_comparison, ConstantR0<bool>(body_builder, false),
1713                     CreateScalarOrComputation(PRED, body_builder));
1714       return updated_values;
1715     };
1716 
1717     TF_ASSIGN_OR_RETURN(std::vector<XlaOp> partial_denominator,
1718                         nth_partial_denominator(Zero(&b, U32), inputs, &b));
1719     TF_RET_CHECK(partial_denominator.size() == 1);
1720     auto h = partial_denominator[0];
1721     auto small_constant = FullLike(h, small);
1722     h = Select(Lt(Abs(h), small_constant), small_constant, h);
1723 
1724     std::vector<XlaOp> values(kFirstInputIdx + inputs.size());
1725     values[kIterationIdx] = One(&b, U32);
1726     values[kValuesUnconvergedIdx] = ConstantR0<bool>(&b, true);
1727     values[kCIdx] = h;
1728     values[kDIdx] = FullLike(h, 0.0);
1729     values[kHIdx] = h;
1730     std::copy(inputs.begin(), inputs.end(), values.begin() + kFirstInputIdx);
1731     TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
1732                                                 values, name, &b));
1733     return values[kHIdx];
1734   });
1735 }
1736 
RegularizedIncompleteBeta(XlaOp a,XlaOp b,XlaOp x)1737 XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) {
1738   auto& builder = *x.builder();
1739   return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1740     TF_ASSIGN_OR_RETURN(Shape shape, builder.GetShape(a));
1741     TF_ASSIGN_OR_RETURN(Shape b_shape, builder.GetShape(b));
1742     TF_ASSIGN_OR_RETURN(Shape x_shape, builder.GetShape(x));
1743     if (b_shape.element_type() != shape.element_type() ||
1744         x_shape.element_type() != shape.element_type()) {
1745       return InvalidArgument(
1746           "Operands to RegularizedIncompleteBeta must have identical types, "
1747           "got shapes %s, %s, and %s",
1748           shape.ToString(), b_shape.ToString(), x_shape.ToString());
1749     }
1750     if (!primitive_util::IsFloatingPointType(shape.element_type())) {
1751       return InvalidArgument(
1752           "Operands to RegularizedIncompleteBeta must be real-valued "
1753           "floating-point, but got %s",
1754           PrimitiveType_Name(shape.element_type()));
1755     }
1756     PrimitiveType element_type = shape.element_type();
1757     if (element_type == F16 || element_type == BF16) {
1758       element_type = F32;
1759       a = ConvertElementType(a, F32);
1760       b = ConvertElementType(b, F32);
1761       x = ConvertElementType(x, F32);
1762     }
1763 
1764     // The partial numerator for the incomplete beta function is given
1765     // here: http://dlmf.nist.gov/8.17.E23 Note that there is a special
1766     // case: the partial numerator for the first iteration is one.
1767     auto NthPartialBetaincNumerator =
1768         [&](XlaOp iteration, absl::Span<const XlaOp> inputs,
1769             XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
1770       auto a = inputs[0];
1771       auto b = inputs[1];
1772       auto x = inputs[2];
1773       auto iteration_bcast = Broadcast(iteration, shape.dimensions());
1774       auto iteration_is_even =
1775           Eq(iteration_bcast % FullLike(iteration_bcast, 2),
1776              FullLike(iteration_bcast, 0));
1777       auto iteration_is_one = Eq(iteration_bcast, FullLike(iteration_bcast, 1));
1778       auto iteration_minus_one = iteration_bcast - FullLike(iteration_bcast, 1);
1779       auto m = iteration_minus_one / FullLike(iteration_minus_one, 2);
1780       m = ConvertElementType(m, element_type);
1781       auto one = FullLike(a, 1.0);
1782       auto two = FullLike(a, 2.0);
1783       // Partial numerator terms.
1784       auto even_numerator =
1785           -(a + m) * (a + b + m) * x / ((a + two * m) * (a + two * m + one));
1786       auto odd_numerator =
1787           m * (b - m) * x / ((a + two * m - one) * (a + two * m));
1788       auto one_numerator = ScalarLike(x, 1.0);
1789       auto numerator = Select(iteration_is_even, even_numerator, odd_numerator);
1790       return std::vector<XlaOp>{
1791           Select(iteration_is_one, one_numerator, numerator)};
1792     };
1793 
1794     auto NthPartialBetaincDenominator =
1795         [&shape](XlaOp iteration, absl::Span<const XlaOp> inputs,
1796                  XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
1797       auto x = inputs[2];
1798       auto iteration_bcast = Broadcast(iteration, shape.dimensions());
1799       return std::vector<XlaOp>{
1800           Select(Eq(iteration_bcast, ScalarLike(iteration_bcast, 0)),
1801                  ScalarLike(x, 0.0), ScalarLike(x, 1.0))};
1802     };
1803 
1804     // Determine if the inputs are out of range.
1805     auto result_is_nan =
1806         Or(Or(Or(Le(a, ScalarLike(a, 0.0)), Le(b, ScalarLike(b, 0.0))),
1807               Lt(x, ScalarLike(x, 0.0))),
1808            Gt(x, ScalarLike(x, 1.0)));
1809 
1810     // The continued fraction will converge rapidly when x < (a+1)/(a+b+2)
1811     // as per: http://dlmf.nist.gov/8.17.E23
1812     //
1813     // Otherwise, we can rewrite using the symmetry relation as per:
1814     // http://dlmf.nist.gov/8.17.E4
1815     auto converges_rapidly =
1816         Lt(x, (a + FullLike(a, 1.0)) / (a + b + FullLike(b, 2.0)));
1817     auto a_orig = a;
1818     a = Select(converges_rapidly, a, b);
1819     b = Select(converges_rapidly, b, a_orig);
1820     x = Select(converges_rapidly, x, Sub(FullLike(x, 1.0), x));
1821 
1822     XlaOp continued_fraction;
1823 
1824     // Thresholds and iteration counts taken from Cephes.
1825     if (element_type == F32) {
1826       continued_fraction = LentzThompsonBarnettAlgorithm(
1827           /*num_iterations=*/200,
1828           /*small=*/std::numeric_limits<float>::epsilon() / 2.0f,
1829           /*threshold=*/std::numeric_limits<float>::epsilon() / 2.0f,
1830           /*nth_partial_numerator=*/NthPartialBetaincNumerator,
1831           /*nth_partial_denominator=*/NthPartialBetaincDenominator, {a, b, x},
1832           "Betainc");
1833     } else {
1834       TF_RET_CHECK(element_type == F64);
1835       continued_fraction = LentzThompsonBarnettAlgorithm(
1836           /*num_iterations=*/600,
1837           /*small=*/std::numeric_limits<double>::epsilon() / 2.0f,
1838           /*threshold=*/std::numeric_limits<double>::epsilon() / 2.0f,
1839           /*nth_partial_numerator=*/NthPartialBetaincNumerator,
1840           /*nth_partial_denominator=*/NthPartialBetaincDenominator, {a, b, x},
1841           "Betainc");
1842     }
1843 
1844     // We want to compute the regularized complete beta function so we need to
1845     // combine the continued fraction with a few more terms as well as dividing
1846     // it by Beta(a, b). To avoid overflow, we compute in the log domain.
1847     // See http://dlmf.nist.gov/8.17.E22 for an easier to read version of this
1848     // formula.
1849     auto lbeta = Lbeta(a, b);
1850     auto result =
1851         continued_fraction * Exp(Log(x) * a + Log1p(-x) * b - lbeta) / a;
1852     result = Select(result_is_nan, NanValue(&builder, element_type), result);
1853 
1854     // We have an additional fixup to do if we are taking advantage of the
1855     // symmetry relation.
1856     auto out =
1857         Select(converges_rapidly, result, Sub(FullLike(result, 1.0), result));
1858     return shape.element_type() == element_type
1859                ? out
1860                : ConvertElementType(out, shape.element_type());
1861   });
1862 }
1863 
Polygamma(XlaOp n,XlaOp x)1864 XlaOp Polygamma(XlaOp n, XlaOp x) {
1865   auto& builder = *x.builder();
1866   auto doit = [](XlaOp n, XlaOp x, PrimitiveType type) -> XlaOp {
1867     XlaOp n_plus_one = n + ScalarLike(n, 1.);
1868     XlaOp sign =
1869         (ScalarLike(n, 2.) * Rem(n, ScalarLike(n, 2.)) - ScalarLike(n, 1.));
1870 
1871     const double nan = std::numeric_limits<double>::quiet_NaN();
1872 
1873     XlaOp output = Select(Eq(n, ScalarLike(n, 0.)), Digamma(x),
1874                           sign * Exp(Lgamma(n_plus_one)) * Zeta(n_plus_one, x));
1875     // Check that n is a natural number.
1876     output = Select(Or(Ne(n, Floor(n)), Lt(n, ScalarLike(n, 0.))),
1877                     ScalarLike(n, nan), output);
1878     return output;
1879   };
1880   return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1881     TF_ASSIGN_OR_RETURN(auto n_shape, builder.GetShape(n));
1882     TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x));
1883     if (n_shape != x_shape) {
1884       return InvalidArgument(
1885           "Arguments to Polygamma must have equal shapes and types; "
1886           "got %s and %s",
1887           n_shape.ToString(), x_shape.ToString());
1888     }
1889     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x));
1890     bool needs_upcast =
1891         n_shape.element_type() == F16 || x_shape.element_type() == BF16;
1892 
1893     if (needs_upcast) {
1894       n = ConvertElementType(n, F32);
1895       x = ConvertElementType(x, F32);
1896     }
1897     XlaOp result = doit(n, x, n_shape.element_type());
1898     if (needs_upcast) {
1899       result = ConvertElementType(result, n_shape.element_type());
1900     }
1901     return result;
1902   });
1903 }
1904 
Zeta(XlaOp x,XlaOp q)1905 XlaOp Zeta(XlaOp x, XlaOp q) {
1906   auto& builder = *x.builder();
1907   auto doit = [&builder](XlaOp x, XlaOp q, PrimitiveType type) -> XlaOp {
1908     // (2k) ! / B_{2k}, where B_{2k} are the Bernoulli numbers.
1909     // These are ordered in reverse.
1910     static const std::array<double, 12> kZetaCoeffs{
1911         -7.1661652561756670113e18,
1912         1.8152105401943546773e17,
1913         -4.5979787224074726105e15,
1914         1.1646782814350067249e14,
1915         -2.950130727918164224e12,
1916         7.47242496e10,
1917         -1.8924375803183791606e9,
1918         47900160.0,
1919         -1209600.0,
1920         30240.0,
1921         -720.0,
1922         12.0,
1923     };
1924 
1925     // For speed we'll always use 9 iterations for the initial series estimate,
1926     // and a 12 term expansion for the Euler-Maclaurin formula.
1927 
1928     XlaOp a = q;
1929     XlaOp neg_power = ScalarLike(a, 0.);
1930     XlaOp initial_sum = Pow(q, Neg(x));
1931     for (int i = 0; i < 9; ++i) {
1932       a = a + ScalarLike(a, 1.);
1933       neg_power = Pow(a, Neg(x));
1934       initial_sum = initial_sum + neg_power;
1935     }
1936     a = a + ScalarLike(a, 1.);
1937     neg_power = Pow(a, Neg(x));
1938     XlaOp s = initial_sum + neg_power * a / (x - ScalarLike(a, 1.));
1939     XlaOp a_inverse_square = Reciprocal(Square(a));
1940     XlaOp horner_sum = ScalarLike(a, 0.);
1941     XlaOp factor = ScalarLike(a, 1.);
1942     // Use Horner's rule for this.
1943     // Note this differs from Cephes which does a 'naive' polynomial evaluation.
1944     // Using Horner's rule allows to avoid some NaN's and Infs from happening,
1945     // resulting in more numerically stable code.
1946     for (int i = 0; i < 11; ++i) {
1947       factor =
1948           (x - ScalarLike(x, 22 - 2 * i)) * (x - ScalarLike(x, 21 - 2 * i));
1949       horner_sum = factor * a_inverse_square *
1950                    (horner_sum + ScalarLike(a, 1. / kZetaCoeffs[i]));
1951     }
1952     s = s + neg_power *
1953                 (ScalarLike(neg_power, 0.5) +
1954                  x / a * (ScalarLike(a, 1. / kZetaCoeffs[11]) + horner_sum));
1955 
1956     const double nan = std::numeric_limits<double>::quiet_NaN();
1957     const double inf = std::numeric_limits<double>::infinity();
1958     // Use the initial zeta sum without the correction term coming
1959     // from Euler-Maclaurin if it is accurate enough.
1960     XlaOp output =
1961         Select(Lt(Abs(neg_power), Abs(initial_sum) * Epsilon(&builder, type)),
1962                initial_sum, s);
1963 
1964     // This is the harmonic series.
1965     output = Select(Eq(x, ScalarLike(x, 1.)), ScalarLike(x, inf), output);
1966 
1967     // Function is not defined for x < 1.
1968     output = Select(Lt(x, ScalarLike(x, 1.)), ScalarLike(x, nan), output);
1969 
1970     // For q <= 0, x must be an integer.
1971     XlaOp x_domain_error = And(Le(q, ScalarLike(x, 0.)), Ne(x, Floor(x)));
1972     output = Select(x_domain_error, ScalarLike(x, nan), output);
1973 
1974     // For all integer q <= 0, zeta has a pole. The limit is only defined as
1975     // +inf if x is and even integer.
1976     XlaOp at_pole = And(Le(q, ScalarLike(x, 0.)), Eq(q, Floor(q)));
1977     XlaOp x_is_even_int =
1978         And(Eq(Rem(x, ScalarLike(x, 2.)), ScalarLike(x, 0.)), Eq(x, Floor(x)));
1979     output = Select(
1980         at_pole, Select(x_is_even_int, ScalarLike(x, inf), ScalarLike(x, nan)),
1981         output);
1982 
1983     return output;
1984   };
1985   return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1986     TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x));
1987     TF_ASSIGN_OR_RETURN(auto q_shape, builder.GetShape(q));
1988     if (x_shape != q_shape) {
1989       return InvalidArgument(
1990           "Arguments to Zeta must have equal shapes and types; got %s and %s",
1991           x_shape.ToString(), q_shape.ToString());
1992     }
1993     TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x));
1994     bool needs_upcast =
1995         x_shape.element_type() == F16 || x_shape.element_type() == BF16;
1996 
1997     if (needs_upcast) {
1998       x = ConvertElementType(x, F32);
1999       q = ConvertElementType(q, F32);
2000     }
2001     XlaOp result = doit(x, q, x_shape.element_type());
2002     if (needs_upcast) {
2003       result = ConvertElementType(result, x_shape.element_type());
2004     }
2005     return result;
2006   });
2007 }
2008 
2009 }  // namespace xla
2010