1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/math.h"
17
18 #include <cmath>
19
20 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/lib/loops.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27
28 namespace xla {
29 namespace {
30
31 // Evaluate the polynomial given `x` and coefficients in decreasing order.
32 template <typename FP>
EvaluatePolynomial(XlaOp x,absl::Span<const FP> coefficients)33 XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const FP> coefficients) {
34 static_assert(std::is_floating_point<FP>::value,
35 "Template-argument 'FP' must be a floating-point type");
36 XlaOp poly = ScalarLike(x, 0.0);
37 for (FP c : coefficients) {
38 poly = poly * x + ScalarLike(x, c);
39 }
40 return poly;
41 }
42
43 // Evaluate the chebyshev polynomial given `x` and coefficients in decreasing
44 // order.
45 template <typename FP>
EvaluateChebyshevPolynomial(XlaOp x,absl::Span<const FP> coefficients)46 XlaOp EvaluateChebyshevPolynomial(XlaOp x, absl::Span<const FP> coefficients) {
47 static_assert(std::is_floating_point<FP>::value,
48 "Template-argument 'FP' must be a floating-point type");
49 XlaOp b0 = ScalarLike(x, 0.0);
50 XlaOp b1 = ScalarLike(x, 0.0);
51 XlaOp b2 = ScalarLike(x, 0.0);
52 for (FP c : coefficients) {
53 b2 = b1;
54 b1 = b0;
55 b0 = x * b1 - b2 + ScalarLike(x, c);
56 }
57 return ScalarLike(x, 0.5) * (b0 - b2);
58 }
59
60 } // namespace
61
62 // Returns operation(operand), except if `operand` is one of the types in
63 // upcast_types, in which case first converts it to F32, and then converts the
64 // result down to the original type.
DoWithUpcastToF32(XlaOp operand,absl::Span<const PrimitiveType> upcast_types,const std::function<XlaOp (XlaOp)> & operation)65 static XlaOp DoWithUpcastToF32(XlaOp operand,
66 absl::Span<const PrimitiveType> upcast_types,
67 const std::function<XlaOp(XlaOp)>& operation) {
68 auto& b = *operand.builder();
69 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
70 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
71 PrimitiveType elem_ty = shape.element_type();
72 bool needs_upcast = absl::c_linear_search(upcast_types, elem_ty);
73
74 if (needs_upcast) {
75 operand = ConvertElementType(operand, F32);
76 }
77 XlaOp result = operation(operand);
78 if (needs_upcast) {
79 result = ConvertElementType(result, elem_ty);
80 }
81 return result;
82 });
83 }
84
85 // TODO(jlebar): Use this function in more places in this file to restrict the
86 // domain of other functions.
EnsureOperandIsRealFp(absl::string_view op_name,XlaOp operand)87 static Status EnsureOperandIsRealFp(absl::string_view op_name, XlaOp operand) {
88 auto& b = *operand.builder();
89 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
90 auto elem_ty = shape.element_type();
91 if (!primitive_util::IsFloatingPointType(elem_ty)) {
92 return InvalidArgument(
93 "Operands to %s must be real-valued floating-point, but got %s",
94 op_name, PrimitiveType_Name(elem_ty));
95 }
96 return Status::OK();
97 }
98
IsPosInf(XlaOp operand)99 XlaOp IsPosInf(XlaOp operand) {
100 auto& b = *operand.builder();
101 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
102 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsPosInf", operand));
103 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
104 // Note that this is only correct for floating-point types. If we wanted it
105 // to be correct for all types, we'd need to Gt(MaxFiniteValue).
106 return Eq(operand, MaxValue(&b, shape.element_type()));
107 });
108 }
109
IsNegInf(XlaOp operand)110 XlaOp IsNegInf(XlaOp operand) {
111 auto& b = *operand.builder();
112 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
113 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegInf", operand));
114 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
115 // Note that this is only correct for floating-point types. If we wanted it
116 // to be correct for all types, we'd need to Lt(MinFiniteValue).
117 return Eq(operand, MinValue(&b, shape.element_type()));
118 });
119 }
120
IsInf(XlaOp operand)121 XlaOp IsInf(XlaOp operand) {
122 auto& b = *operand.builder();
123 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
124 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsInf", operand));
125 return IsPosInf(Abs(operand));
126 });
127 }
128
IsNan(XlaOp operand)129 XlaOp IsNan(XlaOp operand) {
130 auto& b = *operand.builder();
131 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
132 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNan", operand));
133 return Ne(operand, operand);
134 });
135 }
136
IsNegZero(XlaOp operand)137 XlaOp IsNegZero(XlaOp operand) {
138 auto& b = *operand.builder();
139 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
140 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegZero", operand));
141 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
142
143 // The bitwise representation of -0 in bfloat16 and IEEE 754 is 0x80...0
144 // (sign bit on, all other bits off).
145 switch (shape.element_type()) {
146 case F64:
147 return Eq(BitcastConvertType(operand, U64),
148 ConstantR0WithType(&b, U64, uint64{1} << 63));
149 case F32:
150 return Eq(BitcastConvertType(operand, U32),
151 ConstantR0WithType(&b, U32, uint32{1} << 31));
152 case F16:
153 case BF16:
154 // Not all XLA backends handle U16 well, so we convert to F32/U32.
155 // TODO(jlebar): It would be nice if we could stay in (B)F16/U16 for
156 // backends that *do* support it.
157 return Eq(BitcastConvertType(ConvertElementType(operand, F32), U32),
158 ConstantR0WithType(&b, U32, uint32{1} << 31));
159 default:
160 LOG(FATAL) << "Expected real fp type.";
161 }
162 });
163 }
164
Square(XlaOp operand)165 XlaOp Square(XlaOp operand) { return operand * operand; }
166
Reciprocal(XlaOp operand)167 XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; }
168
169 // Computes an approximation of the error function complement (1 - erf(x)).
170 //
171 // Precondition: abs(x) >= 1. Otherwise, use ErfImpl.
172 //
173 // This follows Cephes's f32 implementation of erfc.
ErfcImpl32(XlaOp x)174 static XlaOp ErfcImpl32(XlaOp x) {
175 // Coefficients for erfc(f32), from Cephes.
176 const double kMaxlog = 88.72283905206835;
177 // erfc(x) = exp(-x^2) P(1/x^2), 1 < x < 2
178 static const std::array<float, 9> kErfcPCoefficient{
179 +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1,
180 -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1,
181 +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1,
182 };
183 // erfc(x) = exp(-x^2) R(1/x^2), 2 <= x < kMaxlog
184 static const std::array<float, 8> kErfcRCoefficient{
185 -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0,
186 +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1,
187 -2.820767439740514E-1, +5.641895067754075E-1,
188 };
189 XlaOp abs_x = Abs(x);
190 XlaOp z = Exp(-x * x);
191 XlaOp q = ScalarLike(x, 1) / abs_x;
192 XlaOp y = q * q;
193 XlaOp p = Select(Lt(abs_x, ScalarLike(x, 2.0)),
194 EvaluatePolynomial<float>(y, kErfcPCoefficient),
195 EvaluatePolynomial<float>(y, kErfcRCoefficient));
196 y = z * q * p;
197 XlaOp y_clamp = Select(Lt(z, ScalarLike(x, -kMaxlog)), ScalarLike(x, 0), y);
198 return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y_clamp, y_clamp);
199 }
200
201 // Compute a polynomial approximation of the error function.
202 //
203 // Precondition: abs(x) <= 1. Otherwise, use ErfcImpl.
204 //
205 // This follows Cephes's f32 implementation of erf.
ErfImpl32Cephes(XlaOp x)206 static XlaOp ErfImpl32Cephes(XlaOp x) {
207 // Coefficients for by erf(f32), from Cephes.
208 //
209 // erf(x) = x P(x^2), 0 < x < 1
210 static const std::array<float, 7> kErfTCoefficient{
211 +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3,
212 -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1,
213 +1.128379165726710E+0,
214 };
215 return x * EvaluatePolynomial<float>(x * x, kErfTCoefficient);
216 }
217
ErfcImpl64(XlaOp x)218 static XlaOp ErfcImpl64(XlaOp x) {
219 // Coefficients for erfc(f64), from Cephes.
220 const double kMaxlog = 7.09782712893383996843E2;
221 // erfc(x) = exp(-x^2) P(|x|) / Q(|x|), 1 < x < 8
222 static const std::array<double, 9> kErfcPCoefficient{
223 2.46196981473530512524E-10, 5.64189564831068821977E-1,
224 7.46321056442269912687E0, 4.86371970985681366614E1,
225 1.96520832956077098242E2, 5.26445194995477358631E2,
226 9.34528527171957607540E2, 1.02755188689515710272E3,
227 5.57535335369399327526E2};
228 static const std::array<double, 9> kErfcQCoefficient{
229 1.00000000000000000000E0, 1.32281951154744992508E1,
230 8.67072140885989742329E1, 3.54937778887819891062E2,
231 9.75708501743205489753E2, 1.82390916687909736289E3,
232 2.24633760818710981792E3, 1.65666309194161350182E3,
233 5.57535340817727675546E2};
234
235 // erfc(x) = exp(-x^2) R(|x|) / S(|x|), 8 <= x < kMaxlog
236 static const std::array<double, 6> kErfcRCoefficient{
237 5.64189583547755073984E-1, 1.27536670759978104416E0,
238 5.01905042251180477414E0, 6.16021097993053585195E0,
239 7.40974269950448939160E0, 2.97886665372100240670E0};
240 static const std::array<double, 7> kErfcSCoefficient{
241 1.00000000000000000000E0, 2.26052863220117276590E0,
242 9.39603524938001434673E0, 1.20489539808096656605E1,
243 1.70814450747565897222E1, 9.60896809063285878198E0,
244 3.36907645100081516050E0};
245
246 XlaOp z = -x * x;
247 XlaOp abs_x = Abs(x);
248 XlaOp y =
249 Select(Lt(abs_x, ScalarLike(x, 8.0)),
250 Exp(z) * EvaluatePolynomial<double>(abs_x, kErfcPCoefficient) /
251 EvaluatePolynomial<double>(abs_x, kErfcQCoefficient),
252 Exp(z) * EvaluatePolynomial<double>(abs_x, kErfcRCoefficient) /
253 EvaluatePolynomial<double>(abs_x, kErfcSCoefficient));
254 XlaOp y_clamp = Select(Lt(z, ScalarLike(x, -kMaxlog)), ScalarLike(x, 0), y);
255 return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y_clamp, y_clamp);
256 }
257
258 // Compute a polynomial approximation of the error function.
259 //
260 // Precondition: abs(x) <= 1. Otherwise, use ErfcImpl.
ErfImpl64(XlaOp x)261 static XlaOp ErfImpl64(XlaOp x) {
262 // Coefficients for by erf(f64), from Cephes.
263 //
264 // erf(x) = x T(x^2) / U(x^2), 0 < x < 1
265 static std::array<double, 5> kErfTCoefficient{
266 9.60497373987051638749E0, 9.00260197203842689217E1,
267 2.23200534594684319226E3, 7.00332514112805075473E3,
268 5.55923013010394962768E4};
269 static std::array<double, 6> kErfUCoefficient{
270 1.00000000000000000000E0, 3.35617141647503099647E1,
271 5.21357949780152679795E2, 4.59432382970980127987E3,
272 2.26290000613890934246E4, 4.92673942608635921086E4};
273 XlaOp z = x * x;
274 return x * EvaluatePolynomial<double>(z, kErfTCoefficient) /
275 EvaluatePolynomial<double>(z, kErfUCoefficient);
276 }
277
Erfc(XlaOp x)278 XlaOp Erfc(XlaOp x) {
279 auto& b = *x.builder();
280 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
281 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erfc", x));
282 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
283 // erfc(x) =
284 // erfc_impl(x) if x > 1
285 // 1 - erf_impl(x) otherwise
286 if (shape.element_type() == F64) {
287 return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl64(x),
288 ScalarLike(x, 1) - ErfImpl64(x));
289 }
290 // Erf(c)Impl don't have enough precision when run with bf16 intermediates
291 // (not surprising!), so upcast to f32 in this case.
292 return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
293 return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x),
294 ScalarLike(x, 1) - ErfImpl32Cephes(x));
295 });
296 });
297 }
298
299 // Compute a polynomial approximation of the error function.
300 // This is the same approximation used by Eigen.
ErfImpl32(XlaOp x)301 static XlaOp ErfImpl32(XlaOp x) {
302 static const std::array<float, 7> kAlpha{
303 -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
304 -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
305 -1.60960333262415e-02f,
306 };
307
308 static const std::array<float, 5> kBeta{
309 -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
310 -7.37332916720468e-03f, -1.42647390514189e-02f,
311 };
312
313 x = Clamp(ScalarLike(x, -4.f), x, ScalarLike(x, 4.f));
314 auto x2 = x * x;
315 return x * EvaluatePolynomial<float>(x2, kAlpha) /
316 EvaluatePolynomial<float>(x2, kBeta);
317 }
318
Erf(XlaOp x)319 XlaOp Erf(XlaOp x) {
320 auto& b = *x.builder();
321 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
322 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erf", x));
323 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
324 // erf(x) =
325 // erf_impl(x) if x < 1
326 // 1 - erfc_impl(x) otherwise
327 if (shape.element_type() == F64) {
328 return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl64(x),
329 ScalarLike(x, 1) - ErfcImpl64(x));
330 }
331 // Erf(c)Impl don't have enough precision when run with bf16 intermediates
332 // (not surprising!), so upcast to f32 in this case.
333 return DoWithUpcastToF32(x, {BF16, F16},
334 [](XlaOp x) { return ErfImpl32(x); });
335 });
336 }
337
338 namespace {
339
340 // Approximation for the inverse error function from
341 // Giles, M., "Approximating the erfinv function".
342 // The approximation has the form:
343 // w = -log((1 - x) * (1 + x))
344 // if ( w < 5 ) {
345 // w = w - 2.5
346 // p = sum_{i=1}^n lq[i]*w^i
347 // } else {
348 // w = sqrt(w) - 3
349 // p = sum_{i=1}^n gq[i]*w^i
350 // }
351 // return p*x
ErfInv32(XlaOp x)352 XlaOp ErfInv32(XlaOp x) {
353 constexpr int kDegree = 9;
354 constexpr std::array<float, 9> w_less_than_5_constants = {
355 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
356 -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
357 -0.00417768164f, 0.246640727f, 1.50140941f};
358 constexpr std::array<float, 9> w_greater_than_5_constants = {
359 -0.000200214257f, 0.000100950558f, 0.00134934322f,
360 -0.00367342844f, 0.00573950773f, -0.0076224613f,
361 0.00943887047f, 1.00167406f, 2.83297682f};
362
363 // Compute logarithm of (1+arg) using log1p(arg) which is more precise than
364 // log(1+arg) when arg is close to zero. For more details, see
365 // https://en.cppreference.com/w/cpp/numeric/math/log1p
366 auto w = -Log1p(-x * x);
367
368 auto lt = Lt(w, ScalarLike(x, 5.0));
369 auto coefficient = [&](int i) {
370 return Select(lt, FullLike(x, w_less_than_5_constants[i]),
371 FullLike(x, w_greater_than_5_constants[i]));
372 };
373 w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0));
374 auto p = coefficient(0);
375 for (int i = 1; i < kDegree; ++i) {
376 p = coefficient(i) + p * w;
377 }
378
379 // Result modulo edge cases.
380 XlaOp result = p * x;
381
382 // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is
383 // indeterminate, and can give nan or -/+inf.)
384 auto& b = *x.builder();
385 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
386 TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x));
387 return Select(Eq(Abs(x), ScalarLike(x, 1)),
388 x * MaxValue(&b, shape.element_type()), result);
389 });
390 }
391
ErfInv64(XlaOp x)392 XlaOp ErfInv64(XlaOp x) {
393 constexpr std::array<double, 23> w_less_than_6_25_constants = {
394 -3.6444120640178196996e-21, -1.685059138182016589e-19,
395 1.2858480715256400167e-18, 1.115787767802518096e-17,
396 -1.333171662854620906e-16, 2.0972767875968561637e-17,
397 6.6376381343583238325e-15, -4.0545662729752068639e-14,
398 -8.1519341976054721522e-14, 2.6335093153082322977e-12,
399 -1.2975133253453532498e-11, -5.4154120542946279317e-11,
400 1.051212273321532285e-09, -4.1126339803469836976e-09,
401 -2.9070369957882005086e-08, 4.2347877827932403518e-07,
402 -1.3654692000834678645e-06, -1.3882523362786468719e-05,
403 0.0001867342080340571352, -0.00074070253416626697512,
404 -0.0060336708714301490533, 0.24015818242558961693,
405 1.6536545626831027356};
406 constexpr std::array<double, 19> w_less_than_16_constants = {
407 2.2137376921775787049e-09, 9.0756561938885390979e-08,
408 -2.7517406297064545428e-07, 1.8239629214389227755e-08,
409 1.5027403968909827627e-06, -4.013867526981545969e-06,
410 2.9234449089955446044e-06, 1.2475304481671778723e-05,
411 -4.7318229009055733981e-05, 6.8284851459573175448e-05,
412 2.4031110387097893999e-05, -0.0003550375203628474796,
413 0.00095328937973738049703, -0.0016882755560235047313,
414 0.0024914420961078508066, -0.0037512085075692412107,
415 0.005370914553590063617, 1.0052589676941592334,
416 3.0838856104922207635,
417 };
418 constexpr std::array<double, 17> w_greater_than_16_constants = {
419 -2.7109920616438573243e-11, -2.5556418169965252055e-10,
420 1.5076572693500548083e-09, -3.7894654401267369937e-09,
421 7.6157012080783393804e-09, -1.4960026627149240478e-08,
422 2.9147953450901080826e-08, -6.7711997758452339498e-08,
423 2.2900482228026654717e-07, -9.9298272942317002539e-07,
424 4.5260625972231537039e-06, -1.9681778105531670567e-05,
425 7.5995277030017761139e-05, -0.00021503011930044477347,
426 -0.00013871931833623122026, 1.0103004648645343977,
427 4.8499064014085844221,
428 };
429 // Compute logarithm of (1+arg) using log1p(arg) which is more precise than
430 // log(1+arg) when arg is close to zero. For more details, see
431 // https://en.cppreference.com/w/cpp/numeric/math/log1p
432 auto w = -Log1p(-x * x);
433
434 auto lt_6_25 = Lt(w, ScalarLike(x, 6.25));
435 auto lt_16 = Lt(w, ScalarLike(x, 16));
436 auto coefficient = [&](int i) {
437 auto c = FullLike(x, w_less_than_6_25_constants[i]);
438 if (i < 19) {
439 c = Select(lt_6_25, c, FullLike(x, w_less_than_16_constants[i]));
440 }
441 if (i < 17) {
442 c = Select(lt_16, c, FullLike(x, w_greater_than_16_constants[i]));
443 }
444 return c;
445 };
446 auto sqrt_w = Sqrt(w);
447 w = Select(lt_6_25, w - ScalarLike(x, 3.125),
448 sqrt_w - Select(lt_16, ScalarLike(x, 3.25), ScalarLike(x, 5.0)));
449 auto p = coefficient(0);
450 for (int i = 1; i < 17; ++i) {
451 p = coefficient(i) + p * w;
452 }
453 for (int i = 17; i < 19; ++i) {
454 p = Select(lt_16, coefficient(i) + p * w, p);
455 }
456 for (int i = 19; i < 23; ++i) {
457 p = Select(lt_6_25, coefficient(i) + p * w, p);
458 }
459 // Result modulo edge cases.
460 XlaOp result = p * x;
461
462 // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is
463 // indeterminate, and can give nan or -/+inf.)
464 auto& b = *x.builder();
465 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
466 TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x));
467 return Select(Eq(Abs(x), ScalarLike(x, 1)),
468 x * MaxValue(&b, shape.element_type()), result);
469 });
470 }
471
472 } // namespace
473
ErfInv(XlaOp x)474 XlaOp ErfInv(XlaOp x) {
475 auto& b = *x.builder();
476 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
477 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("ErfInv", x));
478 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
479 if (shape.element_type() == F64) {
480 return ErfInv64(x);
481 }
482 return DoWithUpcastToF32(x, {BF16, F16},
483 [](XlaOp x) { return ErfInv32(x); });
484 });
485 }
486
487 namespace {
488 // Coefficients for the Lanczos approximation of the gamma function. The
489 // coefficients are uniquely determined by the choice of g and n (kLanczosGamma
490 // and kLanczosCoefficients.size() + 1). The coefficients below correspond to
491 // [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7,
492 // 9] seemed to be the least sensitive to the quality of the log function. In
493 // particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
494 // for a particularly inaccurate log function.
495 static constexpr double kLanczosGamma = 7; // aka g
496 static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
497 static constexpr std::array<double, 8> kLanczosCoefficients = {
498 676.520368121885098567009190444019, -1259.13921672240287047156078755283,
499 771.3234287776530788486528258894, -176.61502916214059906584551354,
500 12.507343278686904814458936853, -0.13857109526572011689554707,
501 9.984369578019570859563e-6, 1.50563273514931155834e-7};
502 } // namespace
503
504 // Compute the Lgamma function using Lanczos' approximation from "A Precision
505 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
506 // series B. Vol. 1:
507 // lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
508 // t(z) = z + kLanczosGamma + 1/2
509 // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
Lgamma(XlaOp input)510 XlaOp Lgamma(XlaOp input) {
511 auto do_it = [](XlaOp input) {
512 XlaOp one_half = ScalarLike(input, 0.5);
513 XlaOp one = ScalarLike(input, 1);
514
515 XlaOp pi = ScalarLike(input, M_PI);
516 XlaOp log_pi = ScalarLike(input, std::log(M_PI));
517 XlaOp log_sqrt_two_pi =
518 ScalarLike(input, (std::log(2) + std::log(M_PI)) / 2);
519
520 XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5);
521 XlaOp log_lanczos_gamma_plus_one_half =
522 ScalarLike(input, std::log(kLanczosGamma + 0.5));
523
524 XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff);
525
526 // If the input is less than 0.5 use Euler's reflection formula:
527 // gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
528 XlaOp need_to_reflect = Lt(input, one_half);
529 XlaOp z = Select(need_to_reflect, -input, input - one);
530
531 XlaOp x = base_lanczos_coeff;
532 for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
533 XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]);
534 XlaOp index = ScalarLike(input, i);
535 x = x + lanczos_coefficient / (z + index + one);
536 }
537
538 // To improve accuracy on platforms with less-precise log implementations,
539 // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
540 // the device.
541 // log(t) = log(kLanczosGamma + 0.5 + z)
542 // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
543 XlaOp t = lanczos_gamma_plus_one_half + z;
544 XlaOp log_t = log_lanczos_gamma_plus_one_half +
545 Log1p(z / lanczos_gamma_plus_one_half);
546
547 // Compute the final result (modulo reflection). t(z) may be large, and we
548 // need to be careful not to overflow to infinity in the first term of
549 //
550 // (z + 1/2) * log(t(z)) - t(z).
551 //
552 // Therefore we compute this as
553 //
554 // (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
555 //
556 XlaOp log_y = log_sqrt_two_pi + (z + one_half - t / log_t) * log_t + Log(x);
557
558 // Compute the reflected value, used when x < 0.5:
559 //
560 // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
561 //
562 // (The abs is because lgamma is the log of the absolute value of the gamma
563 // function.)
564 //
565 // We have to be careful when computing the final term above. gamma(x) goes
566 // to +/-inf at every integer x < 0, and this is controlled by the
567 // sin(pi * x) term. The slope is large, so precision is particularly
568 // important.
569 //
570 // Because abs(sin(pi * x)) has period 1, we can equivalently use
571 // abs(sin(pi * frac(x))), where frac(x) is the fractional part of x. This
572 // is more numerically accurate: It doesn't overflow to inf like pi * x can,
573 // and if x is an integer, it evaluates to 0 exactly, which is significant
574 // because we then take the log of this value, and log(0) is inf.
575 //
576 // We don't have a frac(x) primitive in XLA and computing it is tricky, but
577 // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for
578 // our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
579 //
580 // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
581 // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
582 // [0, 1] is symmetric across the line Y=0.5.
583 //
584 XlaOp abs_input = Abs(input);
585 XlaOp abs_frac_input = abs_input - Floor(abs_input);
586 // Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve
587 // precision of pi * abs_frac_input for values of abs_frac_input close to 1.
588 XlaOp reduced_frac_input =
589 Select(Gt(abs_frac_input, ScalarLike(abs_frac_input, 0.5)),
590 ScalarLike(abs_frac_input, 1) - abs_frac_input, abs_frac_input);
591 XlaOp reflection_denom = Log(Sin(pi * reduced_frac_input));
592
593 // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
594 // then it "wins" and the result is +/-inf.
595 XlaOp reflection =
596 Select(IsFinite(reflection_denom), log_pi - reflection_denom - log_y,
597 -reflection_denom);
598 XlaOp result = Select(need_to_reflect, reflection, log_y);
599
600 // lgamma(+/-inf) = +inf.
601 XlaOp inf_bcast = FullLike(input, std::numeric_limits<float>::infinity());
602 return Select(IsInf(input), inf_bcast, result);
603 };
604
605 auto& b = *input.builder();
606 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
607 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Lgamma", input));
608 // F16 and BF16 don't provide sufficient precision for intermediate results
609 // here (although it's better than you might expect!), so do the
610 // computations in F32.
611 return DoWithUpcastToF32(input, {BF16, F16}, do_it);
612 });
613 }
614
615 // Computes an approximation of the lbeta function which is equivalent to
616 // log(abs(Beta(a, b))) but avoids overflow by computing it with lgamma.
Lbeta(XlaOp a,XlaOp b)617 static XlaOp Lbeta(XlaOp a, XlaOp b) {
618 // Beta(a, b) can be computed using Gamma as per
619 // http://dlmf.nist.gov/5.12.E1 as follows:
620 // Beta(a, b) = (Gamma(a) * Gamma(b)) / Gamma(a + b)
621 //
622 // To avoid overflow, we compute in the log domain.
623 //
624 // As per http://dlmf.nist.gov/4.8.E2 we can transform:
625 // Log(a * b)
626 // into:
627 // Log(a) + Log(b)
628 //
629 // Likewise, per https://dlmf.nist.gov/4.8.E4, we can turn:
630 // Log(a - b)
631 // into:
632 // Log(a) - Log(b)
633 //
634 // This means that we can compute Log(Beta(a, b)) by:
635 // Log(Gamma(a)) + Log(Gamma(b)) - Log(Gamma(a + b))
636 return Lgamma(a) + Lgamma(b) - Lgamma(a + b);
637 }
638
639 // Compute the Digamma function using Lanczos' approximation from "A Precision
640 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
641 // series B. Vol. 1:
642 // digamma(z + 1) = log(t(z)) + A'(z) / A(z) - kLanczosGamma / t(z)
643 // t(z) = z + kLanczosGamma + 1/2
644 // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
645 // A'(z) = sigma(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
Digamma(XlaOp input)646 XlaOp Digamma(XlaOp input) {
647 auto do_it = [](XlaOp input) {
648 XlaOp zero = ScalarLike(input, 0);
649 XlaOp one_half = ScalarLike(input, 0.5);
650 XlaOp one = ScalarLike(input, 1);
651
652 XlaOp pi = ScalarLike(input, M_PI);
653
654 XlaOp lanczos_gamma = ScalarLike(input, kLanczosGamma);
655 XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5);
656 XlaOp log_lanczos_gamma_plus_one_half =
657 ScalarLike(input, std::log(kLanczosGamma + 0.5));
658
659 XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff);
660
661 // If the input is less than 0.5 use Euler's reflection formula:
662 // digamma(x) = digamma(1 - x) - pi * cot(pi * x)
663 XlaOp need_to_reflect = Lt(input, one_half);
664 XlaOp z = Select(need_to_reflect, -input, input - one);
665
666 XlaOp num = zero;
667 XlaOp denom = base_lanczos_coeff;
668 for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
669 XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]);
670 XlaOp index = ScalarLike(input, i);
671 num = num - lanczos_coefficient / ((z + index + one) * (z + index + one));
672 denom = denom + lanczos_coefficient / (z + index + one);
673 }
674
675 // To improve accuracy on platforms with less-precise log implementations,
676 // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
677 // the device.
678 // log(t) = log(kLanczosGamma + 0.5 + z)
679 // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
680 XlaOp t = lanczos_gamma_plus_one_half + z;
681 XlaOp log_t = log_lanczos_gamma_plus_one_half +
682 Log1p(z / lanczos_gamma_plus_one_half);
683
684 XlaOp y = log_t + num / denom - lanczos_gamma / t;
685
686 // We need to be careful how we compute cot(pi * input) below: For
687 // near-integral values of `input`, pi * input can lose precision.
688 //
689 // Input is already known to be less than 0.5 (otherwise we don't have to
690 // reflect). We shift values smaller than -0.5 into the range [-.5, .5] to
691 // increase precision of pi * input and the resulting cotangent.
692 XlaOp reduced_input = input + Abs(Floor(input + ScalarLike(input, 0.5)));
693 XlaOp reflection =
694 y - pi * Cos(pi * reduced_input) / Sin(pi * reduced_input);
695 XlaOp real_result = Select(need_to_reflect, reflection, y);
696
697 // Digamma has poles at negative integers and zero; return nan for those.
698 return Select(And(Le(input, zero), Eq(input, Floor(input))),
699 FullLike(input, std::numeric_limits<float>::quiet_NaN()),
700 real_result);
701 };
702
703 auto& b = *input.builder();
704 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
705 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input));
706 return DoWithUpcastToF32(input, {BF16, F16}, do_it);
707 });
708 }
709
710 // Incomplete gamma functions
711
712 namespace {
713
714 enum kIgammaMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE };
715
716 // Helper function for computing Igamma using a power series.
717 template <kIgammaMode mode>
IgammaSeries(XlaOp ax,XlaOp x,XlaOp a,XlaOp enabled,xla::PrimitiveType type)718 XlaOp IgammaSeries(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
719 xla::PrimitiveType type) {
720 // vals: (enabled, r, c, ans, x)
721 // 'enabled' is a predication mask that says for which elements we should
722 // execute the loop body. Disabled elements have no effect in the loop body.
723 // TODO(phawkins): in general this isn't an optimal implementation on any
724 // backend. For example, on GPU, we should probably vectorize to the warp
725 // size, and then run independent loops for each warp's worth of
726 // data.
727 auto cond = [&](absl::Span<const XlaOp> vals,
728 XlaBuilder* builder) -> StatusOr<XlaOp> {
729 XlaOp enabled = vals[0];
730 return Any(enabled);
731 };
732 auto body = [&](absl::Span<const XlaOp> vals,
733 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
734 XlaOp enabled = vals[0];
735 XlaOp r = vals[1];
736 XlaOp c = vals[2];
737 XlaOp ans = vals[3];
738 XlaOp x = vals[4];
739 XlaOp dc_da = vals[5];
740 XlaOp dans_da = vals[6];
741
742 r = r + ScalarLike(r, 1);
743 dc_da = dc_da * (x / r) + (ScalarLike(r, -1) * c * x) / (r * r);
744 dans_da = dans_da + dc_da;
745 c = c * (x / r);
746 ans = ans + c;
747 XlaOp conditional;
748 if (mode == VALUE) {
749 conditional = And(enabled, Gt(c / ans, Epsilon(builder, type)));
750 } else {
751 conditional =
752 And(enabled, Gt(Abs(dc_da / dans_da), Epsilon(builder, type)));
753 }
754
755 return std::vector<XlaOp>{
756 conditional,
757 Select(enabled, r, vals[1]),
758 Select(enabled, c, vals[2]),
759 Select(enabled, ans, vals[3]),
760 Select(enabled, x, vals[4]),
761 Select(enabled, dc_da, vals[5]),
762 Select(enabled, dans_da, vals[6]),
763 };
764 };
765 auto& b = *ax.builder();
766 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
767 std::vector<XlaOp> vals = {
768 enabled, a, FullLike(a, 1), FullLike(a, 1), x, FullLike(a, 0),
769 FullLike(a, 0),
770 };
771
772 TF_ASSIGN_OR_RETURN(vals, WhileLoopHelper(cond, body, vals, "igamma", &b));
773 XlaOp ans = vals[3];
774 XlaOp dans_da = vals[6];
775 if (mode == VALUE) {
776 return (ans * ax) / a;
777 }
778
779 XlaOp dlogax_da = Log(x) - Digamma(a + ScalarLike(a, 1));
780
781 switch (mode) {
782 case DERIVATIVE:
783 return ax * (ans * dlogax_da + dans_da) / a;
784 case SAMPLE_DERIVATIVE:
785 default:
786 return -(dans_da + ans * dlogax_da) * x / a;
787 }
788 });
789 }
790
791 // Helper function for computing Igammac using a continued fraction.
792 template <kIgammaMode mode>
IgammacContinuedFraction(XlaOp ax,XlaOp x,XlaOp a,XlaOp enabled,xla::PrimitiveType type)793 XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
794 xla::PrimitiveType type) {
795 // vals: enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2
796 auto cond = [&](absl::Span<const XlaOp> vals,
797 XlaBuilder* builder) -> StatusOr<XlaOp> {
798 XlaOp enabled = vals[0];
799 XlaOp c = vals[5];
800 return And(Lt(c, ScalarLike(c, 2000)), Any(enabled));
801 };
802 auto body = [&](absl::Span<const XlaOp> vals,
803 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
804 XlaOp enabled = vals[0];
805 XlaOp ans = vals[1];
806 XlaOp t = vals[2];
807 XlaOp y = vals[3];
808 XlaOp z = vals[4];
809 XlaOp c = vals[5];
810 XlaOp pkm1 = vals[6];
811 XlaOp qkm1 = vals[7];
812 XlaOp pkm2 = vals[8];
813 XlaOp qkm2 = vals[9];
814
815 XlaOp dpkm2_da = vals[10];
816 XlaOp dqkm2_da = vals[11];
817 XlaOp dpkm1_da = vals[12];
818 XlaOp dqkm1_da = vals[13];
819 XlaOp dans_da = vals[14];
820
821 c = c + ScalarLike(c, 1);
822 y = y + ScalarLike(y, 1);
823 z = z + ScalarLike(z, 2);
824 XlaOp yc = y * c;
825 XlaOp pk = pkm1 * z - pkm2 * yc;
826 XlaOp qk = qkm1 * z - qkm2 * yc;
827 XlaOp qk_is_nonzero = Ne(qk, ScalarLike(qk, 0));
828 XlaOp r = pk / qk;
829
830 t = Select(qk_is_nonzero, Abs((ans - r) / r), FullLike(t, 1));
831 ans = Select(qk_is_nonzero, r, ans);
832
833 XlaOp dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c;
834 XlaOp dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c;
835 XlaOp dans_da_new =
836 Select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da);
837 XlaOp grad_conditional =
838 Select(qk_is_nonzero, Abs(dans_da_new - dans_da), FullLike(dans_da, 1));
839
840 pkm2 = pkm1;
841 pkm1 = pk;
842 qkm2 = qkm1;
843 qkm1 = qk;
844
845 dpkm2_da = dpkm1_da;
846 dqkm2_da = dqkm1_da;
847 dpkm1_da = dpk_da;
848 dqkm1_da = dqk_da;
849
850 XlaOp rescale = Gt(Abs(pk), Reciprocal(Epsilon(builder, type)));
851 pkm2 = Select(rescale, pkm2 * Epsilon(builder, type), pkm2);
852 pkm1 = Select(rescale, pkm1 * Epsilon(builder, type), pkm1);
853 qkm2 = Select(rescale, qkm2 * Epsilon(builder, type), qkm2);
854 qkm1 = Select(rescale, qkm1 * Epsilon(builder, type), qkm1);
855
856 dpkm2_da = Select(rescale, dpkm2_da * Epsilon(builder, type), dpkm2_da);
857 dqkm2_da = Select(rescale, dqkm2_da * Epsilon(builder, type), dqkm2_da);
858 dpkm1_da = Select(rescale, dpkm1_da * Epsilon(builder, type), dpkm1_da);
859 dqkm1_da = Select(rescale, dqkm1_da * Epsilon(builder, type), dqkm1_da);
860
861 XlaOp conditional;
862 if (mode == VALUE) {
863 conditional = And(enabled, Gt(t, Epsilon(builder, type)));
864 } else {
865 conditional = And(enabled, Gt(grad_conditional, Epsilon(builder, type)));
866 }
867
868 return std::vector<XlaOp>{conditional,
869 Select(enabled, ans, vals[1]),
870 Select(enabled, t, vals[2]),
871 Select(enabled, y, vals[3]),
872 Select(enabled, z, vals[4]),
873 c,
874 Select(enabled, pkm1, vals[6]),
875 Select(enabled, qkm1, vals[7]),
876 Select(enabled, pkm2, vals[8]),
877 Select(enabled, qkm2, vals[9]),
878 Select(enabled, dpkm2_da, vals[10]),
879 Select(enabled, dqkm2_da, vals[11]),
880 Select(enabled, dpkm1_da, vals[12]),
881 Select(enabled, dqkm1_da, vals[13]),
882 Select(enabled, dans_da_new, vals[14])};
883 };
884
885 auto& b = *ax.builder();
886 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
887 XlaOp y = ScalarLike(a, 1) - a;
888 XlaOp z = x + y + ScalarLike(x, 1);
889 XlaOp c = ScalarLike(x, 0);
890 XlaOp pkm2 = FullLike(x, 1);
891 XlaOp qkm2 = x;
892 XlaOp pkm1 = x + ScalarLike(x, 1);
893 XlaOp qkm1 = z * x;
894 XlaOp ans = pkm1 / qkm1;
895 XlaOp t = FullLike(x, 1);
896 XlaOp dpkm2_da = FullLike(x, 0);
897 XlaOp dqkm2_da = FullLike(x, 0);
898 XlaOp dpkm1_da = FullLike(x, 0);
899 XlaOp dqkm1_da = -x;
900 XlaOp dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
901 std::vector<XlaOp> vals = {enabled, ans, t, y, z,
902 c, pkm1, qkm1, pkm2, qkm2,
903 dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da};
904
905 TF_ASSIGN_OR_RETURN(vals, WhileLoopHelper(cond, body, vals, "igammac", &b));
906 ans = vals[1];
907 if (mode == VALUE) {
908 return ans * ax;
909 }
910
911 dans_da = vals[14];
912 XlaOp dlogax_da = Log(x) - Digamma(a);
913
914 switch (mode) {
915 case DERIVATIVE:
916 return ax * (ans * dlogax_da + dans_da);
917 case SAMPLE_DERIVATIVE:
918 default:
919 return -(dans_da + ans * dlogax_da) * x;
920 }
921 });
922 }
923
924 } // namespace
925
Igamma(XlaOp a,XlaOp x)926 XlaOp Igamma(XlaOp a, XlaOp x) {
927 auto& b = *a.builder();
928 auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
929 XlaOp is_nan = Or(IsNan(a), IsNan(x));
930 XlaOp x_is_zero = Eq(x, ScalarLike(x, 0));
931 XlaOp x_is_infinity =
932 Eq(x, ScalarLike(x, std::numeric_limits<float>::infinity()));
933 XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
934 XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a));
935 XlaOp ax = a * Log(x) - x - Lgamma(a);
936 XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
937 ax = Exp(ax);
938 XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan));
939 const double nan = std::numeric_limits<double>::quiet_NaN();
940 XlaOp output = Select(
941 use_igammac,
942 ScalarLike(a, 1) - IgammacContinuedFraction<VALUE>(
943 ax, x, a, And(enabled, use_igammac), type),
944 IgammaSeries<VALUE>(ax, x, a, And(enabled, Not(use_igammac)), type));
945 output = Select(x_is_zero, ZerosLike(output), output);
946 output = Select(x_is_infinity, FullLike(output, 1), output);
947 output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
948 return output;
949 };
950 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
951 TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
952 TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
953 if (a_shape != x_shape) {
954 return InvalidArgument(
955 "Arguments to Igamma must have equal shapes and types; got %s and %s",
956 a_shape.ToString(), x_shape.ToString());
957 }
958 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a));
959 PrimitiveType a_x_type = a_shape.element_type();
960 bool needs_upcast =
961 a_shape.element_type() == F16 || a_shape.element_type() == BF16;
962
963 if (needs_upcast) {
964 a = ConvertElementType(a, F32);
965 x = ConvertElementType(x, F32);
966 a_x_type = F32;
967 }
968 XlaOp result = doit(a, x, a_x_type);
969 if (needs_upcast) {
970 result = ConvertElementType(result, a_shape.element_type());
971 }
972 return result;
973 });
974 }
975
IgammaGradA(XlaOp a,XlaOp x)976 XlaOp IgammaGradA(XlaOp a, XlaOp x) {
977 auto& b = *a.builder();
978 auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
979 XlaOp is_nan = Or(IsNan(a), IsNan(x));
980 XlaOp x_is_zero = Eq(x, ScalarLike(x, 0));
981 XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
982 XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a));
983 XlaOp ax = a * Log(x) - x - Lgamma(a);
984 XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
985 ax = Exp(ax);
986 XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan));
987 const double nan = std::numeric_limits<double>::quiet_NaN();
988 XlaOp output = Select(use_igammac,
989 -IgammacContinuedFraction<DERIVATIVE>(
990 ax, x, a, And(enabled, use_igammac), type),
991 IgammaSeries<DERIVATIVE>(
992 ax, x, a, And(enabled, Not(use_igammac)), type));
993 output = Select(x_is_zero, ZerosLike(output), output);
994 output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
995 return output;
996 };
997 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
998 TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
999 TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
1000 if (a_shape != x_shape) {
1001 return InvalidArgument(
1002 "Arguments to IgammaGradA must have equal shapes and types; got %s "
1003 "and %s",
1004 a_shape.ToString(), x_shape.ToString());
1005 }
1006 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
1007 bool needs_upcast =
1008 a_shape.element_type() == F16 || a_shape.element_type() == BF16;
1009
1010 if (needs_upcast) {
1011 a = ConvertElementType(a, F32);
1012 x = ConvertElementType(x, F32);
1013 }
1014 XlaOp result = doit(a, x, a_shape.element_type());
1015 if (needs_upcast) {
1016 result = ConvertElementType(result, a_shape.element_type());
1017 }
1018 return result;
1019 });
1020 }
1021
1022 // Gradient of Gamma sample from Gamma(a, 1) with respect to `a`.
RandomGammaGrad(XlaOp a,XlaOp x)1023 XlaOp RandomGammaGrad(XlaOp a, XlaOp x) {
1024 auto& b = *a.builder();
1025 auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
1026 XlaOp is_nan = Or(IsNan(a), IsNan(x));
1027 XlaOp x_is_zero = Eq(x, ScalarLike(x, 0));
1028 XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
1029 XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a));
1030 XlaOp ax = a * Log(x) - x - Lgamma(a);
1031 XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
1032 ax = Exp(ax);
1033 XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan));
1034 const double nan = std::numeric_limits<double>::quiet_NaN();
1035 XlaOp output = Select(use_igammac,
1036 -IgammacContinuedFraction<SAMPLE_DERIVATIVE>(
1037 ax, x, a, And(enabled, use_igammac), type),
1038 IgammaSeries<SAMPLE_DERIVATIVE>(
1039 ax, x, a, And(enabled, Not(use_igammac)), type));
1040 output = Select(x_is_zero, ZerosLike(output), output);
1041 output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
1042 return output;
1043 };
1044 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1045 TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
1046 TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
1047 if (a_shape != x_shape) {
1048 return InvalidArgument(
1049 "Arguments to RandomGammaGrad must have equal shapes and types; got "
1050 "%s and %s",
1051 a_shape.ToString(), x_shape.ToString());
1052 }
1053 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RandomGammaGrad", a));
1054 bool needs_upcast =
1055 a_shape.element_type() == F16 || a_shape.element_type() == BF16;
1056
1057 if (needs_upcast) {
1058 a = ConvertElementType(a, F32);
1059 x = ConvertElementType(x, F32);
1060 }
1061 XlaOp result = doit(a, x, a_shape.element_type());
1062 if (needs_upcast) {
1063 result = ConvertElementType(result, a_shape.element_type());
1064 }
1065 return result;
1066 });
1067 }
1068
Igammac(XlaOp a,XlaOp x)1069 XlaOp Igammac(XlaOp a, XlaOp x) {
1070 auto& b = *a.builder();
1071 auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
1072 XlaOp out_of_range = Or(Le(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
1073 XlaOp use_igamma = Or(Lt(x, ScalarLike(x, 1)), Lt(x, a));
1074 XlaOp ax = a * Log(x) - x - Lgamma(a);
1075 XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
1076 XlaOp enabled = Not(Or(out_of_range, underflow));
1077 ax = Exp(ax);
1078 XlaOp result =
1079 Select(use_igamma,
1080 ScalarLike(a, 1) - IgammaSeries<VALUE>(
1081 ax, x, a, And(enabled, use_igamma), type),
1082 IgammacContinuedFraction<VALUE>(
1083 ax, x, a, And(enabled, Not(use_igamma)), type));
1084 XlaOp x_is_infinity =
1085 Eq(x, ScalarLike(x, std::numeric_limits<float>::infinity()));
1086 result = Select(x_is_infinity, ZerosLike(result), result);
1087 return Select(out_of_range, FullLike(a, 1), result);
1088 };
1089 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1090 TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
1091 TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
1092 if (a_shape != x_shape) {
1093 return InvalidArgument(
1094 "Arguments to Igammac must have equal shapes and types; "
1095 "got %s and %s",
1096 a_shape.ToString(), x_shape.ToString());
1097 }
1098 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igammac", a));
1099 PrimitiveType a_x_type = a_shape.element_type();
1100 bool needs_upcast =
1101 a_shape.element_type() == F16 || a_shape.element_type() == BF16;
1102
1103 if (needs_upcast) {
1104 a = ConvertElementType(a, F32);
1105 x = ConvertElementType(x, F32);
1106 a_x_type = F32;
1107 }
1108 XlaOp result = doit(a, x, a_x_type);
1109 if (needs_upcast) {
1110 result = ConvertElementType(result, a_shape.element_type());
1111 }
1112 return result;
1113 });
1114 }
1115
1116 // Implements Banker's rounding: numbers that are equidistant between two
1117 // integers are rounded towards even.
RoundToEven(XlaOp x)1118 XlaOp RoundToEven(XlaOp x) {
1119 auto& b = *x.builder();
1120 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1121 // Reject non-real non-fp inputs (What does it even mean to round a complex
1122 // number? Do you round each component equally? In that case, you should
1123 // just ask for that explicitly.)
1124 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RoundToEven", x));
1125
1126 auto half = ScalarLike(x, 0.5);
1127 auto one = ScalarLike(x, 1.0);
1128 auto two = ScalarLike(x, 2.0);
1129
1130 auto round_val = Floor(x);
1131 auto fraction = x - round_val;
1132 auto nearest_even_int = round_val - two * Floor(half * x);
1133 auto is_odd = Eq(nearest_even_int, one);
1134 return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)),
1135 round_val + one, round_val);
1136 });
1137 }
1138
1139 // Trigonometric functions.
1140
1141 // acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1
1142 // pi if x == -1
1143 // For complex:
1144 // acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x))))
Acos(XlaOp x)1145 XlaOp Acos(XlaOp x) {
1146 XlaBuilder* b = x.builder();
1147 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1148 TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
1149
1150 if (primitive_util::IsComplexType(shape.element_type())) {
1151 auto one = ScalarLike(x, 1);
1152 auto imag_one = Complex(
1153 Zero(b, primitive_util::ComplexComponentType(shape.element_type())),
1154 One(b, primitive_util::ComplexComponentType(shape.element_type())));
1155
1156 auto result =
1157 Neg(imag_one * Log(x + imag_one * Sqrt((one + x) * (one - x))));
1158 return result;
1159 }
1160 return Select(Ne(x, FullLike(x, -1)),
1161 ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x),
1162 ScalarLike(x, 1.0) + x),
1163 FullLike(x, M_PI));
1164 });
1165 }
1166
1167 // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
Asin(XlaOp x)1168 XlaOp Asin(XlaOp x) {
1169 return ScalarLike(x, 2.0) *
1170 Atan2(x, ScalarLike(x, 1.0) + Sqrt(ScalarLike(x, 1.0) - x * x));
1171 }
1172
Atan(XlaOp x)1173 XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); }
1174
Tan(XlaOp x)1175 XlaOp Tan(XlaOp x) {
1176 return DoWithUpcastToF32(x, {F16}, [](XlaOp x) { return Sin(x) / Cos(x); });
1177 }
1178
1179 // Hyperbolic trigonometric functions.
1180
1181 // acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1
1182 // = log(x + sqrt((x+1)*(x-1)))
1183 // acosh(x) = nan if x < -1
1184 //
1185 // If x^2 will overflow, we approximate sqrt(x^2 - 1) == x and compute as
1186 // log(2*x) = log(2) + log(x). (Note this works because negative x never
1187 // overflows; x < -1 simply yields nan. This is quite different than asinh!)
Acosh(XlaOp x)1188 XlaOp Acosh(XlaOp x) {
1189 XlaBuilder* b = x.builder();
1190 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1191 TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
1192
1193 auto one = ScalarLike(x, 1);
1194 auto neg_one = ScalarLike(x, -1);
1195 auto nan = FullLike(x, std::numeric_limits<float>::quiet_NaN());
1196
1197 // return
1198 //
1199 // nan if x < -1
1200 // log(x) + log(2) if x >= sqrt_max_value
1201 // log(x + sqrt((x+1)*(x-1))) otherwise
1202 //
1203 // TODO(jlebar): For now, we ignore the question of overflow if x is a
1204 // complex type, because we don't yet have exhaustive tests for complex trig
1205 // functions.
1206 auto naive_result = Log(x + Sqrt((x + one) * (x - one)));
1207 if (primitive_util::IsComplexType(shape.element_type())) {
1208 return naive_result;
1209 }
1210 auto overflow_result = Log(x) + Log(ScalarLike(x, 2));
1211
1212 auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type()));
1213 return Select(Lt(x, neg_one), nan,
1214 Select(Ge(x, sqrt_max_value), overflow_result, naive_result));
1215 });
1216 }
1217
1218 // asinh(x) = log(x + sqrt(x^2 + 1))
1219 //
1220 // If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1)
1221 // as 2*x and return log(2) + log(x).
1222 //
1223 // If x is negative, the above would give us some trouble; we can't approximate
1224 // the result as x + abs(x) = 0! But we're saved by the fact that asinh(-x) =
1225 // -asinh(x).
Asinh(XlaOp x)1226 XlaOp Asinh(XlaOp x) {
1227 XlaBuilder* b = x.builder();
1228 auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
1229 TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
1230 auto one = ScalarLike(x, 1);
1231
1232 // Let a = abs(x). Compute
1233 //
1234 // y = log(a + sqrt(a*a + 1)) if a < sqrt_max_value, or
1235 // y = log(a) + log(2) otherwise
1236 //
1237 // and then return
1238 //
1239 // y * sign(x).
1240 //
1241 // TODO(jlebar): For now, we ignore the question of overflow if x is a
1242 // complex type, because we don't yet have exhaustive tests for complex trig
1243 // functions.
1244 if (primitive_util::IsComplexType(shape.element_type())) {
1245 return Log(x + Sqrt(x * x + one));
1246 }
1247 // For small x, sqrt(x**2 + 1) will evaluate to 1 due to floating point
1248 // arithmetic. However, we would like to retain the low order term of this,
1249 // which is around 0.5 * x**2 using a binomial expansion.
1250 // Let z = sqrt(a**2 + 1)
1251 // log(a + sqrt(a**2 + 1)) =
1252 // log((a + sqrt(a**2 + 1)) * (1 + sqrt(a**2 + 1)) / (1 + sqrt(a**2 + 1))) =
1253 // log((a + a**2 + 1 + a * z + z) / (1 + z)) =
1254 // log(1 + a + a**2 / (1 + z)) =
1255 // log(1 + a + a ** 2 / (1 + sqrt(a**2 + 1)))
1256 // This rewrite retains the lower order term.
1257 auto a = Abs(x);
1258 auto small_result = Log1p(a + a * a / (one + Sqrt(a * a + one)));
1259 auto naive_result = Log(a + Sqrt(a * a + one));
1260 auto overflow_result = Log(Abs(a)) + Log(ScalarLike(a, 2));
1261 auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type()));
1262 return Sign(x) * Select(Ge(a, sqrt_max_value), overflow_result,
1263 Select(Le(a, one), small_result, naive_result));
1264 };
1265 // These upcasts are not strictly necessary on all platforms to get within our
1266 // error tolerances, so we could relax this if it ever mattered.
1267 return DoWithUpcastToF32(x, {BF16, F16}, [&](XlaOp x) {
1268 return b->ReportErrorOrReturn(do_it(x));
1269 });
1270 }
1271
1272 // atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
1273 // atanh(x) = nan otherwise
Atanh(XlaOp x)1274 XlaOp Atanh(XlaOp x) {
1275 XlaBuilder* b = x.builder();
1276 auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
1277 TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
1278 auto naive_result = (Log1p(x) - Log1p(-x)) * ScalarLike(x, 0.5);
1279
1280 // TODO(jlebar): For now, we ignore the nan edge case for complex inputs,
1281 // because we don't yet have exhaustive tests for complex trig functions.
1282 if (primitive_util::IsComplexType(shape.element_type())) {
1283 return naive_result;
1284 }
1285
1286 auto nan = FullLike(x, std::numeric_limits<float>::quiet_NaN());
1287 return Select(Gt(Abs(x), ScalarLike(x, 1)), nan, naive_result);
1288 };
1289 return DoWithUpcastToF32(x, {BF16}, [&](XlaOp x) { //
1290 return b->ReportErrorOrReturn(do_it(x));
1291 });
1292 }
1293
1294 // Cosh(x) = (e^x + e^-x) / 2
1295 // = e^(x + log(1/2)) + e^(-x + log(1/2)).
1296 //
1297 // The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
1298 // inf.
1299 //
1300 // This incorrectly overflows to inf for two f32 input values, namely
1301 // +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
1302 // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
1303 // we deem this acceptable.
Cosh(XlaOp x)1304 XlaOp Cosh(XlaOp x) {
1305 return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
1306 auto log_one_half = Log(ScalarLike(x, 0.5));
1307 return Exp(x + log_one_half) + Exp(-x + log_one_half);
1308 });
1309 }
1310
1311 // Sinh(x) = (e^x - e^-x) / 2
1312 // = e^(x + log(1/2)) - e^(-x + log(1/2)).
1313 //
1314 // The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
1315 // inf.
1316 //
1317 // This incorrectly overflows to +/-inf for two f32 input values, namely
1318 // +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
1319 // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
1320 // we deem this acceptable.
Sinh(XlaOp x)1321 XlaOp Sinh(XlaOp x) {
1322 XlaBuilder* b = x.builder();
1323 auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
1324 TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
1325 auto one_half = ScalarLike(x, 0.5);
1326 auto log_one_half = Log(ScalarLike(x, 0.5));
1327 auto large_sinh_result = Exp(x + log_one_half) - Exp(-x + log_one_half);
1328
1329 if (primitive_util::IsComplexType(shape.element_type())) {
1330 return large_sinh_result;
1331 }
1332
1333 // Here we use e^x = e^(x / 2) * e^(x / 2). This avoids overflow for large
1334 // values of x.
1335
1336 // For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in
1337 // 0.
1338 // Rewrite this to avoid that. We use expm1(x) because that preserves the
1339 // first order term of the taylor series of e^x.
1340 // (e^(x) - e^(-x)) / 2. =
1341 // (e^(x) - 1 + 1 - e^(-x)) / 2.
1342 // (expm1(x) + (e^(x) - 1) / e^x) / 2.
1343 // (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2.
1344 auto expm1 = Expm1(x);
1345 auto one = ScalarLike(x, 1.);
1346 auto small_sinh_result = one_half * (expm1 + expm1 / (expm1 + one));
1347 return Select(Lt(Abs(x), one), small_sinh_result, large_sinh_result);
1348 };
1349 return DoWithUpcastToF32(x, {BF16, F16}, [&](XlaOp x) {
1350 return b->ReportErrorOrReturn(do_it(x));
1351 });
1352 }
1353
MaybeConjugate(XlaOp x,bool conjugate)1354 XlaOp MaybeConjugate(XlaOp x, bool conjugate) {
1355 XlaBuilder* builder = x.builder();
1356 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1357 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
1358 auto perform_conj =
1359 primitive_util::IsComplexType(shape.element_type()) && conjugate;
1360 return perform_conj ? Conj(x) : x;
1361 });
1362 }
1363
NextAfter(XlaOp from,XlaOp to)1364 XlaOp NextAfter(XlaOp from, XlaOp to) {
1365 auto builder = from.builder();
1366 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1367 TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(from));
1368 int bitwidth = primitive_util::BitWidth(shape.element_type());
1369 auto int_type = primitive_util::UnsignedIntegralTypeForBitWidth(bitwidth);
1370 auto from_as_int = BitcastConvertType(from, int_type);
1371 auto to_as_int = BitcastConvertType(to, int_type);
1372
1373 // The result is NaN if either "from" or "to" are NaN.
1374 auto from_is_nan = Ne(from, from);
1375 auto to_is_nan = Ne(to, to);
1376 auto nan_input = Or(from_is_nan, to_is_nan);
1377 auto result_for_nan =
1378 Broadcast(ScalarLike(from, std::numeric_limits<double>::quiet_NaN()),
1379 shape.dimensions());
1380 result_for_nan = BitcastConvertType(result_for_nan, int_type);
1381
1382 // The sign bit is the MSB.
1383 const int64_t sign_mask = int64{1} << (bitwidth - 1);
1384 // Discard the sign bit to make the result non-negative.
1385 auto from_abs = And(from_as_int, ScalarLike(from_as_int, ~sign_mask));
1386 auto to_abs = And(to_as_int, ScalarLike(to_as_int, ~sign_mask));
1387
1388 // When both "from" and "to" are equal, the result is "to".
1389 // N.B. It would not make a difference if we chose the result to be "from".
1390 auto from_and_to_are_equal = Eq(from_as_int, to_as_int);
1391 auto result_for_equal = to_as_int;
1392
1393 // When both "from" and "to" are both 0, the result is "to". This ensures we
1394 // get a zero signed like "to".
1395 auto from_is_zero = Eq(from_abs, ZerosLike(from_abs));
1396 auto to_is_zero = Eq(to_abs, ZerosLike(to_abs));
1397 auto result_for_both_zero = to_as_int;
1398
1399 auto from_sign = And(from_as_int, ScalarLike(from_as_int, sign_mask));
1400 auto to_sign = And(to_as_int, ScalarLike(to_as_int, sign_mask));
1401
1402 // If from == 0 && to != 0, we need to return the smallest subnormal number
1403 // signed like "to".
1404 auto result_for_from_zero_to_non_zero =
1405 Or(to_sign, ScalarLike(from_as_int, 1));
1406
1407 // If the sign of "from" and "to" disagree:
1408 // - we need to make the magnitude of "from" smaller so that it is closer to
1409 // zero.
1410 //
1411 // Otherwise the signs agree:
1412 // - "from" with a magnitude larger than "to" means we need to make the
1413 // magnitude smaller.
1414 // - "from" with a magnitude smaller than "to" means we need to make the
1415 // magnitude larger.
1416 // - "from" with the same magnitude and sign as "to" has already been
1417 // handled.
1418 auto signs_disagree = Ne(from_sign, to_sign);
1419 auto from_magnitude_larger_than_to = Gt(from_abs, to_abs);
1420 auto result_has_smaller_magnitude =
1421 Or(from_magnitude_larger_than_to, signs_disagree);
1422 auto magnitude_adjustment =
1423 Select(result_has_smaller_magnitude,
1424 Broadcast(ScalarLike(from_as_int, -1), shape.dimensions()),
1425 Broadcast(ScalarLike(from_as_int, 1), shape.dimensions()));
1426 auto result = Add(from_as_int, magnitude_adjustment);
1427 // Handle from == ±0.
1428 result = Select(from_is_zero,
1429 Select(to_is_zero, result_for_both_zero,
1430 result_for_from_zero_to_non_zero),
1431 result);
1432 // Handle from == to.
1433 result = Select(from_and_to_are_equal, result_for_equal, result);
1434 // Handle isnan(from) || isnan(to).
1435 result = Select(nan_input, result_for_nan, result);
1436
1437 // Cast back to the original type.
1438 return BitcastConvertType(result, shape.element_type());
1439 });
1440 }
1441
1442 // Computes an approximation to the modified Bessel function of the first kind,
1443 // zeroth order.
1444 // The following implementation follows Cephes' F32 and F64 implementation of
1445 // i0e.
I0eImpl32(XlaOp x)1446 static XlaOp I0eImpl32(XlaOp x) {
1447 static const std::array<float, 18> kI0eCoeffsA{
1448 -1.30002500998624804212E-8f, 6.04699502254191894932E-8f,
1449 -2.67079385394061173391E-7f, 1.11738753912010371815E-6f,
1450 -4.41673835845875056359E-6f, 1.64484480707288970893E-5f,
1451 -5.75419501008210370398E-5f, 1.88502885095841655729E-4f,
1452 -5.76375574538582365885E-4f, 1.63947561694133579842E-3f,
1453 -4.32430999505057594430E-3f, 1.05464603945949983183E-2f,
1454 -2.37374148058994688156E-2f, 4.93052842396707084878E-2f,
1455 -9.49010970480476444210E-2f, 1.71620901522208775349E-1f,
1456 -3.04682672343198398683E-1f, 6.76795274409476084995E-1f};
1457
1458 static const std::array<float, 7> kI0eCoeffsB{
1459 3.39623202570838634515E-9f, 2.26666899049817806459E-8f,
1460 2.04891858946906374183E-7f, 2.89137052083475648297E-6f,
1461 6.88975834691682398426E-5f, 3.36911647825569408990E-3f,
1462 8.04490411014108831608E-1f};
1463
1464 x = Abs(x);
1465 auto half = xla::ScalarLike(x, 0.5);
1466 auto two = xla::ScalarLike(x, 2.0);
1467 auto thirty_two = xla::ScalarLike(x, 32.0);
1468 auto result_le_8 =
1469 EvaluateChebyshevPolynomial<float>(half * x - two, kI0eCoeffsA);
1470 auto result_gt_8 =
1471 EvaluateChebyshevPolynomial<float>(thirty_two / x - two, kI0eCoeffsB) /
1472 Sqrt(x);
1473 return Select(Le(x, xla::ScalarLike(x, 8.0)), result_le_8, result_gt_8);
1474 }
1475
I0eImpl64(XlaOp x)1476 static XlaOp I0eImpl64(XlaOp x) {
1477 static const std::array<double, 30> kI0eCoeffsA{
1478 -4.41534164647933937950E-18, 3.33079451882223809783E-17,
1479 -2.43127984654795469359E-16, 1.71539128555513303061E-15,
1480 -1.16853328779934516808E-14, 7.67618549860493561688E-14,
1481 -4.85644678311192946090E-13, 2.95505266312963983461E-12,
1482 -1.72682629144155570723E-11, 9.67580903537323691224E-11,
1483 -5.18979560163526290666E-10, 2.65982372468238665035E-9,
1484 -1.30002500998624804212E-8, 6.04699502254191894932E-8,
1485 -2.67079385394061173391E-7, 1.11738753912010371815E-6,
1486 -4.41673835845875056359E-6, 1.64484480707288970893E-5,
1487 -5.75419501008210370398E-5, 1.88502885095841655729E-4,
1488 -5.76375574538582365885E-4, 1.63947561694133579842E-3,
1489 -4.32430999505057594430E-3, 1.05464603945949983183E-2,
1490 -2.37374148058994688156E-2, 4.93052842396707084878E-2,
1491 -9.49010970480476444210E-2, 1.71620901522208775349E-1,
1492 -3.04682672343198398683E-1, 6.76795274409476084995E-1};
1493
1494 static const std::array<double, 25> kI0eCoeffsB{
1495 -7.23318048787475395456E-18, -4.83050448594418207126E-18,
1496 4.46562142029675999901E-17, 3.46122286769746109310E-17,
1497 -2.82762398051658348494E-16, -3.42548561967721913462E-16,
1498 1.77256013305652638360E-15, 3.81168066935262242075E-15,
1499 -9.55484669882830764870E-15, -4.15056934728722208663E-14,
1500 1.54008621752140982691E-14, 3.85277838274214270114E-13,
1501 7.18012445138366623367E-13, -1.79417853150680611778E-12,
1502 -1.32158118404477131188E-11, -3.14991652796324136454E-11,
1503 1.18891471078464383424E-11, 4.94060238822496958910E-10,
1504 3.39623202570838634515E-9, 2.26666899049817806459E-8,
1505 2.04891858946906374183E-7, 2.89137052083475648297E-6,
1506 6.88975834691682398426E-5, 3.36911647825569408990E-3,
1507 8.04490411014108831608E-1};
1508
1509 x = Abs(x);
1510 auto half = xla::ScalarLike(x, 0.5);
1511 auto two = xla::ScalarLike(x, 2.0);
1512 auto thirty_two = xla::ScalarLike(x, 32.0);
1513 auto result_le_8 =
1514 EvaluateChebyshevPolynomial<double>(half * x - two, kI0eCoeffsA);
1515 auto result_gt_8 =
1516 EvaluateChebyshevPolynomial<double>(thirty_two / x - two, kI0eCoeffsB) /
1517 Sqrt(x);
1518 return Select(Le(x, xla::ScalarLike(x, 8.0)), result_le_8, result_gt_8);
1519 }
1520
BesselI0e(XlaOp x)1521 XlaOp BesselI0e(XlaOp x) {
1522 auto& b = *x.builder();
1523 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1524 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("BesselI0e", x));
1525 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
1526 if (shape.element_type() == F64) {
1527 return I0eImpl64(x);
1528 }
1529 // I0eF32Impl don't have enough precision when run with bf16 intermediates
1530 // (not surprising!), so upcast to f32 in this case.
1531 return DoWithUpcastToF32(x, {BF16, F16},
1532 [](XlaOp x) { return I0eImpl32(x); });
1533 });
1534 }
1535
1536 // Computes an approximation to the modified Bessel function of the first kind,
1537 // first order.
1538 // The following implementation follows Cephes' F32 and F64 implementation of
1539 // i1e.
1540
I1eImpl32(XlaOp x)1541 static XlaOp I1eImpl32(XlaOp x) {
1542 static const std::array<float, 17> kI1eCoeffsA{
1543 9.38153738649577178388E-9f, -4.44505912879632808065E-8f,
1544 2.00329475355213526229E-7f, -8.56872026469545474066E-7f,
1545 3.47025130813767847674E-6f, -1.32731636560394358279E-5f,
1546 4.78156510755005422638E-5f, -1.61760815825896745588E-4f,
1547 5.12285956168575772895E-4f, -1.51357245063125314899E-3f,
1548 4.15642294431288815669E-3f, -1.05640848946261981558E-2f,
1549 2.47264490306265168283E-2f, -5.29459812080949914269E-2f,
1550 1.02643658689847095384E-1f, -1.76416518357834055153E-1f,
1551 2.52587186443633654823E-1f};
1552
1553 static const std::array<float, 7> kI1eCoeffsB{
1554 -3.83538038596423702205E-9f, -2.63146884688951950684E-8f,
1555 -2.51223623787020892529E-7f, -3.88256480887769039346E-6f,
1556 -1.10588938762623716291E-4f, -9.76109749136146840777E-3f,
1557 7.78576235018280120474E-1f};
1558 XlaOp z = Abs(x);
1559 auto half = xla::ScalarLike(x, 0.5);
1560 auto two = xla::ScalarLike(x, 2.0);
1561 auto thirty_two = xla::ScalarLike(x, 32.0);
1562 auto result_le_8 =
1563 z * EvaluateChebyshevPolynomial<float>(half * z - two, kI1eCoeffsA);
1564 auto result_gt_8 =
1565 EvaluateChebyshevPolynomial<float>(thirty_two / z - two, kI1eCoeffsB) /
1566 Sqrt(z);
1567 return Sign(x) *
1568 Select(Le(z, xla::ScalarLike(x, 8.0)), result_le_8, result_gt_8);
1569 }
1570
I1eImpl64(XlaOp x)1571 static XlaOp I1eImpl64(XlaOp x) {
1572 static const std::array<double, 29> kI1eCoeffsA{
1573 2.77791411276104639959E-18, -2.11142121435816608115E-17,
1574 1.55363195773620046921E-16, -1.10559694773538630805E-15,
1575 7.60068429473540693410E-15, -5.04218550472791168711E-14,
1576 3.22379336594557470981E-13, -1.98397439776494371520E-12,
1577 1.17361862988909016308E-11, -6.66348972350202774223E-11,
1578 3.62559028155211703701E-10, -1.88724975172282928790E-9,
1579 9.38153738649577178388E-9, -4.44505912879632808065E-8,
1580 2.00329475355213526229E-7, -8.56872026469545474066E-7,
1581 3.47025130813767847674E-6, -1.32731636560394358279E-5,
1582 4.78156510755005422638E-5, -1.61760815825896745588E-4,
1583 5.12285956168575772895E-4, -1.51357245063125314899E-3,
1584 4.15642294431288815669E-3, -1.05640848946261981558E-2,
1585 2.47264490306265168283E-2, -5.29459812080949914269E-2,
1586 1.02643658689847095384E-1, -1.76416518357834055153E-1,
1587 2.52587186443633654823E-1};
1588
1589 static const std::array<double, 25> kI1eCoeffsB{
1590 7.51729631084210481353E-18, 4.41434832307170791151E-18,
1591 -4.65030536848935832153E-17, -3.20952592199342395980E-17,
1592 2.96262899764595013876E-16, 3.30820231092092828324E-16,
1593 -1.88035477551078244854E-15, -3.81440307243700780478E-15,
1594 1.04202769841288027642E-14, 4.27244001671195135429E-14,
1595 -2.10154184277266431302E-14, -4.08355111109219731823E-13,
1596 -7.19855177624590851209E-13, 2.03562854414708950722E-12,
1597 1.41258074366137813316E-11, 3.25260358301548823856E-11,
1598 -1.89749581235054123450E-11, -5.58974346219658380687E-10,
1599 -3.83538038596423702205E-9, -2.63146884688951950684E-8,
1600 -2.51223623787020892529E-7, -3.88256480887769039346E-6,
1601 -1.10588938762623716291E-4, -9.76109749136146840777E-3,
1602 7.78576235018280120474E-1};
1603
1604 XlaOp z = Abs(x);
1605 auto half = xla::ScalarLike(x, 0.5);
1606 auto two = xla::ScalarLike(x, 2.0);
1607 auto thirty_two = xla::ScalarLike(x, 32.0);
1608 auto result_le_8 =
1609 z * EvaluateChebyshevPolynomial<double>(half * z - two, kI1eCoeffsA);
1610 auto result_gt_8 =
1611 EvaluateChebyshevPolynomial<double>(thirty_two / z - two, kI1eCoeffsB) /
1612 Sqrt(z);
1613 return Sign(x) *
1614 Select(Le(z, xla::ScalarLike(x, 8.0)), result_le_8, result_gt_8);
1615 }
1616
BesselI1e(XlaOp x)1617 XlaOp BesselI1e(XlaOp x) {
1618 auto& b = *x.builder();
1619 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1620 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("BesselI1e", x));
1621 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
1622 if (shape.element_type() == F64) {
1623 return I1eImpl64(x);
1624 }
1625 // I1eF32Impl don't have enough precision when run with bf16 intermediates
1626 // (not surprising!), so upcast to f32 in this case.
1627 return DoWithUpcastToF32(x, {BF16, F16},
1628 [](XlaOp x) { return I1eImpl32(x); });
1629 });
1630 }
1631
1632 // I J Thompson and A R Barnett. 1986. Coulomb and Bessel functions of complex
1633 // arguments and order. J. Comput. Phys. 64, 2 (June 1986), 490-509.
1634 // DOI=http://dx.doi.org/10.1016/0021-9991(86)90046-X
LentzThompsonBarnettAlgorithm(int64_t num_iterations,double small,double threshold,const ForEachIndexBodyFunction & nth_partial_numerator,const ForEachIndexBodyFunction & nth_partial_denominator,absl::Span<const XlaOp> inputs,absl::string_view name)1635 static XlaOp LentzThompsonBarnettAlgorithm(
1636 int64_t num_iterations, double small, double threshold,
1637 const ForEachIndexBodyFunction& nth_partial_numerator,
1638 const ForEachIndexBodyFunction& nth_partial_denominator,
1639 absl::Span<const XlaOp> inputs, absl::string_view name) {
1640 auto& b = *inputs.front().builder();
1641 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1642 TF_RET_CHECK(num_iterations < INT32_MAX);
1643
1644 enum {
1645 // Position in the evaluation.
1646 kIterationIdx,
1647 // Whether or not we have reached the desired tolerance.
1648 kValuesUnconvergedIdx,
1649 // Ratio between nth canonical numerator and the nth-1 canonical
1650 // numerator.
1651 kCIdx,
1652 // Ratio between nth-1 canonical denominator and the nth canonical
1653 // denominator.
1654 kDIdx,
1655 // Computed approximant in the evaluation.
1656 kHIdx,
1657 // Inputs follow all of the other state.
1658 kFirstInputIdx,
1659 };
1660 auto while_cond_fn = [num_iterations](
1661 absl::Span<const XlaOp> values,
1662 XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
1663 auto iteration = values[kIterationIdx];
1664 auto iterations_remain_cond =
1665 Lt(iteration, ScalarLike(iteration, num_iterations));
1666 auto values_unconverged_cond = values[kValuesUnconvergedIdx];
1667 return And(iterations_remain_cond, values_unconverged_cond);
1668 };
1669
1670 auto while_body_fn =
1671 [small, threshold, &nth_partial_numerator, &nth_partial_denominator](
1672 absl::Span<const XlaOp> values,
1673 XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
1674 XlaOp iteration = values[kIterationIdx];
1675
1676 TF_ASSIGN_OR_RETURN(
1677 std::vector<XlaOp> partial_numerator,
1678 nth_partial_numerator(iteration, values.subspan(kFirstInputIdx),
1679 body_builder));
1680 TF_RET_CHECK(partial_numerator.size() == 1);
1681
1682 TF_ASSIGN_OR_RETURN(
1683 std::vector<XlaOp> partial_denominator,
1684 nth_partial_denominator(iteration, values.subspan(kFirstInputIdx),
1685 body_builder));
1686 TF_RET_CHECK(partial_denominator.size() == 1);
1687
1688 auto c = partial_denominator[0] + partial_numerator[0] / values[kCIdx];
1689 auto small_constant = FullLike(c, small);
1690 c = Select(Lt(Abs(c), small_constant), small_constant, c);
1691
1692 auto d = partial_denominator[0] + partial_numerator[0] * values[kDIdx];
1693 d = Select(Lt(Abs(d), small_constant), small_constant, d);
1694
1695 d = Reciprocal(d);
1696
1697 auto delta = c * d;
1698 auto h = values[kHIdx] * delta;
1699
1700 std::vector<XlaOp> updated_values(values.size());
1701 updated_values[kIterationIdx] = Add(iteration, ScalarLike(iteration, 1));
1702 updated_values[kCIdx] = c;
1703 updated_values[kDIdx] = d;
1704 updated_values[kHIdx] = h;
1705 std::copy(values.begin() + kFirstInputIdx, values.end(),
1706 updated_values.begin() + kFirstInputIdx);
1707
1708 // If any values are greater than the tolerance, we have not converged.
1709 auto tolerance_comparison =
1710 Ge(Abs(Sub(delta, FullLike(delta, 1.0))), FullLike(delta, threshold));
1711 updated_values[kValuesUnconvergedIdx] =
1712 ReduceAll(tolerance_comparison, ConstantR0<bool>(body_builder, false),
1713 CreateScalarOrComputation(PRED, body_builder));
1714 return updated_values;
1715 };
1716
1717 TF_ASSIGN_OR_RETURN(std::vector<XlaOp> partial_denominator,
1718 nth_partial_denominator(Zero(&b, U32), inputs, &b));
1719 TF_RET_CHECK(partial_denominator.size() == 1);
1720 auto h = partial_denominator[0];
1721 auto small_constant = FullLike(h, small);
1722 h = Select(Lt(Abs(h), small_constant), small_constant, h);
1723
1724 std::vector<XlaOp> values(kFirstInputIdx + inputs.size());
1725 values[kIterationIdx] = One(&b, U32);
1726 values[kValuesUnconvergedIdx] = ConstantR0<bool>(&b, true);
1727 values[kCIdx] = h;
1728 values[kDIdx] = FullLike(h, 0.0);
1729 values[kHIdx] = h;
1730 std::copy(inputs.begin(), inputs.end(), values.begin() + kFirstInputIdx);
1731 TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
1732 values, name, &b));
1733 return values[kHIdx];
1734 });
1735 }
1736
RegularizedIncompleteBeta(XlaOp a,XlaOp b,XlaOp x)1737 XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) {
1738 auto& builder = *x.builder();
1739 return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1740 TF_ASSIGN_OR_RETURN(Shape shape, builder.GetShape(a));
1741 TF_ASSIGN_OR_RETURN(Shape b_shape, builder.GetShape(b));
1742 TF_ASSIGN_OR_RETURN(Shape x_shape, builder.GetShape(x));
1743 if (b_shape.element_type() != shape.element_type() ||
1744 x_shape.element_type() != shape.element_type()) {
1745 return InvalidArgument(
1746 "Operands to RegularizedIncompleteBeta must have identical types, "
1747 "got shapes %s, %s, and %s",
1748 shape.ToString(), b_shape.ToString(), x_shape.ToString());
1749 }
1750 if (!primitive_util::IsFloatingPointType(shape.element_type())) {
1751 return InvalidArgument(
1752 "Operands to RegularizedIncompleteBeta must be real-valued "
1753 "floating-point, but got %s",
1754 PrimitiveType_Name(shape.element_type()));
1755 }
1756 PrimitiveType element_type = shape.element_type();
1757 if (element_type == F16 || element_type == BF16) {
1758 element_type = F32;
1759 a = ConvertElementType(a, F32);
1760 b = ConvertElementType(b, F32);
1761 x = ConvertElementType(x, F32);
1762 }
1763
1764 // The partial numerator for the incomplete beta function is given
1765 // here: http://dlmf.nist.gov/8.17.E23 Note that there is a special
1766 // case: the partial numerator for the first iteration is one.
1767 auto NthPartialBetaincNumerator =
1768 [&](XlaOp iteration, absl::Span<const XlaOp> inputs,
1769 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
1770 auto a = inputs[0];
1771 auto b = inputs[1];
1772 auto x = inputs[2];
1773 auto iteration_bcast = Broadcast(iteration, shape.dimensions());
1774 auto iteration_is_even =
1775 Eq(iteration_bcast % FullLike(iteration_bcast, 2),
1776 FullLike(iteration_bcast, 0));
1777 auto iteration_is_one = Eq(iteration_bcast, FullLike(iteration_bcast, 1));
1778 auto iteration_minus_one = iteration_bcast - FullLike(iteration_bcast, 1);
1779 auto m = iteration_minus_one / FullLike(iteration_minus_one, 2);
1780 m = ConvertElementType(m, element_type);
1781 auto one = FullLike(a, 1.0);
1782 auto two = FullLike(a, 2.0);
1783 // Partial numerator terms.
1784 auto even_numerator =
1785 -(a + m) * (a + b + m) * x / ((a + two * m) * (a + two * m + one));
1786 auto odd_numerator =
1787 m * (b - m) * x / ((a + two * m - one) * (a + two * m));
1788 auto one_numerator = ScalarLike(x, 1.0);
1789 auto numerator = Select(iteration_is_even, even_numerator, odd_numerator);
1790 return std::vector<XlaOp>{
1791 Select(iteration_is_one, one_numerator, numerator)};
1792 };
1793
1794 auto NthPartialBetaincDenominator =
1795 [&shape](XlaOp iteration, absl::Span<const XlaOp> inputs,
1796 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
1797 auto x = inputs[2];
1798 auto iteration_bcast = Broadcast(iteration, shape.dimensions());
1799 return std::vector<XlaOp>{
1800 Select(Eq(iteration_bcast, ScalarLike(iteration_bcast, 0)),
1801 ScalarLike(x, 0.0), ScalarLike(x, 1.0))};
1802 };
1803
1804 // Determine if the inputs are out of range.
1805 auto result_is_nan =
1806 Or(Or(Or(Le(a, ScalarLike(a, 0.0)), Le(b, ScalarLike(b, 0.0))),
1807 Lt(x, ScalarLike(x, 0.0))),
1808 Gt(x, ScalarLike(x, 1.0)));
1809
1810 // The continued fraction will converge rapidly when x < (a+1)/(a+b+2)
1811 // as per: http://dlmf.nist.gov/8.17.E23
1812 //
1813 // Otherwise, we can rewrite using the symmetry relation as per:
1814 // http://dlmf.nist.gov/8.17.E4
1815 auto converges_rapidly =
1816 Lt(x, (a + FullLike(a, 1.0)) / (a + b + FullLike(b, 2.0)));
1817 auto a_orig = a;
1818 a = Select(converges_rapidly, a, b);
1819 b = Select(converges_rapidly, b, a_orig);
1820 x = Select(converges_rapidly, x, Sub(FullLike(x, 1.0), x));
1821
1822 XlaOp continued_fraction;
1823
1824 // Thresholds and iteration counts taken from Cephes.
1825 if (element_type == F32) {
1826 continued_fraction = LentzThompsonBarnettAlgorithm(
1827 /*num_iterations=*/200,
1828 /*small=*/std::numeric_limits<float>::epsilon() / 2.0f,
1829 /*threshold=*/std::numeric_limits<float>::epsilon() / 2.0f,
1830 /*nth_partial_numerator=*/NthPartialBetaincNumerator,
1831 /*nth_partial_denominator=*/NthPartialBetaincDenominator, {a, b, x},
1832 "Betainc");
1833 } else {
1834 TF_RET_CHECK(element_type == F64);
1835 continued_fraction = LentzThompsonBarnettAlgorithm(
1836 /*num_iterations=*/600,
1837 /*small=*/std::numeric_limits<double>::epsilon() / 2.0f,
1838 /*threshold=*/std::numeric_limits<double>::epsilon() / 2.0f,
1839 /*nth_partial_numerator=*/NthPartialBetaincNumerator,
1840 /*nth_partial_denominator=*/NthPartialBetaincDenominator, {a, b, x},
1841 "Betainc");
1842 }
1843
1844 // We want to compute the regularized complete beta function so we need to
1845 // combine the continued fraction with a few more terms as well as dividing
1846 // it by Beta(a, b). To avoid overflow, we compute in the log domain.
1847 // See http://dlmf.nist.gov/8.17.E22 for an easier to read version of this
1848 // formula.
1849 auto lbeta = Lbeta(a, b);
1850 auto result =
1851 continued_fraction * Exp(Log(x) * a + Log1p(-x) * b - lbeta) / a;
1852 result = Select(result_is_nan, NanValue(&builder, element_type), result);
1853
1854 // We have an additional fixup to do if we are taking advantage of the
1855 // symmetry relation.
1856 auto out =
1857 Select(converges_rapidly, result, Sub(FullLike(result, 1.0), result));
1858 return shape.element_type() == element_type
1859 ? out
1860 : ConvertElementType(out, shape.element_type());
1861 });
1862 }
1863
Polygamma(XlaOp n,XlaOp x)1864 XlaOp Polygamma(XlaOp n, XlaOp x) {
1865 auto& builder = *x.builder();
1866 auto doit = [](XlaOp n, XlaOp x, PrimitiveType type) -> XlaOp {
1867 XlaOp n_plus_one = n + ScalarLike(n, 1.);
1868 XlaOp sign =
1869 (ScalarLike(n, 2.) * Rem(n, ScalarLike(n, 2.)) - ScalarLike(n, 1.));
1870
1871 const double nan = std::numeric_limits<double>::quiet_NaN();
1872
1873 XlaOp output = Select(Eq(n, ScalarLike(n, 0.)), Digamma(x),
1874 sign * Exp(Lgamma(n_plus_one)) * Zeta(n_plus_one, x));
1875 // Check that n is a natural number.
1876 output = Select(Or(Ne(n, Floor(n)), Lt(n, ScalarLike(n, 0.))),
1877 ScalarLike(n, nan), output);
1878 return output;
1879 };
1880 return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1881 TF_ASSIGN_OR_RETURN(auto n_shape, builder.GetShape(n));
1882 TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x));
1883 if (n_shape != x_shape) {
1884 return InvalidArgument(
1885 "Arguments to Polygamma must have equal shapes and types; "
1886 "got %s and %s",
1887 n_shape.ToString(), x_shape.ToString());
1888 }
1889 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x));
1890 bool needs_upcast =
1891 n_shape.element_type() == F16 || x_shape.element_type() == BF16;
1892
1893 if (needs_upcast) {
1894 n = ConvertElementType(n, F32);
1895 x = ConvertElementType(x, F32);
1896 }
1897 XlaOp result = doit(n, x, n_shape.element_type());
1898 if (needs_upcast) {
1899 result = ConvertElementType(result, n_shape.element_type());
1900 }
1901 return result;
1902 });
1903 }
1904
Zeta(XlaOp x,XlaOp q)1905 XlaOp Zeta(XlaOp x, XlaOp q) {
1906 auto& builder = *x.builder();
1907 auto doit = [&builder](XlaOp x, XlaOp q, PrimitiveType type) -> XlaOp {
1908 // (2k) ! / B_{2k}, where B_{2k} are the Bernoulli numbers.
1909 // These are ordered in reverse.
1910 static const std::array<double, 12> kZetaCoeffs{
1911 -7.1661652561756670113e18,
1912 1.8152105401943546773e17,
1913 -4.5979787224074726105e15,
1914 1.1646782814350067249e14,
1915 -2.950130727918164224e12,
1916 7.47242496e10,
1917 -1.8924375803183791606e9,
1918 47900160.0,
1919 -1209600.0,
1920 30240.0,
1921 -720.0,
1922 12.0,
1923 };
1924
1925 // For speed we'll always use 9 iterations for the initial series estimate,
1926 // and a 12 term expansion for the Euler-Maclaurin formula.
1927
1928 XlaOp a = q;
1929 XlaOp neg_power = ScalarLike(a, 0.);
1930 XlaOp initial_sum = Pow(q, Neg(x));
1931 for (int i = 0; i < 9; ++i) {
1932 a = a + ScalarLike(a, 1.);
1933 neg_power = Pow(a, Neg(x));
1934 initial_sum = initial_sum + neg_power;
1935 }
1936 a = a + ScalarLike(a, 1.);
1937 neg_power = Pow(a, Neg(x));
1938 XlaOp s = initial_sum + neg_power * a / (x - ScalarLike(a, 1.));
1939 XlaOp a_inverse_square = Reciprocal(Square(a));
1940 XlaOp horner_sum = ScalarLike(a, 0.);
1941 XlaOp factor = ScalarLike(a, 1.);
1942 // Use Horner's rule for this.
1943 // Note this differs from Cephes which does a 'naive' polynomial evaluation.
1944 // Using Horner's rule allows to avoid some NaN's and Infs from happening,
1945 // resulting in more numerically stable code.
1946 for (int i = 0; i < 11; ++i) {
1947 factor =
1948 (x - ScalarLike(x, 22 - 2 * i)) * (x - ScalarLike(x, 21 - 2 * i));
1949 horner_sum = factor * a_inverse_square *
1950 (horner_sum + ScalarLike(a, 1. / kZetaCoeffs[i]));
1951 }
1952 s = s + neg_power *
1953 (ScalarLike(neg_power, 0.5) +
1954 x / a * (ScalarLike(a, 1. / kZetaCoeffs[11]) + horner_sum));
1955
1956 const double nan = std::numeric_limits<double>::quiet_NaN();
1957 const double inf = std::numeric_limits<double>::infinity();
1958 // Use the initial zeta sum without the correction term coming
1959 // from Euler-Maclaurin if it is accurate enough.
1960 XlaOp output =
1961 Select(Lt(Abs(neg_power), Abs(initial_sum) * Epsilon(&builder, type)),
1962 initial_sum, s);
1963
1964 // This is the harmonic series.
1965 output = Select(Eq(x, ScalarLike(x, 1.)), ScalarLike(x, inf), output);
1966
1967 // Function is not defined for x < 1.
1968 output = Select(Lt(x, ScalarLike(x, 1.)), ScalarLike(x, nan), output);
1969
1970 // For q <= 0, x must be an integer.
1971 XlaOp x_domain_error = And(Le(q, ScalarLike(x, 0.)), Ne(x, Floor(x)));
1972 output = Select(x_domain_error, ScalarLike(x, nan), output);
1973
1974 // For all integer q <= 0, zeta has a pole. The limit is only defined as
1975 // +inf if x is and even integer.
1976 XlaOp at_pole = And(Le(q, ScalarLike(x, 0.)), Eq(q, Floor(q)));
1977 XlaOp x_is_even_int =
1978 And(Eq(Rem(x, ScalarLike(x, 2.)), ScalarLike(x, 0.)), Eq(x, Floor(x)));
1979 output = Select(
1980 at_pole, Select(x_is_even_int, ScalarLike(x, inf), ScalarLike(x, nan)),
1981 output);
1982
1983 return output;
1984 };
1985 return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1986 TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x));
1987 TF_ASSIGN_OR_RETURN(auto q_shape, builder.GetShape(q));
1988 if (x_shape != q_shape) {
1989 return InvalidArgument(
1990 "Arguments to Zeta must have equal shapes and types; got %s and %s",
1991 x_shape.ToString(), q_shape.ToString());
1992 }
1993 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x));
1994 bool needs_upcast =
1995 x_shape.element_type() == F16 || x_shape.element_type() == BF16;
1996
1997 if (needs_upcast) {
1998 x = ConvertElementType(x, F32);
1999 q = ConvertElementType(q, F32);
2000 }
2001 XlaOp result = doit(x, q, x_shape.element_type());
2002 if (needs_upcast) {
2003 result = ConvertElementType(result, x_shape.element_type());
2004 }
2005 return result;
2006 });
2007 }
2008
2009 } // namespace xla
2010