1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/math.h"
17
18 #include <cmath>
19
20 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/lib/loops.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27
28 namespace xla {
29 namespace {
30
31 // Evaluate the polynomial given `x` and coefficients in decreasing order.
32 template <typename FP>
EvaluatePolynomial(XlaOp x,absl::Span<const FP> coefficients)33 XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const FP> coefficients) {
34 static_assert(std::is_floating_point<FP>::value,
35 "Template-argument 'FP' must be a floating-point type");
36 XlaOp poly = ScalarLike(x, 0.0);
37 for (FP c : coefficients) {
38 poly = poly * x + ScalarLike(x, c);
39 }
40 return poly;
41 }
42
43 // Evaluate the chebyshev polynomial given `x` and coefficients in decreasing
44 // order.
45 template <typename FP>
EvaluateChebyshevPolynomial(XlaOp x,absl::Span<const FP> coefficients)46 XlaOp EvaluateChebyshevPolynomial(XlaOp x, absl::Span<const FP> coefficients) {
47 static_assert(std::is_floating_point<FP>::value,
48 "Template-argument 'FP' must be a floating-point type");
49 XlaOp b0 = ScalarLike(x, 0.0);
50 XlaOp b1 = ScalarLike(x, 0.0);
51 XlaOp b2 = ScalarLike(x, 0.0);
52 for (FP c : coefficients) {
53 b2 = b1;
54 b1 = b0;
55 b0 = x * b1 - b2 + ScalarLike(x, c);
56 }
57 return ScalarLike(x, 0.5) * (b0 - b2);
58 }
59
60 } // namespace
61
62 // Returns operation(operand), except if `operand` is one of the types in
63 // upcast_types, in which case first converts it to F32, and then converts the
64 // result down to the original type.
DoWithUpcastToF32(XlaOp operand,absl::Span<const PrimitiveType> upcast_types,const std::function<XlaOp (XlaOp)> & operation)65 static XlaOp DoWithUpcastToF32(XlaOp operand,
66 absl::Span<const PrimitiveType> upcast_types,
67 const std::function<XlaOp(XlaOp)>& operation) {
68 auto& b = *operand.builder();
69 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
70 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
71 PrimitiveType elem_ty = shape.element_type();
72 bool needs_upcast = absl::c_linear_search(upcast_types, elem_ty);
73
74 if (needs_upcast) {
75 operand = ConvertElementType(operand, F32);
76 }
77 XlaOp result = operation(operand);
78 if (needs_upcast) {
79 result = ConvertElementType(result, elem_ty);
80 }
81 return result;
82 });
83 }
84
85 // TODO(jlebar): Use this function in more places in this file to restrict the
86 // domain of other functions.
EnsureOperandIsRealFp(absl::string_view op_name,XlaOp operand)87 static Status EnsureOperandIsRealFp(absl::string_view op_name, XlaOp operand) {
88 auto& b = *operand.builder();
89 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
90 auto elem_ty = shape.element_type();
91 if (!primitive_util::IsFloatingPointType(elem_ty)) {
92 return InvalidArgument(
93 "Operands to %s must be real-valued floating-point, but got %s",
94 op_name, PrimitiveType_Name(elem_ty));
95 }
96 return Status::OK();
97 }
98
IsPosInf(XlaOp operand)99 XlaOp IsPosInf(XlaOp operand) {
100 auto& b = *operand.builder();
101 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
102 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsPosInf", operand));
103 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
104 // Note that this is only correct for floating-point types. If we wanted it
105 // to be correct for all types, we'd need to Gt(MaxFiniteValue).
106 return Eq(operand, MaxValue(&b, shape.element_type()));
107 });
108 }
109
IsNegInf(XlaOp operand)110 XlaOp IsNegInf(XlaOp operand) {
111 auto& b = *operand.builder();
112 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
113 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegInf", operand));
114 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
115 // Note that this is only correct for floating-point types. If we wanted it
116 // to be correct for all types, we'd need to Lt(MinFiniteValue).
117 return Eq(operand, MinValue(&b, shape.element_type()));
118 });
119 }
120
IsInf(XlaOp operand)121 XlaOp IsInf(XlaOp operand) {
122 auto& b = *operand.builder();
123 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
124 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsInf", operand));
125 return IsPosInf(Abs(operand));
126 });
127 }
128
IsNan(XlaOp operand)129 XlaOp IsNan(XlaOp operand) {
130 auto& b = *operand.builder();
131 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
132 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNan", operand));
133 return Ne(operand, operand);
134 });
135 }
136
IsNegZero(XlaOp operand)137 XlaOp IsNegZero(XlaOp operand) {
138 auto& b = *operand.builder();
139 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
140 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegZero", operand));
141 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
142
143 // The bitwise representation of -0 in bfloat16 and IEEE 754 is 0x80...0
144 // (sign bit on, all other bits off).
145 switch (shape.element_type()) {
146 case F64:
147 return Eq(BitcastConvertType(operand, U64),
148 ConstantR0WithType(&b, U64, uint64{1} << 63));
149 case F32:
150 return Eq(BitcastConvertType(operand, U32),
151 ConstantR0WithType(&b, U32, uint32{1} << 31));
152 case F16:
153 case BF16:
154 // Not all XLA backends handle U16 well, so we convert to F32/U32.
155 // TODO(jlebar): It would be nice if we could stay in (B)F16/U16 for
156 // backends that *do* support it.
157 return Eq(BitcastConvertType(ConvertElementType(operand, F32), U32),
158 ConstantR0WithType(&b, U32, uint32{1} << 31));
159 default:
160 LOG(FATAL) << "Expected real fp type.";
161 }
162 });
163 }
164
Square(XlaOp operand)165 XlaOp Square(XlaOp operand) { return operand * operand; }
166
Reciprocal(XlaOp operand)167 XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; }
168
169 // Computes an approximation of the error function complement (1 - erf(x)).
170 //
171 // Precondition: abs(x) >= 1. Otherwise, use ErfImpl.
172 //
173 // This follows Cephes's f32 implementation of erfc.
ErfcImpl32(XlaOp x)174 static XlaOp ErfcImpl32(XlaOp x) {
175 // Coefficients for erfc(f32), from Cephes.
176 const double kMaxlog = 88.72283905206835;
177 // erfc(x) = exp(-x^2) P(1/x^2), 1 < x < 2
178 static const std::array<float, 9> kErfcPCoefficient{
179 +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1,
180 -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1,
181 +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1,
182 };
183 // erfc(x) = exp(-x^2) R(1/x^2), 2 <= x < kMaxlog
184 static const std::array<float, 8> kErfcRCoefficient{
185 -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0,
186 +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1,
187 -2.820767439740514E-1, +5.641895067754075E-1,
188 };
189 XlaOp abs_x = Abs(x);
190 XlaOp z = Exp(-x * x);
191 XlaOp q = ScalarLike(x, 1) / abs_x;
192 XlaOp y = q * q;
193 XlaOp p = Select(Lt(abs_x, ScalarLike(x, 2.0)),
194 EvaluatePolynomial<float>(y, kErfcPCoefficient),
195 EvaluatePolynomial<float>(y, kErfcRCoefficient));
196 y = z * q * p;
197 XlaOp y_clamp = Select(Lt(z, ScalarLike(x, -kMaxlog)), ScalarLike(x, 0), y);
198 return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y_clamp, y_clamp);
199 }
200
201 // Compute a polynomial approximation of the error function.
202 //
203 // Precondition: abs(x) <= 1. Otherwise, use ErfcImpl.
204 //
205 // This follows Cephes's f32 implementation of erf.
ErfImpl32Cephes(XlaOp x)206 static XlaOp ErfImpl32Cephes(XlaOp x) {
207 // Coefficients for by erf(f32), from Cephes.
208 //
209 // erf(x) = x P(x^2), 0 < x < 1
210 static const std::array<float, 7> kErfTCoefficient{
211 +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3,
212 -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1,
213 +1.128379165726710E+0,
214 };
215 return x * EvaluatePolynomial<float>(x * x, kErfTCoefficient);
216 }
217
ErfcImpl64(XlaOp x)218 static XlaOp ErfcImpl64(XlaOp x) {
219 // Coefficients for erfc(f64), from Cephes.
220 const double kMaxlog = 7.09782712893383996843E2;
221 // erfc(x) = exp(-x^2) P(|x|) / Q(|x|), 1 < x < 8
222 static const std::array<double, 9> kErfcPCoefficient{
223 2.46196981473530512524E-10, 5.64189564831068821977E-1,
224 7.46321056442269912687E0, 4.86371970985681366614E1,
225 1.96520832956077098242E2, 5.26445194995477358631E2,
226 9.34528527171957607540E2, 1.02755188689515710272E3,
227 5.57535335369399327526E2};
228 static const std::array<double, 9> kErfcQCoefficient{
229 1.00000000000000000000E0, 1.32281951154744992508E1,
230 8.67072140885989742329E1, 3.54937778887819891062E2,
231 9.75708501743205489753E2, 1.82390916687909736289E3,
232 2.24633760818710981792E3, 1.65666309194161350182E3,
233 5.57535340817727675546E2};
234
235 // erfc(x) = exp(-x^2) R(|x|) / S(|x|), 8 <= x < kMaxlog
236 static const std::array<double, 6> kErfcRCoefficient{
237 5.64189583547755073984E-1, 1.27536670759978104416E0,
238 5.01905042251180477414E0, 6.16021097993053585195E0,
239 7.40974269950448939160E0, 2.97886665372100240670E0};
240 static const std::array<double, 7> kErfcSCoefficient{
241 1.00000000000000000000E0, 2.26052863220117276590E0,
242 9.39603524938001434673E0, 1.20489539808096656605E1,
243 1.70814450747565897222E1, 9.60896809063285878198E0,
244 3.36907645100081516050E0};
245
246 XlaOp z = -x * x;
247 XlaOp abs_x = Abs(x);
248 XlaOp y =
249 Select(Lt(abs_x, ScalarLike(x, 8.0)),
250 Exp(z) * EvaluatePolynomial<double>(abs_x, kErfcPCoefficient) /
251 EvaluatePolynomial<double>(abs_x, kErfcQCoefficient),
252 Exp(z) * EvaluatePolynomial<double>(abs_x, kErfcRCoefficient) /
253 EvaluatePolynomial<double>(abs_x, kErfcSCoefficient));
254 XlaOp y_clamp = Select(Lt(z, ScalarLike(x, -kMaxlog)), ScalarLike(x, 0), y);
255 return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y_clamp, y_clamp);
256 }
257
258 // Compute a polynomial approximation of the error function.
259 //
260 // Precondition: abs(x) <= 1. Otherwise, use ErfcImpl.
ErfImpl64(XlaOp x)261 static XlaOp ErfImpl64(XlaOp x) {
262 // Coefficients for by erf(f64), from Cephes.
263 //
264 // erf(x) = x T(x^2) / U(x^2), 0 < x < 1
265 static std::array<double, 5> kErfTCoefficient{
266 9.60497373987051638749E0, 9.00260197203842689217E1,
267 2.23200534594684319226E3, 7.00332514112805075473E3,
268 5.55923013010394962768E4};
269 static std::array<double, 6> kErfUCoefficient{
270 1.00000000000000000000E0, 3.35617141647503099647E1,
271 5.21357949780152679795E2, 4.59432382970980127987E3,
272 2.26290000613890934246E4, 4.92673942608635921086E4};
273 XlaOp z = x * x;
274 return x * EvaluatePolynomial<double>(z, kErfTCoefficient) /
275 EvaluatePolynomial<double>(z, kErfUCoefficient);
276 }
277
Erfc(XlaOp x)278 XlaOp Erfc(XlaOp x) {
279 auto& b = *x.builder();
280 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
281 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erfc", x));
282 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
283 // erfc(x) =
284 // erfc_impl(x) if x > 1
285 // 1 - erf_impl(x) otherwise
286 if (shape.element_type() == F64) {
287 return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl64(x),
288 ScalarLike(x, 1) - ErfImpl64(x));
289 }
290 // Erf(c)Impl don't have enough precision when run with bf16 intermediates
291 // (not surprising!), so upcast to f32 in this case.
292 return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
293 return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x),
294 ScalarLike(x, 1) - ErfImpl32Cephes(x));
295 });
296 });
297 }
298
299 // Compute a polynomial approximation of the error function.
300 // This is the same approximation used by Eigen.
ErfImpl32(XlaOp x)301 static XlaOp ErfImpl32(XlaOp x) {
302 static const std::array<float, 7> kAlpha{
303 -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
304 -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
305 -1.60960333262415e-02f,
306 };
307
308 static const std::array<float, 5> kBeta{
309 -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
310 -7.37332916720468e-03f, -1.42647390514189e-02f,
311 };
312
313 x = Clamp(ScalarLike(x, -4.f), x, ScalarLike(x, 4.f));
314 auto x2 = x * x;
315 return x * EvaluatePolynomial<float>(x2, kAlpha) /
316 EvaluatePolynomial<float>(x2, kBeta);
317 }
318
Erf(XlaOp x)319 XlaOp Erf(XlaOp x) {
320 auto& b = *x.builder();
321 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
322 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erf", x));
323 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
324 // erf(x) =
325 // erf_impl(x) if x < 1
326 // 1 - erfc_impl(x) otherwise
327 if (shape.element_type() == F64) {
328 return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl64(x),
329 ScalarLike(x, 1) - ErfcImpl64(x));
330 }
331 // Erf(c)Impl don't have enough precision when run with bf16 intermediates
332 // (not surprising!), so upcast to f32 in this case.
333 return DoWithUpcastToF32(x, {BF16, F16},
334 [](XlaOp x) { return ErfImpl32(x); });
335 });
336 }
337
338 namespace {
339
340 // Approximation for the inverse error function from
341 // Giles, M., "Approximating the erfinv function".
342 // The approximation has the form:
343 // w = -log((1 - x) * (1 + x))
344 // if ( w < 5 ) {
345 // w = w - 2.5
346 // p = sum_{i=1}^n lq[i]*w^i
347 // } else {
348 // w = sqrt(w) - 3
349 // p = sum_{i=1}^n gq[i]*w^i
350 // }
351 // return p*x
ErfInv32(XlaOp x)352 XlaOp ErfInv32(XlaOp x) {
353 constexpr int kDegree = 9;
354 constexpr std::array<float, 9> w_less_than_5_constants = {
355 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
356 -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
357 -0.00417768164f, 0.246640727f, 1.50140941f};
358 constexpr std::array<float, 9> w_greater_than_5_constants = {
359 -0.000200214257f, 0.000100950558f, 0.00134934322f,
360 -0.00367342844f, 0.00573950773f, -0.0076224613f,
361 0.00943887047f, 1.00167406f, 2.83297682f};
362
363 // Compute logarithm of (1+arg) using log1p(arg) which is more precise than
364 // log(1+arg) when arg is close to zero. For more details, see
365 // https://en.cppreference.com/w/cpp/numeric/math/log1p
366 auto w = -Log1p(-x * x);
367
368 auto lt = Lt(w, ScalarLike(x, 5.0));
369 auto coefficient = [&](int i) {
370 return Select(lt, FullLike(x, w_less_than_5_constants[i]),
371 FullLike(x, w_greater_than_5_constants[i]));
372 };
373 w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0));
374 auto p = coefficient(0);
375 for (int i = 1; i < kDegree; ++i) {
376 p = coefficient(i) + p * w;
377 }
378
379 // Result modulo edge cases.
380 XlaOp result = p * x;
381
382 // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is
383 // indeterminate, and can give nan or -/+inf.)
384 auto& b = *x.builder();
385 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
386 TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x));
387 return Select(Eq(Abs(x), ScalarLike(x, 1)),
388 x * MaxValue(&b, shape.element_type()), result);
389 });
390 }
391
ErfInv64(XlaOp x)392 XlaOp ErfInv64(XlaOp x) {
393 constexpr std::array<double, 23> w_less_than_6_25_constants = {
394 -3.6444120640178196996e-21, -1.685059138182016589e-19,
395 1.2858480715256400167e-18, 1.115787767802518096e-17,
396 -1.333171662854620906e-16, 2.0972767875968561637e-17,
397 6.6376381343583238325e-15, -4.0545662729752068639e-14,
398 -8.1519341976054721522e-14, 2.6335093153082322977e-12,
399 -1.2975133253453532498e-11, -5.4154120542946279317e-11,
400 1.051212273321532285e-09, -4.1126339803469836976e-09,
401 -2.9070369957882005086e-08, 4.2347877827932403518e-07,
402 -1.3654692000834678645e-06, -1.3882523362786468719e-05,
403 0.0001867342080340571352, -0.00074070253416626697512,
404 -0.0060336708714301490533, 0.24015818242558961693,
405 1.6536545626831027356};
406 constexpr std::array<double, 19> w_less_than_16_constants = {
407 2.2137376921775787049e-09, 9.0756561938885390979e-08,
408 -2.7517406297064545428e-07, 1.8239629214389227755e-08,
409 1.5027403968909827627e-06, -4.013867526981545969e-06,
410 2.9234449089955446044e-06, 1.2475304481671778723e-05,
411 -4.7318229009055733981e-05, 6.8284851459573175448e-05,
412 2.4031110387097893999e-05, -0.0003550375203628474796,
413 0.00095328937973738049703, -0.0016882755560235047313,
414 0.0024914420961078508066, -0.0037512085075692412107,
415 0.005370914553590063617, 1.0052589676941592334,
416 3.0838856104922207635,
417 };
418 constexpr std::array<double, 17> w_greater_than_16_constants = {
419 -2.7109920616438573243e-11, -2.5556418169965252055e-10,
420 1.5076572693500548083e-09, -3.7894654401267369937e-09,
421 7.6157012080783393804e-09, -1.4960026627149240478e-08,
422 2.9147953450901080826e-08, -6.7711997758452339498e-08,
423 2.2900482228026654717e-07, -9.9298272942317002539e-07,
424 4.5260625972231537039e-06, -1.9681778105531670567e-05,
425 7.5995277030017761139e-05, -0.00021503011930044477347,
426 -0.00013871931833623122026, 1.0103004648645343977,
427 4.8499064014085844221,
428 };
429 // Compute logarithm of (1+arg) using log1p(arg) which is more precise than
430 // log(1+arg) when arg is close to zero. For more details, see
431 // https://en.cppreference.com/w/cpp/numeric/math/log1p
432 auto w = -Log1p(-x * x);
433
434 auto lt_6_25 = Lt(w, ScalarLike(x, 6.25));
435 auto lt_16 = Lt(w, ScalarLike(x, 16));
436 auto coefficient = [&](int i) {
437 auto c = FullLike(x, w_less_than_6_25_constants[i]);
438 if (i < 19) {
439 c = Select(lt_6_25, c, FullLike(x, w_less_than_16_constants[i]));
440 }
441 if (i < 17) {
442 c = Select(lt_16, c, FullLike(x, w_greater_than_16_constants[i]));
443 }
444 return c;
445 };
446 auto sqrt_w = Sqrt(w);
447 w = Select(lt_6_25, w - ScalarLike(x, 3.125),
448 sqrt_w - Select(lt_16, ScalarLike(x, 3.25), ScalarLike(x, 5.0)));
449 auto p = coefficient(0);
450 for (int i = 1; i < 17; ++i) {
451 p = coefficient(i) + p * w;
452 }
453 for (int i = 17; i < 19; ++i) {
454 p = Select(lt_16, coefficient(i) + p * w, p);
455 }
456 for (int i = 19; i < 23; ++i) {
457 p = Select(lt_6_25, coefficient(i) + p * w, p);
458 }
459 // Result modulo edge cases.
460 XlaOp result = p * x;
461
462 // Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is
463 // indeterminate, and can give nan or -/+inf.)
464 auto& b = *x.builder();
465 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
466 TF_ASSIGN_OR_RETURN(Shape shape, b.GetShape(x));
467 return Select(Eq(Abs(x), ScalarLike(x, 1)),
468 x * MaxValue(&b, shape.element_type()), result);
469 });
470 }
471
472 } // namespace
473
ErfInv(XlaOp x)474 XlaOp ErfInv(XlaOp x) {
475 auto& b = *x.builder();
476 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
477 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("ErfInv", x));
478 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
479 if (shape.element_type() == F64) {
480 return ErfInv64(x);
481 }
482 return DoWithUpcastToF32(x, {BF16, F16},
483 [](XlaOp x) { return ErfInv32(x); });
484 });
485 }
486
487 namespace {
488 // Coefficients for the Lanczos approximation of the gamma function. The
489 // coefficients are uniquely determined by the choice of g and n (kLanczosGamma
490 // and kLanczosCoefficients.size() + 1). The coefficients below correspond to
491 // [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7,
492 // 9] seemed to be the least sensitive to the quality of the log function. In
493 // particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
494 // for a particularly inaccurate log function.
495 static constexpr double kLanczosGamma = 7; // aka g
496 static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
497 static constexpr std::array<double, 8> kLanczosCoefficients = {
498 676.520368121885098567009190444019, -1259.13921672240287047156078755283,
499 771.3234287776530788486528258894, -176.61502916214059906584551354,
500 12.507343278686904814458936853, -0.13857109526572011689554707,
501 9.984369578019570859563e-6, 1.50563273514931155834e-7};
502 } // namespace
503
504 // Compute the Lgamma function using Lanczos' approximation from "A Precision
505 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
506 // series B. Vol. 1:
507 // lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
508 // t(z) = z + kLanczosGamma + 1/2
509 // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
Lgamma(XlaOp input)510 XlaOp Lgamma(XlaOp input) {
511 auto do_it = [](XlaOp input) {
512 XlaOp one_half = ScalarLike(input, 0.5);
513 XlaOp one = ScalarLike(input, 1);
514
515 XlaOp pi = ScalarLike(input, M_PI);
516 XlaOp log_pi = ScalarLike(input, std::log(M_PI));
517 XlaOp log_sqrt_two_pi =
518 ScalarLike(input, (std::log(2) + std::log(M_PI)) / 2);
519
520 XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5);
521 XlaOp log_lanczos_gamma_plus_one_half =
522 ScalarLike(input, std::log(kLanczosGamma + 0.5));
523
524 XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff);
525
526 // If the input is less than 0.5 use Euler's reflection formula:
527 // gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
528 XlaOp need_to_reflect = Lt(input, one_half);
529 XlaOp z = Select(need_to_reflect, -input, input - one);
530
531 XlaOp x = base_lanczos_coeff;
532 for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
533 XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]);
534 XlaOp index = ScalarLike(input, i);
535 x = x + lanczos_coefficient / (z + index + one);
536 }
537
538 // To improve accuracy on platforms with less-precise log implementations,
539 // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
540 // the device.
541 // log(t) = log(kLanczosGamma + 0.5 + z)
542 // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
543 XlaOp t = lanczos_gamma_plus_one_half + z;
544 XlaOp log_t = log_lanczos_gamma_plus_one_half +
545 Log1p(z / lanczos_gamma_plus_one_half);
546
547 // Compute the final result (modulo reflection). t(z) may be large, and we
548 // need to be careful not to overflow to infinity in the first term of
549 //
550 // (z + 1/2) * log(t(z)) - t(z).
551 //
552 // Therefore we compute this as
553 //
554 // (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
555 //
556 XlaOp log_y = log_sqrt_two_pi + (z + one_half - t / log_t) * log_t + Log(x);
557
558 // Compute the reflected value, used when x < 0.5:
559 //
560 // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
561 //
562 // (The abs is because lgamma is the log of the absolute value of the gamma
563 // function.)
564 //
565 // We have to be careful when computing the final term above. gamma(x) goes
566 // to +/-inf at every integer x < 0, and this is controlled by the
567 // sin(pi * x) term. The slope is large, so precision is particularly
568 // important.
569 //
570 // Because abs(sin(pi * x)) has period 1, we can equivalently use
571 // abs(sin(pi * frac(x))), where frac(x) is the fractional part of x. This
572 // is more numerically accurate: It doesn't overflow to inf like pi * x can,
573 // and if x is an integer, it evaluates to 0 exactly, which is significant
574 // because we then take the log of this value, and log(0) is inf.
575 //
576 // We don't have a frac(x) primitive in XLA and computing it is tricky, but
577 // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for
578 // our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
579 //
580 // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
581 // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
582 // [0, 1] is symmetric across the line Y=0.5.
583 //
584 XlaOp abs_input = Abs(input);
585 XlaOp abs_frac_input = abs_input - Floor(abs_input);
586 // Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve
587 // precision of pi * abs_frac_input for values of abs_frac_input close to 1.
588 XlaOp reduced_frac_input =
589 Select(Gt(abs_frac_input, ScalarLike(abs_frac_input, 0.5)),
590 ScalarLike(abs_frac_input, 1) - abs_frac_input, abs_frac_input);
591 XlaOp reflection_denom = Log(Sin(pi * reduced_frac_input));
592
593 // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
594 // then it "wins" and the result is +/-inf.
595 XlaOp reflection =
596 Select(IsFinite(reflection_denom), log_pi - reflection_denom - log_y,
597 -reflection_denom);
598 XlaOp result = Select(need_to_reflect, reflection, log_y);
599
600 // lgamma(+/-inf) = +inf.
601 XlaOp inf_bcast = FullLike(input, std::numeric_limits<float>::infinity());
602 return Select(IsInf(input), inf_bcast, result);
603 };
604
605 auto& b = *input.builder();
606 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
607 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Lgamma", input));
608 // F16 and BF16 don't provide sufficient precision for intermediate results
609 // here (although it's better than you might expect!), so do the
610 // computations in F32.
611 return DoWithUpcastToF32(input, {BF16, F16}, do_it);
612 });
613 }
614
615 // Computes an approximation of the lbeta function which is equivalent to
616 // log(abs(Beta(a, b))) but avoids overflow by computing it with lgamma.
Lbeta(XlaOp a,XlaOp b)617 static XlaOp Lbeta(XlaOp a, XlaOp b) {
618 // Beta(a, b) can be computed using Gamma as per
619 // http://dlmf.nist.gov/5.12.E1 as follows:
620 // Beta(a, b) = (Gamma(a) * Gamma(b)) / Gamma(a + b)
621 //
622 // To avoid overflow, we compute in the log domain.
623 //
624 // As per http://dlmf.nist.gov/4.8.E2 we can transform:
625 // Log(a * b)
626 // into:
627 // Log(a) + Log(b)
628 //
629 // Likewise, per https://dlmf.nist.gov/4.8.E4, we can turn:
630 // Log(a - b)
631 // into:
632 // Log(a) - Log(b)
633 //
634 // This means that we can compute Log(Beta(a, b)) by:
635 // Log(Gamma(a)) + Log(Gamma(b)) - Log(Gamma(a + b))
636 return Lgamma(a) + Lgamma(b) - Lgamma(a + b);
637 }
638
639 // Compute the Digamma function using Lanczos' approximation from "A Precision
640 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
641 // series B. Vol. 1:
642 // digamma(z + 1) = log(t(z)) + A'(z) / A(z) - kLanczosGamma / t(z)
643 // t(z) = z + kLanczosGamma + 1/2
644 // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
645 // A'(z) = sigma(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
Digamma(XlaOp input)646 XlaOp Digamma(XlaOp input) {
647 auto do_it = [](XlaOp input) {
648 XlaOp zero = ScalarLike(input, 0);
649 XlaOp one_half = ScalarLike(input, 0.5);
650 XlaOp one = ScalarLike(input, 1);
651
652 XlaOp pi = ScalarLike(input, M_PI);
653
654 XlaOp lanczos_gamma = ScalarLike(input, kLanczosGamma);
655 XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5);
656 XlaOp log_lanczos_gamma_plus_one_half =
657 ScalarLike(input, std::log(kLanczosGamma + 0.5));
658
659 XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff);
660
661 // If the input is less than 0.5 use Euler's reflection formula:
662 // digamma(x) = digamma(1 - x) - pi * cot(pi * x)
663 XlaOp need_to_reflect = Lt(input, one_half);
664 XlaOp z = Select(need_to_reflect, -input, input - one);
665
666 XlaOp num = zero;
667 XlaOp denom = base_lanczos_coeff;
668 for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
669 XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]);
670 XlaOp index = ScalarLike(input, i);
671 num = num - lanczos_coefficient / ((z + index + one) * (z + index + one));
672 denom = denom + lanczos_coefficient / (z + index + one);
673 }
674
675 // To improve accuracy on platforms with less-precise log implementations,
676 // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
677 // the device.
678 // log(t) = log(kLanczosGamma + 0.5 + z)
679 // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
680 XlaOp t = lanczos_gamma_plus_one_half + z;
681 XlaOp log_t = log_lanczos_gamma_plus_one_half +
682 Log1p(z / lanczos_gamma_plus_one_half);
683
684 XlaOp y = log_t + num / denom - lanczos_gamma / t;
685
686 // We need to be careful how we compute cot(pi * input) below: For
687 // near-integral values of `input`, pi * input can lose precision.
688 //
689 // Input is already known to be less than 0.5 (otherwise we don't have to
690 // reflect). We shift values smaller than -0.5 into the range [-.5, .5] to
691 // increase precision of pi * input and the resulting cotangent.
692 XlaOp reduced_input = input + Abs(Floor(input + ScalarLike(input, 0.5)));
693 XlaOp reflection =
694 y - pi * Cos(pi * reduced_input) / Sin(pi * reduced_input);
695 XlaOp real_result = Select(need_to_reflect, reflection, y);
696
697 // Digamma has poles at negative integers and zero; return nan for those.
698 return Select(And(Le(input, zero), Eq(input, Floor(input))),
699 FullLike(input, std::numeric_limits<float>::quiet_NaN()),
700 real_result);
701 };
702
703 auto& b = *input.builder();
704 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
705 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input));
706 return DoWithUpcastToF32(input, {BF16, F16}, do_it);
707 });
708 }
709
710 // Incomplete gamma functions
711
712 namespace {
713
714 enum kIgammaMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE };
715
716 // Helper function for computing Igamma using a power series.
717 template <kIgammaMode mode>
IgammaSeries(XlaOp ax,XlaOp x,XlaOp a,XlaOp enabled,xla::PrimitiveType type)718 XlaOp IgammaSeries(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
719 xla::PrimitiveType type) {
720 // vals: (enabled, r, c, ans, x)
721 // 'enabled' is a predication mask that says for which elements we should
722 // execute the loop body. Disabled elements have no effect in the loop body.
723 // TODO(phawkins): in general this isn't an optimal implementation on any
724 // backend. For example, on GPU, we should probably vectorize to the warp
725 // size, and then run independent loops for each warp's worth of
726 // data.
727 auto cond = [&](absl::Span<const XlaOp> vals,
728 XlaBuilder* builder) -> StatusOr<XlaOp> {
729 XlaOp enabled = vals[0];
730 return Any(enabled);
731 };
732 auto body = [&](absl::Span<const XlaOp> vals,
733 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
734 XlaOp enabled = vals[0];
735 XlaOp r = vals[1];
736 XlaOp c = vals[2];
737 XlaOp ans = vals[3];
738 XlaOp x = vals[4];
739 XlaOp dc_da = vals[5];
740 XlaOp dans_da = vals[6];
741
742 r = r + ScalarLike(r, 1);
743 dc_da = dc_da * (x / r) + (ScalarLike(r, -1) * c * x) / (r * r);
744 dans_da = dans_da + dc_da;
745 c = c * (x / r);
746 ans = ans + c;
747 XlaOp conditional;
748 if (mode == VALUE) {
749 conditional = And(enabled, Gt(c / ans, Epsilon(builder, type)));
750 } else {
751 conditional =
752 And(enabled, Gt(Abs(dc_da / dans_da), Epsilon(builder, type)));
753 }
754
755 return std::vector<XlaOp>{
756 conditional,
757 Select(enabled, r, vals[1]),
758 Select(enabled, c, vals[2]),
759 Select(enabled, ans, vals[3]),
760 Select(enabled, x, vals[4]),
761 Select(enabled, dc_da, vals[5]),
762 Select(enabled, dans_da, vals[6]),
763 };
764 };
765 auto& b = *ax.builder();
766 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
767 std::vector<XlaOp> vals = {
768 enabled, a, FullLike(a, 1), FullLike(a, 1), x, FullLike(a, 0),
769 FullLike(a, 0),
770 };
771
772 TF_ASSIGN_OR_RETURN(vals, WhileLoopHelper(cond, body, vals, "igamma", &b));
773 XlaOp ans = vals[3];
774 XlaOp dans_da = vals[6];
775 if (mode == VALUE) {
776 return (ans * ax) / a;
777 }
778
779 XlaOp dlogax_da = Log(x) - Digamma(a + ScalarLike(a, 1));
780
781 switch (mode) {
782 case DERIVATIVE:
783 return ax * (ans * dlogax_da + dans_da) / a;
784 case SAMPLE_DERIVATIVE:
785 default:
786 return -(dans_da + ans * dlogax_da) * x / a;
787 }
788 });
789 }
790
791 // Helper function for computing Igammac using a continued fraction.
792 template <kIgammaMode mode>
IgammacContinuedFraction(XlaOp ax,XlaOp x,XlaOp a,XlaOp enabled,xla::PrimitiveType type)793 XlaOp IgammacContinuedFraction(XlaOp ax, XlaOp x, XlaOp a, XlaOp enabled,
794 xla::PrimitiveType type) {
795 // vals: enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2
796 auto cond = [&](absl::Span<const XlaOp> vals,
797 XlaBuilder* builder) -> StatusOr<XlaOp> {
798 XlaOp enabled = vals[0];
799 XlaOp c = vals[5];
800 return And(Lt(c, ScalarLike(c, 2000)), Any(enabled));
801 };
802 auto body = [&](absl::Span<const XlaOp> vals,
803 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
804 XlaOp enabled = vals[0];
805 XlaOp ans = vals[1];
806 XlaOp t = vals[2];
807 XlaOp y = vals[3];
808 XlaOp z = vals[4];
809 XlaOp c = vals[5];
810 XlaOp pkm1 = vals[6];
811 XlaOp qkm1 = vals[7];
812 XlaOp pkm2 = vals[8];
813 XlaOp qkm2 = vals[9];
814
815 XlaOp dpkm2_da = vals[10];
816 XlaOp dqkm2_da = vals[11];
817 XlaOp dpkm1_da = vals[12];
818 XlaOp dqkm1_da = vals[13];
819 XlaOp dans_da = vals[14];
820
821 c = c + ScalarLike(c, 1);
822 y = y + ScalarLike(y, 1);
823 z = z + ScalarLike(z, 2);
824 XlaOp yc = y * c;
825 XlaOp pk = pkm1 * z - pkm2 * yc;
826 XlaOp qk = qkm1 * z - qkm2 * yc;
827 XlaOp qk_is_nonzero = Ne(qk, ScalarLike(qk, 0));
828 XlaOp r = pk / qk;
829
830 t = Select(qk_is_nonzero, Abs((ans - r) / r), FullLike(t, 1));
831 ans = Select(qk_is_nonzero, r, ans);
832
833 XlaOp dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c;
834 XlaOp dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c;
835 XlaOp dans_da_new =
836 Select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da);
837 XlaOp grad_conditional =
838 Select(qk_is_nonzero, Abs(dans_da_new - dans_da), FullLike(dans_da, 1));
839
840 pkm2 = pkm1;
841 pkm1 = pk;
842 qkm2 = qkm1;
843 qkm1 = qk;
844
845 dpkm2_da = dpkm1_da;
846 dqkm2_da = dqkm1_da;
847 dpkm1_da = dpk_da;
848 dqkm1_da = dqk_da;
849
850 XlaOp rescale = Gt(Abs(pk), Reciprocal(Epsilon(builder, type)));
851 pkm2 = Select(rescale, pkm2 * Epsilon(builder, type), pkm2);
852 pkm1 = Select(rescale, pkm1 * Epsilon(builder, type), pkm1);
853 qkm2 = Select(rescale, qkm2 * Epsilon(builder, type), qkm2);
854 qkm1 = Select(rescale, qkm1 * Epsilon(builder, type), qkm1);
855
856 dpkm2_da = Select(rescale, dpkm2_da * Epsilon(builder, type), dpkm2_da);
857 dqkm2_da = Select(rescale, dqkm2_da * Epsilon(builder, type), dqkm2_da);
858 dpkm1_da = Select(rescale, dpkm1_da * Epsilon(builder, type), dpkm1_da);
859 dqkm1_da = Select(rescale, dqkm1_da * Epsilon(builder, type), dqkm1_da);
860
861 XlaOp conditional;
862 if (mode == VALUE) {
863 conditional = And(enabled, Gt(t, Epsilon(builder, type)));
864 } else {
865 conditional = And(enabled, Gt(grad_conditional, Epsilon(builder, type)));
866 }
867
868 return std::vector<XlaOp>{conditional,
869 Select(enabled, ans, vals[1]),
870 Select(enabled, t, vals[2]),
871 Select(enabled, y, vals[3]),
872 Select(enabled, z, vals[4]),
873 c,
874 Select(enabled, pkm1, vals[6]),
875 Select(enabled, qkm1, vals[7]),
876 Select(enabled, pkm2, vals[8]),
877 Select(enabled, qkm2, vals[9]),
878 Select(enabled, dpkm2_da, vals[10]),
879 Select(enabled, dqkm2_da, vals[11]),
880 Select(enabled, dpkm1_da, vals[12]),
881 Select(enabled, dqkm1_da, vals[13]),
882 Select(enabled, dans_da_new, vals[14])};
883 };
884
885 auto& b = *ax.builder();
886 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
887 XlaOp y = ScalarLike(a, 1) - a;
888 XlaOp z = x + y + ScalarLike(x, 1);
889 XlaOp c = ScalarLike(x, 0);
890 XlaOp pkm2 = FullLike(x, 1);
891 XlaOp qkm2 = x;
892 XlaOp pkm1 = x + ScalarLike(x, 1);
893 XlaOp qkm1 = z * x;
894 XlaOp ans = pkm1 / qkm1;
895 XlaOp t = FullLike(x, 1);
896 XlaOp dpkm2_da = FullLike(x, 0);
897 XlaOp dqkm2_da = FullLike(x, 0);
898 XlaOp dpkm1_da = FullLike(x, 0);
899 XlaOp dqkm1_da = -x;
900 XlaOp dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
901 std::vector<XlaOp> vals = {enabled, ans, t, y, z,
902 c, pkm1, qkm1, pkm2, qkm2,
903 dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da};
904
905 TF_ASSIGN_OR_RETURN(vals, WhileLoopHelper(cond, body, vals, "igammac", &b));
906 ans = vals[1];
907 if (mode == VALUE) {
908 return ans * ax;
909 }
910
911 dans_da = vals[14];
912 XlaOp dlogax_da = Log(x) - Digamma(a);
913
914 switch (mode) {
915 case DERIVATIVE:
916 return ax * (ans * dlogax_da + dans_da);
917 case SAMPLE_DERIVATIVE:
918 default:
919 return -(dans_da + ans * dlogax_da) * x;
920 }
921 });
922 }
923
924 } // namespace
925
Igamma(XlaOp a,XlaOp x)926 XlaOp Igamma(XlaOp a, XlaOp x) {
927 auto& b = *a.builder();
928 auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
929 XlaOp is_nan = Or(IsNan(a), IsNan(x));
930 XlaOp x_is_zero = Eq(x, ScalarLike(x, 0));
931 XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
932 XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a));
933 XlaOp ax = a * Log(x) - x - Lgamma(a);
934 XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
935 ax = Exp(ax);
936 XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan));
937 const double nan = std::numeric_limits<double>::quiet_NaN();
938 XlaOp output = Select(
939 use_igammac,
940 ScalarLike(a, 1) - IgammacContinuedFraction<VALUE>(
941 ax, x, a, And(enabled, use_igammac), type),
942 IgammaSeries<VALUE>(ax, x, a, And(enabled, Not(use_igammac)), type));
943 output = Select(x_is_zero, ZerosLike(output), output);
944 output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
945 return output;
946 };
947 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
948 TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
949 TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
950 if (a_shape != x_shape) {
951 return InvalidArgument(
952 "Arguments to Igamma must have equal shapes and types; got %s and %s",
953 a_shape.ToString(), x_shape.ToString());
954 }
955 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a));
956 PrimitiveType a_x_type = a_shape.element_type();
957 bool needs_upcast =
958 a_shape.element_type() == F16 || a_shape.element_type() == BF16;
959
960 if (needs_upcast) {
961 a = ConvertElementType(a, F32);
962 x = ConvertElementType(x, F32);
963 a_x_type = F32;
964 }
965 XlaOp result = doit(a, x, a_x_type);
966 if (needs_upcast) {
967 result = ConvertElementType(result, a_shape.element_type());
968 }
969 return result;
970 });
971 }
972
IgammaGradA(XlaOp a,XlaOp x)973 XlaOp IgammaGradA(XlaOp a, XlaOp x) {
974 auto& b = *a.builder();
975 auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
976 XlaOp is_nan = Or(IsNan(a), IsNan(x));
977 XlaOp x_is_zero = Eq(x, ScalarLike(x, 0));
978 XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
979 XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a));
980 XlaOp ax = a * Log(x) - x - Lgamma(a);
981 XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
982 ax = Exp(ax);
983 XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan));
984 const double nan = std::numeric_limits<double>::quiet_NaN();
985 XlaOp output = Select(use_igammac,
986 -IgammacContinuedFraction<DERIVATIVE>(
987 ax, x, a, And(enabled, use_igammac), type),
988 IgammaSeries<DERIVATIVE>(
989 ax, x, a, And(enabled, Not(use_igammac)), type));
990 output = Select(x_is_zero, ZerosLike(output), output);
991 output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
992 return output;
993 };
994 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
995 TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
996 TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
997 if (a_shape != x_shape) {
998 return InvalidArgument(
999 "Arguments to IgammaGradA must have equal shapes and types; got %s "
1000 "and %s",
1001 a_shape.ToString(), x_shape.ToString());
1002 }
1003 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
1004 bool needs_upcast =
1005 a_shape.element_type() == F16 || a_shape.element_type() == BF16;
1006
1007 if (needs_upcast) {
1008 a = ConvertElementType(a, F32);
1009 x = ConvertElementType(x, F32);
1010 }
1011 XlaOp result = doit(a, x, a_shape.element_type());
1012 if (needs_upcast) {
1013 result = ConvertElementType(result, a_shape.element_type());
1014 }
1015 return result;
1016 });
1017 }
1018
1019 // Gradient of Gamma sample from Gamma(a, 1) with respect to `a`.
RandomGammaGrad(XlaOp a,XlaOp x)1020 XlaOp RandomGammaGrad(XlaOp a, XlaOp x) {
1021 auto& b = *a.builder();
1022 auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
1023 XlaOp is_nan = Or(IsNan(a), IsNan(x));
1024 XlaOp x_is_zero = Eq(x, ScalarLike(x, 0));
1025 XlaOp domain_error = Or(Lt(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
1026 XlaOp use_igammac = And(Gt(x, ScalarLike(x, 1)), Gt(x, a));
1027 XlaOp ax = a * Log(x) - x - Lgamma(a);
1028 XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
1029 ax = Exp(ax);
1030 XlaOp enabled = Not(Or(Or(Or(x_is_zero, domain_error), underflow), is_nan));
1031 const double nan = std::numeric_limits<double>::quiet_NaN();
1032 XlaOp output = Select(use_igammac,
1033 -IgammacContinuedFraction<SAMPLE_DERIVATIVE>(
1034 ax, x, a, And(enabled, use_igammac), type),
1035 IgammaSeries<SAMPLE_DERIVATIVE>(
1036 ax, x, a, And(enabled, Not(use_igammac)), type));
1037 output = Select(x_is_zero, ZerosLike(output), output);
1038 output = Select(Or(domain_error, is_nan), FullLike(a, nan), output);
1039 return output;
1040 };
1041 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1042 TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
1043 TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
1044 if (a_shape != x_shape) {
1045 return InvalidArgument(
1046 "Arguments to RandomGammaGrad must have equal shapes and types; got "
1047 "%s and %s",
1048 a_shape.ToString(), x_shape.ToString());
1049 }
1050 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RandomGammaGrad", a));
1051 bool needs_upcast =
1052 a_shape.element_type() == F16 || a_shape.element_type() == BF16;
1053
1054 if (needs_upcast) {
1055 a = ConvertElementType(a, F32);
1056 x = ConvertElementType(x, F32);
1057 }
1058 XlaOp result = doit(a, x, a_shape.element_type());
1059 if (needs_upcast) {
1060 result = ConvertElementType(result, a_shape.element_type());
1061 }
1062 return result;
1063 });
1064 }
1065
Igammac(XlaOp a,XlaOp x)1066 XlaOp Igammac(XlaOp a, XlaOp x) {
1067 auto& b = *a.builder();
1068 auto doit = [&b](XlaOp a, XlaOp x, PrimitiveType type) -> XlaOp {
1069 XlaOp out_of_range = Or(Le(x, ScalarLike(x, 0)), Le(a, ScalarLike(a, 0)));
1070 XlaOp use_igamma = Or(Lt(x, ScalarLike(x, 1)), Lt(x, a));
1071 XlaOp ax = a * Log(x) - x - Lgamma(a);
1072 XlaOp underflow = Lt(ax, -Log(MaxFiniteValue(&b, type)));
1073 XlaOp enabled = Not(Or(out_of_range, underflow));
1074 ax = Exp(ax);
1075 XlaOp result =
1076 Select(use_igamma,
1077 ScalarLike(a, 1) - IgammaSeries<VALUE>(
1078 ax, x, a, And(enabled, use_igamma), type),
1079 IgammacContinuedFraction<VALUE>(
1080 ax, x, a, And(enabled, Not(use_igamma)), type));
1081 return Select(out_of_range, FullLike(a, 1), result);
1082 };
1083 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1084 TF_ASSIGN_OR_RETURN(auto a_shape, b.GetShape(a));
1085 TF_ASSIGN_OR_RETURN(auto x_shape, b.GetShape(x));
1086 if (a_shape != x_shape) {
1087 return InvalidArgument(
1088 "Arguments to Igammac must have equal shapes and types; "
1089 "got %s and %s",
1090 a_shape.ToString(), x_shape.ToString());
1091 }
1092 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igammac", a));
1093 PrimitiveType a_x_type = a_shape.element_type();
1094 bool needs_upcast =
1095 a_shape.element_type() == F16 || a_shape.element_type() == BF16;
1096
1097 if (needs_upcast) {
1098 a = ConvertElementType(a, F32);
1099 x = ConvertElementType(x, F32);
1100 a_x_type = F32;
1101 }
1102 XlaOp result = doit(a, x, a_x_type);
1103 if (needs_upcast) {
1104 result = ConvertElementType(result, a_shape.element_type());
1105 }
1106 return result;
1107 });
1108 }
1109 // Implements Banker's rounding: numbers that are equidistant between two
1110 // integers are rounded towards even.
RoundToEven(XlaOp x)1111 XlaOp RoundToEven(XlaOp x) {
1112 auto& b = *x.builder();
1113 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1114 // Reject non-real non-fp inputs (What does it even mean to round a complex
1115 // number? Do you round each component equally? In that case, you should
1116 // just ask for that explicitly.)
1117 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RoundToEven", x));
1118
1119 auto half = ScalarLike(x, 0.5);
1120 auto one = ScalarLike(x, 1.0);
1121 auto two = ScalarLike(x, 2.0);
1122
1123 auto round_val = Floor(x);
1124 auto fraction = x - round_val;
1125 auto nearest_even_int = round_val - two * Floor(half * x);
1126 auto is_odd = Eq(nearest_even_int, one);
1127 return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)),
1128 round_val + one, round_val);
1129 });
1130 }
1131
1132 // Trigonometric functions.
1133
1134 // acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1
1135 // pi if x == -1
1136 // For complex:
1137 // acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x))))
Acos(XlaOp x)1138 XlaOp Acos(XlaOp x) {
1139 XlaBuilder* b = x.builder();
1140 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1141 TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
1142
1143 if (primitive_util::IsComplexType(shape.element_type())) {
1144 auto one = ScalarLike(x, 1);
1145 auto imag_one = Complex(
1146 Zero(b, primitive_util::ComplexComponentType(shape.element_type())),
1147 One(b, primitive_util::ComplexComponentType(shape.element_type())));
1148
1149 auto result =
1150 Neg(imag_one * Log(x + imag_one * Sqrt((one + x) * (one - x))));
1151 return result;
1152 }
1153 return Select(Ne(x, FullLike(x, -1)),
1154 ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x),
1155 ScalarLike(x, 1.0) + x),
1156 FullLike(x, M_PI));
1157 });
1158 }
1159
1160 // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
Asin(XlaOp x)1161 XlaOp Asin(XlaOp x) {
1162 return ScalarLike(x, 2.0) *
1163 Atan2(x, ScalarLike(x, 1.0) + Sqrt(ScalarLike(x, 1.0) - x * x));
1164 }
1165
Atan(XlaOp x)1166 XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); }
1167
Tan(XlaOp x)1168 XlaOp Tan(XlaOp x) {
1169 return DoWithUpcastToF32(x, {F16}, [](XlaOp x) { return Sin(x) / Cos(x); });
1170 }
1171
1172 // Hyperbolic trigonometric functions.
1173
1174 // acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1
1175 // = log(x + sqrt((x+1)*(x-1)))
1176 // acosh(x) = nan if x < -1
1177 //
1178 // If x^2 will overflow, we approximate sqrt(x^2 - 1) == x and compute as
1179 // log(2*x) = log(2) + log(x). (Note this works because negative x never
1180 // overflows; x < -1 simply yields nan. This is quite different than asinh!)
Acosh(XlaOp x)1181 XlaOp Acosh(XlaOp x) {
1182 XlaBuilder* b = x.builder();
1183 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1184 TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
1185
1186 auto one = ScalarLike(x, 1);
1187 auto neg_one = ScalarLike(x, -1);
1188 auto nan = FullLike(x, std::numeric_limits<float>::quiet_NaN());
1189
1190 // return
1191 //
1192 // nan if x < -1
1193 // log(x) + log(2) if x >= sqrt_max_value
1194 // log(x + sqrt((x+1)*(x-1))) otherwise
1195 //
1196 // TODO(jlebar): For now, we ignore the question of overflow if x is a
1197 // complex type, because we don't yet have exhaustive tests for complex trig
1198 // functions.
1199 auto naive_result = Log(x + Sqrt((x + one) * (x - one)));
1200 if (primitive_util::IsComplexType(shape.element_type())) {
1201 return naive_result;
1202 }
1203 auto overflow_result = Log(x) + Log(ScalarLike(x, 2));
1204
1205 auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type()));
1206 return Select(Lt(x, neg_one), nan,
1207 Select(Ge(x, sqrt_max_value), overflow_result, naive_result));
1208 });
1209 }
1210
1211 // asinh(x) = log(x + sqrt(x^2 + 1))
1212 //
1213 // If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1)
1214 // as 2*x and return log(2) + log(x).
1215 //
1216 // If x is negative, the above would give us some trouble; we can't approximate
1217 // the result as x + abs(x) = 0! But we're saved by the fact that asinh(-x) =
1218 // -asinh(x).
Asinh(XlaOp x)1219 XlaOp Asinh(XlaOp x) {
1220 XlaBuilder* b = x.builder();
1221 auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
1222 TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
1223 auto one = ScalarLike(x, 1);
1224
1225 // Let a = abs(x). Compute
1226 //
1227 // y = log(a + sqrt(a*a + 1)) if a < sqrt_max_value, or
1228 // y = log(a) + log(2) otherwise
1229 //
1230 // and then return
1231 //
1232 // y * sign(x).
1233 //
1234 // TODO(jlebar): For now, we ignore the question of overflow if x is a
1235 // complex type, because we don't yet have exhaustive tests for complex trig
1236 // functions.
1237 if (primitive_util::IsComplexType(shape.element_type())) {
1238 return Log(x + Sqrt(x * x + one));
1239 }
1240 // For small x, sqrt(x**2 + 1) will evaluate to 1 due to floating point
1241 // arithmetic. However, we would like to retain the low order term of this,
1242 // which is around 0.5 * x**2 using a binomial expansion.
1243 // Let z = sqrt(a**2 + 1)
1244 // log(a + sqrt(a**2 + 1)) =
1245 // log((a + sqrt(a**2 + 1)) * (1 + sqrt(a**2 + 1)) / (1 + sqrt(a**2 + 1))) =
1246 // log((a + a**2 + 1 + a * z + z) / (1 + z)) =
1247 // log(1 + a + a**2 / (1 + z)) =
1248 // log(1 + a + a ** 2 / (1 + sqrt(a**2 + 1)))
1249 // This rewrite retains the lower order term.
1250 auto a = Abs(x);
1251 auto small_result = Log1p(a + a * a / (one + Sqrt(a * a + one)));
1252 auto naive_result = Log(a + Sqrt(a * a + one));
1253 auto overflow_result = Log(Abs(a)) + Log(ScalarLike(a, 2));
1254 auto sqrt_max_value = Sqrt(MaxFiniteValue(b, shape.element_type()));
1255 return Sign(x) * Select(Ge(a, sqrt_max_value), overflow_result,
1256 Select(Le(a, one), small_result, naive_result));
1257 };
1258 // These upcasts are not strictly necessary on all platforms to get within our
1259 // error tolerances, so we could relax this if it ever mattered.
1260 return DoWithUpcastToF32(x, {BF16, F16}, [&](XlaOp x) {
1261 return b->ReportErrorOrReturn(do_it(x));
1262 });
1263 }
1264
1265 // atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
1266 // atanh(x) = nan otherwise
Atanh(XlaOp x)1267 XlaOp Atanh(XlaOp x) {
1268 XlaBuilder* b = x.builder();
1269 auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
1270 TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
1271 auto naive_result = (Log1p(x) - Log1p(-x)) * ScalarLike(x, 0.5);
1272
1273 // TODO(jlebar): For now, we ignore the nan edge case for complex inputs,
1274 // because we don't yet have exhaustive tests for complex trig functions.
1275 if (primitive_util::IsComplexType(shape.element_type())) {
1276 return naive_result;
1277 }
1278
1279 auto nan = FullLike(x, std::numeric_limits<float>::quiet_NaN());
1280 return Select(Gt(Abs(x), ScalarLike(x, 1)), nan, naive_result);
1281 };
1282 return DoWithUpcastToF32(x, {BF16}, [&](XlaOp x) { //
1283 return b->ReportErrorOrReturn(do_it(x));
1284 });
1285 }
1286
1287 // Cosh(x) = (e^x + e^-x) / 2
1288 // = e^(x + log(1/2)) + e^(-x + log(1/2)).
1289 //
1290 // The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
1291 // inf.
1292 //
1293 // This incorrectly overflows to inf for two f32 input values, namely
1294 // +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
1295 // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
1296 // we deem this acceptable.
Cosh(XlaOp x)1297 XlaOp Cosh(XlaOp x) {
1298 return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) {
1299 auto log_one_half = Log(ScalarLike(x, 0.5));
1300 return Exp(x + log_one_half) + Exp(-x + log_one_half);
1301 });
1302 }
1303
1304 // Sinh(x) = (e^x - e^-x) / 2
1305 // = e^(x + log(1/2)) - e^(-x + log(1/2)).
1306 //
1307 // The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
1308 // inf.
1309 //
1310 // This incorrectly overflows to +/-inf for two f32 input values, namely
1311 // +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
1312 // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
1313 // we deem this acceptable.
Sinh(XlaOp x)1314 XlaOp Sinh(XlaOp x) {
1315 XlaBuilder* b = x.builder();
1316 auto do_it = [&](XlaOp x) -> StatusOr<XlaOp> {
1317 TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
1318 auto one_half = ScalarLike(x, 0.5);
1319 auto log_one_half = Log(ScalarLike(x, 0.5));
1320 auto large_sinh_result = Exp(x + log_one_half) - Exp(-x + log_one_half);
1321
1322 if (primitive_util::IsComplexType(shape.element_type())) {
1323 return large_sinh_result;
1324 }
1325
1326 // Here we use e^x = e^(x / 2) * e^(x / 2). This avoids overflow for large
1327 // values of x.
1328
1329 // For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in
1330 // 0.
1331 // Rewrite this to avoid that. We use expm1(x) because that preserves the
1332 // first order term of the taylor series of e^x.
1333 // (e^(x) - e^(-x)) / 2. =
1334 // (e^(x) - 1 + 1 - e^(-x)) / 2.
1335 // (expm1(x) + (e^(x) - 1) / e^x) / 2.
1336 // (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2.
1337 auto expm1 = Expm1(x);
1338 auto one = ScalarLike(x, 1.);
1339 auto small_sinh_result = one_half * (expm1 + expm1 / (expm1 + one));
1340 return Select(Lt(Abs(x), one), small_sinh_result, large_sinh_result);
1341 };
1342 return DoWithUpcastToF32(x, {BF16, F16}, [&](XlaOp x) {
1343 return b->ReportErrorOrReturn(do_it(x));
1344 });
1345 }
1346
MaybeConjugate(XlaOp x,bool conjugate)1347 XlaOp MaybeConjugate(XlaOp x, bool conjugate) {
1348 XlaBuilder* builder = x.builder();
1349 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1350 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
1351 auto perform_conj =
1352 primitive_util::IsComplexType(shape.element_type()) && conjugate;
1353 return perform_conj ? Conj(x) : x;
1354 });
1355 }
1356
NextAfter(XlaOp from,XlaOp to)1357 XlaOp NextAfter(XlaOp from, XlaOp to) {
1358 auto builder = from.builder();
1359 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1360 TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(from));
1361 int bitwidth = primitive_util::BitWidth(shape.element_type());
1362 auto int_type = primitive_util::UnsignedIntegralTypeForBitWidth(bitwidth);
1363 auto from_as_int = BitcastConvertType(from, int_type);
1364 auto to_as_int = BitcastConvertType(to, int_type);
1365
1366 // The result is NaN if either "from" or "to" are NaN.
1367 auto from_is_nan = Ne(from, from);
1368 auto to_is_nan = Ne(to, to);
1369 auto nan_input = Or(from_is_nan, to_is_nan);
1370 auto result_for_nan =
1371 Broadcast(ScalarLike(from, std::numeric_limits<double>::quiet_NaN()),
1372 shape.dimensions());
1373 result_for_nan = BitcastConvertType(result_for_nan, int_type);
1374
1375 // The sign bit is the MSB.
1376 const int64 sign_mask = int64{1} << (bitwidth - 1);
1377 // Discard the sign bit to make the result non-negative.
1378 auto from_abs = And(from_as_int, ScalarLike(from_as_int, ~sign_mask));
1379 auto to_abs = And(to_as_int, ScalarLike(to_as_int, ~sign_mask));
1380
1381 // When both "from" and "to" are equal, the result is "to".
1382 // N.B. It would not make a difference if we chose the result to be "from".
1383 auto from_and_to_are_equal = Eq(from_as_int, to_as_int);
1384 auto result_for_equal = to_as_int;
1385
1386 // When both "from" and "to" are both 0, the result is "to". This ensures we
1387 // get a zero signed like "to".
1388 auto from_is_zero = Eq(from_abs, ZerosLike(from_abs));
1389 auto to_is_zero = Eq(to_abs, ZerosLike(to_abs));
1390 auto result_for_both_zero = to_as_int;
1391
1392 auto from_sign = And(from_as_int, ScalarLike(from_as_int, sign_mask));
1393 auto to_sign = And(to_as_int, ScalarLike(to_as_int, sign_mask));
1394
1395 // If from == 0 && to != 0, we need to return the smallest subnormal number
1396 // signed like "to".
1397 auto result_for_from_zero_to_non_zero =
1398 Or(to_sign, ScalarLike(from_as_int, 1));
1399
1400 // If the sign of "from" and "to" disagree:
1401 // - we need to make the magnitude of "from" smaller so that it is closer to
1402 // zero.
1403 //
1404 // Otherwise the signs agree:
1405 // - "from" with a magnitude larger than "to" means we need to make the
1406 // magnitude smaller.
1407 // - "from" with a magnitude smaller than "to" means we need to make the
1408 // magnitude larger.
1409 // - "from" with the same magnitude and sign as "to" has already been
1410 // handled.
1411 auto signs_disagree = Ne(from_sign, to_sign);
1412 auto from_magnitude_larger_than_to = Gt(from_abs, to_abs);
1413 auto result_has_smaller_magnitude =
1414 Or(from_magnitude_larger_than_to, signs_disagree);
1415 auto magnitude_adjustment =
1416 Select(result_has_smaller_magnitude,
1417 Broadcast(ScalarLike(from_as_int, -1), shape.dimensions()),
1418 Broadcast(ScalarLike(from_as_int, 1), shape.dimensions()));
1419 auto result = Add(from_as_int, magnitude_adjustment);
1420 // Handle from == ±0.
1421 result = Select(from_is_zero,
1422 Select(to_is_zero, result_for_both_zero,
1423 result_for_from_zero_to_non_zero),
1424 result);
1425 // Handle from == to.
1426 result = Select(from_and_to_are_equal, result_for_equal, result);
1427 // Handle isnan(from) || isnan(to).
1428 result = Select(nan_input, result_for_nan, result);
1429
1430 // Cast back to the original type.
1431 return BitcastConvertType(result, shape.element_type());
1432 });
1433 }
1434
1435 // Computes an approximation to the modified Bessel function of the first kind,
1436 // zeroth order.
1437 // The following implementation follows Cephes' F32 and F64 implementation of
1438 // i0e.
I0eImpl32(XlaOp x)1439 static XlaOp I0eImpl32(XlaOp x) {
1440 static const std::array<float, 18> kI0eCoeffsA{
1441 -1.30002500998624804212E-8f, 6.04699502254191894932E-8f,
1442 -2.67079385394061173391E-7f, 1.11738753912010371815E-6f,
1443 -4.41673835845875056359E-6f, 1.64484480707288970893E-5f,
1444 -5.75419501008210370398E-5f, 1.88502885095841655729E-4f,
1445 -5.76375574538582365885E-4f, 1.63947561694133579842E-3f,
1446 -4.32430999505057594430E-3f, 1.05464603945949983183E-2f,
1447 -2.37374148058994688156E-2f, 4.93052842396707084878E-2f,
1448 -9.49010970480476444210E-2f, 1.71620901522208775349E-1f,
1449 -3.04682672343198398683E-1f, 6.76795274409476084995E-1f};
1450
1451 static const std::array<float, 7> kI0eCoeffsB{
1452 3.39623202570838634515E-9f, 2.26666899049817806459E-8f,
1453 2.04891858946906374183E-7f, 2.89137052083475648297E-6f,
1454 6.88975834691682398426E-5f, 3.36911647825569408990E-3f,
1455 8.04490411014108831608E-1f};
1456
1457 x = Abs(x);
1458 auto half = xla::ScalarLike(x, 0.5);
1459 auto two = xla::ScalarLike(x, 2.0);
1460 auto thirty_two = xla::ScalarLike(x, 32.0);
1461 auto result_le_8 =
1462 EvaluateChebyshevPolynomial<float>(half * x - two, kI0eCoeffsA);
1463 auto result_gt_8 =
1464 EvaluateChebyshevPolynomial<float>(thirty_two / x - two, kI0eCoeffsB) /
1465 Sqrt(x);
1466 return Select(Le(x, xla::ScalarLike(x, 8.0)), result_le_8, result_gt_8);
1467 }
1468
I0eImpl64(XlaOp x)1469 static XlaOp I0eImpl64(XlaOp x) {
1470 static const std::array<double, 30> kI0eCoeffsA{
1471 -4.41534164647933937950E-18, 3.33079451882223809783E-17,
1472 -2.43127984654795469359E-16, 1.71539128555513303061E-15,
1473 -1.16853328779934516808E-14, 7.67618549860493561688E-14,
1474 -4.85644678311192946090E-13, 2.95505266312963983461E-12,
1475 -1.72682629144155570723E-11, 9.67580903537323691224E-11,
1476 -5.18979560163526290666E-10, 2.65982372468238665035E-9,
1477 -1.30002500998624804212E-8, 6.04699502254191894932E-8,
1478 -2.67079385394061173391E-7, 1.11738753912010371815E-6,
1479 -4.41673835845875056359E-6, 1.64484480707288970893E-5,
1480 -5.75419501008210370398E-5, 1.88502885095841655729E-4,
1481 -5.76375574538582365885E-4, 1.63947561694133579842E-3,
1482 -4.32430999505057594430E-3, 1.05464603945949983183E-2,
1483 -2.37374148058994688156E-2, 4.93052842396707084878E-2,
1484 -9.49010970480476444210E-2, 1.71620901522208775349E-1,
1485 -3.04682672343198398683E-1, 6.76795274409476084995E-1};
1486
1487 static const std::array<double, 25> kI0eCoeffsB{
1488 -7.23318048787475395456E-18, -4.83050448594418207126E-18,
1489 4.46562142029675999901E-17, 3.46122286769746109310E-17,
1490 -2.82762398051658348494E-16, -3.42548561967721913462E-16,
1491 1.77256013305652638360E-15, 3.81168066935262242075E-15,
1492 -9.55484669882830764870E-15, -4.15056934728722208663E-14,
1493 1.54008621752140982691E-14, 3.85277838274214270114E-13,
1494 7.18012445138366623367E-13, -1.79417853150680611778E-12,
1495 -1.32158118404477131188E-11, -3.14991652796324136454E-11,
1496 1.18891471078464383424E-11, 4.94060238822496958910E-10,
1497 3.39623202570838634515E-9, 2.26666899049817806459E-8,
1498 2.04891858946906374183E-7, 2.89137052083475648297E-6,
1499 6.88975834691682398426E-5, 3.36911647825569408990E-3,
1500 8.04490411014108831608E-1};
1501
1502 x = Abs(x);
1503 auto half = xla::ScalarLike(x, 0.5);
1504 auto two = xla::ScalarLike(x, 2.0);
1505 auto thirty_two = xla::ScalarLike(x, 32.0);
1506 auto result_le_8 =
1507 EvaluateChebyshevPolynomial<double>(half * x - two, kI0eCoeffsA);
1508 auto result_gt_8 =
1509 EvaluateChebyshevPolynomial<double>(thirty_two / x - two, kI0eCoeffsB) /
1510 Sqrt(x);
1511 return Select(Le(x, xla::ScalarLike(x, 8.0)), result_le_8, result_gt_8);
1512 }
1513
BesselI0e(XlaOp x)1514 XlaOp BesselI0e(XlaOp x) {
1515 auto& b = *x.builder();
1516 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1517 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("BesselI0e", x));
1518 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
1519 if (shape.element_type() == F64) {
1520 return I0eImpl64(x);
1521 }
1522 // I0eF32Impl don't have enough precision when run with bf16 intermediates
1523 // (not surprising!), so upcast to f32 in this case.
1524 return DoWithUpcastToF32(x, {BF16, F16},
1525 [](XlaOp x) { return I0eImpl32(x); });
1526 });
1527 }
1528
1529 // Computes an approximation to the modified Bessel function of the first kind,
1530 // first order.
1531 // The following implementation follows Cephes' F32 and F64 implementation of
1532 // i1e.
1533
I1eImpl32(XlaOp x)1534 static XlaOp I1eImpl32(XlaOp x) {
1535 static const std::array<float, 17> kI1eCoeffsA{
1536 9.38153738649577178388E-9f, -4.44505912879632808065E-8f,
1537 2.00329475355213526229E-7f, -8.56872026469545474066E-7f,
1538 3.47025130813767847674E-6f, -1.32731636560394358279E-5f,
1539 4.78156510755005422638E-5f, -1.61760815825896745588E-4f,
1540 5.12285956168575772895E-4f, -1.51357245063125314899E-3f,
1541 4.15642294431288815669E-3f, -1.05640848946261981558E-2f,
1542 2.47264490306265168283E-2f, -5.29459812080949914269E-2f,
1543 1.02643658689847095384E-1f, -1.76416518357834055153E-1f,
1544 2.52587186443633654823E-1f};
1545
1546 static const std::array<float, 7> kI1eCoeffsB{
1547 -3.83538038596423702205E-9f, -2.63146884688951950684E-8f,
1548 -2.51223623787020892529E-7f, -3.88256480887769039346E-6f,
1549 -1.10588938762623716291E-4f, -9.76109749136146840777E-3f,
1550 7.78576235018280120474E-1f};
1551 XlaOp z = Abs(x);
1552 auto half = xla::ScalarLike(x, 0.5);
1553 auto two = xla::ScalarLike(x, 2.0);
1554 auto thirty_two = xla::ScalarLike(x, 32.0);
1555 auto result_le_8 =
1556 z * EvaluateChebyshevPolynomial<float>(half * z - two, kI1eCoeffsA);
1557 auto result_gt_8 =
1558 EvaluateChebyshevPolynomial<float>(thirty_two / z - two, kI1eCoeffsB) /
1559 Sqrt(z);
1560 return Sign(x) *
1561 Select(Le(z, xla::ScalarLike(x, 8.0)), result_le_8, result_gt_8);
1562 }
1563
I1eImpl64(XlaOp x)1564 static XlaOp I1eImpl64(XlaOp x) {
1565 static const std::array<double, 29> kI1eCoeffsA{
1566 2.77791411276104639959E-18, -2.11142121435816608115E-17,
1567 1.55363195773620046921E-16, -1.10559694773538630805E-15,
1568 7.60068429473540693410E-15, -5.04218550472791168711E-14,
1569 3.22379336594557470981E-13, -1.98397439776494371520E-12,
1570 1.17361862988909016308E-11, -6.66348972350202774223E-11,
1571 3.62559028155211703701E-10, -1.88724975172282928790E-9,
1572 9.38153738649577178388E-9, -4.44505912879632808065E-8,
1573 2.00329475355213526229E-7, -8.56872026469545474066E-7,
1574 3.47025130813767847674E-6, -1.32731636560394358279E-5,
1575 4.78156510755005422638E-5, -1.61760815825896745588E-4,
1576 5.12285956168575772895E-4, -1.51357245063125314899E-3,
1577 4.15642294431288815669E-3, -1.05640848946261981558E-2,
1578 2.47264490306265168283E-2, -5.29459812080949914269E-2,
1579 1.02643658689847095384E-1, -1.76416518357834055153E-1,
1580 2.52587186443633654823E-1};
1581
1582 static const std::array<double, 25> kI1eCoeffsB{
1583 7.51729631084210481353E-18, 4.41434832307170791151E-18,
1584 -4.65030536848935832153E-17, -3.20952592199342395980E-17,
1585 2.96262899764595013876E-16, 3.30820231092092828324E-16,
1586 -1.88035477551078244854E-15, -3.81440307243700780478E-15,
1587 1.04202769841288027642E-14, 4.27244001671195135429E-14,
1588 -2.10154184277266431302E-14, -4.08355111109219731823E-13,
1589 -7.19855177624590851209E-13, 2.03562854414708950722E-12,
1590 1.41258074366137813316E-11, 3.25260358301548823856E-11,
1591 -1.89749581235054123450E-11, -5.58974346219658380687E-10,
1592 -3.83538038596423702205E-9, -2.63146884688951950684E-8,
1593 -2.51223623787020892529E-7, -3.88256480887769039346E-6,
1594 -1.10588938762623716291E-4, -9.76109749136146840777E-3,
1595 7.78576235018280120474E-1};
1596
1597 XlaOp z = Abs(x);
1598 auto half = xla::ScalarLike(x, 0.5);
1599 auto two = xla::ScalarLike(x, 2.0);
1600 auto thirty_two = xla::ScalarLike(x, 32.0);
1601 auto result_le_8 =
1602 z * EvaluateChebyshevPolynomial<double>(half * z - two, kI1eCoeffsA);
1603 auto result_gt_8 =
1604 EvaluateChebyshevPolynomial<double>(thirty_two / z - two, kI1eCoeffsB) /
1605 Sqrt(z);
1606 return Sign(x) *
1607 Select(Le(z, xla::ScalarLike(x, 8.0)), result_le_8, result_gt_8);
1608 }
1609
BesselI1e(XlaOp x)1610 XlaOp BesselI1e(XlaOp x) {
1611 auto& b = *x.builder();
1612 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1613 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("BesselI1e", x));
1614 TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
1615 if (shape.element_type() == F64) {
1616 return I1eImpl64(x);
1617 }
1618 // I1eF32Impl don't have enough precision when run with bf16 intermediates
1619 // (not surprising!), so upcast to f32 in this case.
1620 return DoWithUpcastToF32(x, {BF16, F16},
1621 [](XlaOp x) { return I1eImpl32(x); });
1622 });
1623 }
1624
1625 // I J Thompson and A R Barnett. 1986. Coulomb and Bessel functions of complex
1626 // arguments and order. J. Comput. Phys. 64, 2 (June 1986), 490-509.
1627 // DOI=http://dx.doi.org/10.1016/0021-9991(86)90046-X
LentzThompsonBarnettAlgorithm(int64 num_iterations,double small,double threshold,const ForEachIndexBodyFunction & nth_partial_numerator,const ForEachIndexBodyFunction & nth_partial_denominator,absl::Span<const XlaOp> inputs,absl::string_view name)1628 static XlaOp LentzThompsonBarnettAlgorithm(
1629 int64 num_iterations, double small, double threshold,
1630 const ForEachIndexBodyFunction& nth_partial_numerator,
1631 const ForEachIndexBodyFunction& nth_partial_denominator,
1632 absl::Span<const XlaOp> inputs, absl::string_view name) {
1633 auto& b = *inputs.front().builder();
1634 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1635 TF_RET_CHECK(num_iterations < INT32_MAX);
1636
1637 enum {
1638 // Position in the evaluation.
1639 kIterationIdx,
1640 // Whether or not we have reached the desired tolerance.
1641 kValuesUnconvergedIdx,
1642 // Ratio between nth canonical numerator and the nth-1 canonical
1643 // numerator.
1644 kCIdx,
1645 // Ratio between nth-1 canonical denominator and the nth canonical
1646 // denominator.
1647 kDIdx,
1648 // Computed approximant in the evaluation.
1649 kHIdx,
1650 // Inputs follow all of the other state.
1651 kFirstInputIdx,
1652 };
1653 auto while_cond_fn = [num_iterations](
1654 absl::Span<const XlaOp> values,
1655 XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
1656 auto iteration = values[kIterationIdx];
1657 auto iterations_remain_cond =
1658 Lt(iteration, ScalarLike(iteration, num_iterations));
1659 auto values_unconverged_cond = values[kValuesUnconvergedIdx];
1660 return And(iterations_remain_cond, values_unconverged_cond);
1661 };
1662
1663 auto while_body_fn =
1664 [small, threshold, &nth_partial_numerator, &nth_partial_denominator](
1665 absl::Span<const XlaOp> values,
1666 XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
1667 XlaOp iteration = values[kIterationIdx];
1668
1669 TF_ASSIGN_OR_RETURN(
1670 std::vector<XlaOp> partial_numerator,
1671 nth_partial_numerator(iteration, values.subspan(kFirstInputIdx),
1672 body_builder));
1673 TF_RET_CHECK(partial_numerator.size() == 1);
1674
1675 TF_ASSIGN_OR_RETURN(
1676 std::vector<XlaOp> partial_denominator,
1677 nth_partial_denominator(iteration, values.subspan(kFirstInputIdx),
1678 body_builder));
1679 TF_RET_CHECK(partial_denominator.size() == 1);
1680
1681 auto c = partial_denominator[0] + partial_numerator[0] / values[kCIdx];
1682 auto small_constant = FullLike(c, small);
1683 c = Select(Lt(Abs(c), small_constant), small_constant, c);
1684
1685 auto d = partial_denominator[0] + partial_numerator[0] * values[kDIdx];
1686 d = Select(Lt(Abs(d), small_constant), small_constant, d);
1687
1688 d = Reciprocal(d);
1689
1690 auto delta = c * d;
1691 auto h = values[kHIdx] * delta;
1692
1693 std::vector<XlaOp> updated_values(values.size());
1694 updated_values[kIterationIdx] = Add(iteration, ScalarLike(iteration, 1));
1695 updated_values[kCIdx] = c;
1696 updated_values[kDIdx] = d;
1697 updated_values[kHIdx] = h;
1698 std::copy(values.begin() + kFirstInputIdx, values.end(),
1699 updated_values.begin() + kFirstInputIdx);
1700
1701 // If any values are greater than the tolerance, we have not converged.
1702 auto tolerance_comparison =
1703 Ge(Abs(Sub(delta, FullLike(delta, 1.0))), FullLike(delta, threshold));
1704 updated_values[kValuesUnconvergedIdx] =
1705 ReduceAll(tolerance_comparison, ConstantR0<bool>(body_builder, false),
1706 CreateScalarOrComputation(PRED, body_builder));
1707 return updated_values;
1708 };
1709
1710 TF_ASSIGN_OR_RETURN(std::vector<XlaOp> partial_denominator,
1711 nth_partial_denominator(Zero(&b, U32), inputs, &b));
1712 TF_RET_CHECK(partial_denominator.size() == 1);
1713 auto h = partial_denominator[0];
1714 auto small_constant = FullLike(h, small);
1715 h = Select(Lt(Abs(h), small_constant), small_constant, h);
1716
1717 std::vector<XlaOp> values(kFirstInputIdx + inputs.size());
1718 values[kIterationIdx] = One(&b, U32);
1719 values[kValuesUnconvergedIdx] = ConstantR0<bool>(&b, true);
1720 values[kCIdx] = h;
1721 values[kDIdx] = FullLike(h, 0.0);
1722 values[kHIdx] = h;
1723 std::copy(inputs.begin(), inputs.end(), values.begin() + kFirstInputIdx);
1724 TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
1725 values, name, &b));
1726 return values[kHIdx];
1727 });
1728 }
1729
RegularizedIncompleteBeta(XlaOp a,XlaOp b,XlaOp x)1730 XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x) {
1731 auto& builder = *x.builder();
1732 return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1733 TF_ASSIGN_OR_RETURN(Shape shape, builder.GetShape(a));
1734 TF_ASSIGN_OR_RETURN(Shape b_shape, builder.GetShape(b));
1735 TF_ASSIGN_OR_RETURN(Shape x_shape, builder.GetShape(x));
1736 if (b_shape.element_type() != shape.element_type() ||
1737 x_shape.element_type() != shape.element_type()) {
1738 return InvalidArgument(
1739 "Operands to RegularizedIncompleteBeta must have identical types, "
1740 "got shapes %s, %s, and %s",
1741 shape.ToString(), b_shape.ToString(), x_shape.ToString());
1742 }
1743 if (!primitive_util::IsFloatingPointType(shape.element_type())) {
1744 return InvalidArgument(
1745 "Operands to RegularizedIncompleteBeta must be real-valued "
1746 "floating-point, but got %s",
1747 PrimitiveType_Name(shape.element_type()));
1748 }
1749 PrimitiveType element_type = shape.element_type();
1750 if (element_type == F16 || element_type == BF16) {
1751 element_type = F32;
1752 a = ConvertElementType(a, F32);
1753 b = ConvertElementType(b, F32);
1754 x = ConvertElementType(x, F32);
1755 }
1756
1757 // The partial numerator for the incomplete beta function is given
1758 // here: http://dlmf.nist.gov/8.17.E23 Note that there is a special
1759 // case: the partial numerator for the first iteration is one.
1760 auto NthPartialBetaincNumerator =
1761 [&](XlaOp iteration, absl::Span<const XlaOp> inputs,
1762 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
1763 auto a = inputs[0];
1764 auto b = inputs[1];
1765 auto x = inputs[2];
1766 auto iteration_bcast = Broadcast(iteration, shape.dimensions());
1767 auto iteration_is_even =
1768 Eq(iteration_bcast % FullLike(iteration_bcast, 2),
1769 FullLike(iteration_bcast, 0));
1770 auto iteration_is_one = Eq(iteration_bcast, FullLike(iteration_bcast, 1));
1771 auto iteration_minus_one = iteration_bcast - FullLike(iteration_bcast, 1);
1772 auto m = iteration_minus_one / FullLike(iteration_minus_one, 2);
1773 m = ConvertElementType(m, element_type);
1774 auto one = FullLike(a, 1.0);
1775 auto two = FullLike(a, 2.0);
1776 // Partial numerator terms.
1777 auto even_numerator =
1778 -(a + m) * (a + b + m) * x / ((a + two * m) * (a + two * m + one));
1779 auto odd_numerator =
1780 m * (b - m) * x / ((a + two * m - one) * (a + two * m));
1781 auto one_numerator = ScalarLike(x, 1.0);
1782 auto numerator = Select(iteration_is_even, even_numerator, odd_numerator);
1783 return std::vector<XlaOp>{
1784 Select(iteration_is_one, one_numerator, numerator)};
1785 };
1786
1787 auto NthPartialBetaincDenominator =
1788 [&shape](XlaOp iteration, absl::Span<const XlaOp> inputs,
1789 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
1790 auto x = inputs[2];
1791 auto iteration_bcast = Broadcast(iteration, shape.dimensions());
1792 return std::vector<XlaOp>{
1793 Select(Eq(iteration_bcast, ScalarLike(iteration_bcast, 0)),
1794 ScalarLike(x, 0.0), ScalarLike(x, 1.0))};
1795 };
1796
1797 // Determine if the inputs are out of range.
1798 auto result_is_nan =
1799 Or(Or(Or(Le(a, ScalarLike(a, 0.0)), Le(b, ScalarLike(b, 0.0))),
1800 Lt(x, ScalarLike(x, 0.0))),
1801 Gt(x, ScalarLike(x, 1.0)));
1802
1803 // The continued fraction will converge rapidly when x < (a+1)/(a+b+2)
1804 // as per: http://dlmf.nist.gov/8.17.E23
1805 //
1806 // Otherwise, we can rewrite using the symmetry relation as per:
1807 // http://dlmf.nist.gov/8.17.E4
1808 auto converges_rapidly =
1809 Lt(x, (a + FullLike(a, 1.0)) / (a + b + FullLike(b, 2.0)));
1810 auto a_orig = a;
1811 a = Select(converges_rapidly, a, b);
1812 b = Select(converges_rapidly, b, a_orig);
1813 x = Select(converges_rapidly, x, Sub(FullLike(x, 1.0), x));
1814
1815 XlaOp continued_fraction;
1816
1817 // Thresholds and iteration counts taken from Cephes.
1818 if (element_type == F32) {
1819 continued_fraction = LentzThompsonBarnettAlgorithm(
1820 /*num_iterations=*/200,
1821 /*small=*/std::numeric_limits<float>::epsilon() / 2.0f,
1822 /*threshold=*/std::numeric_limits<float>::epsilon() / 2.0f,
1823 /*nth_partial_numerator=*/NthPartialBetaincNumerator,
1824 /*nth_partial_denominator=*/NthPartialBetaincDenominator, {a, b, x},
1825 "Betainc");
1826 } else {
1827 TF_RET_CHECK(element_type == F64);
1828 continued_fraction = LentzThompsonBarnettAlgorithm(
1829 /*num_iterations=*/600,
1830 /*small=*/std::numeric_limits<double>::epsilon() / 2.0f,
1831 /*threshold=*/std::numeric_limits<double>::epsilon() / 2.0f,
1832 /*nth_partial_numerator=*/NthPartialBetaincNumerator,
1833 /*nth_partial_denominator=*/NthPartialBetaincDenominator, {a, b, x},
1834 "Betainc");
1835 }
1836
1837 // We want to compute the regularized complete beta function so we need to
1838 // combine the continued fraction with a few more terms as well as dividing
1839 // it by Beta(a, b). To avoid overflow, we compute in the log domain.
1840 // See http://dlmf.nist.gov/8.17.E22 for an easier to read version of this
1841 // formula.
1842 auto lbeta = Lbeta(a, b);
1843 auto result =
1844 continued_fraction * Exp(Log(x) * a + Log1p(-x) * b - lbeta) / a;
1845 result = Select(result_is_nan, NanValue(&builder, element_type), result);
1846
1847 // We have an additional fixup to do if we are taking advantage of the
1848 // symmetry relation.
1849 auto out =
1850 Select(converges_rapidly, result, Sub(FullLike(result, 1.0), result));
1851 return shape.element_type() == element_type
1852 ? out
1853 : ConvertElementType(out, shape.element_type());
1854 });
1855 }
1856
Polygamma(XlaOp n,XlaOp x)1857 XlaOp Polygamma(XlaOp n, XlaOp x) {
1858 auto& builder = *x.builder();
1859 auto doit = [](XlaOp n, XlaOp x, PrimitiveType type) -> XlaOp {
1860 XlaOp n_plus_one = n + ScalarLike(n, 1.);
1861 XlaOp sign =
1862 (ScalarLike(n, 2.) * Rem(n, ScalarLike(n, 2.)) - ScalarLike(n, 1.));
1863
1864 const double nan = std::numeric_limits<double>::quiet_NaN();
1865
1866 XlaOp output = Select(Eq(n, ScalarLike(n, 0.)), Digamma(x),
1867 sign * Exp(Lgamma(n_plus_one)) * Zeta(n_plus_one, x));
1868 // Check that n is a natural number.
1869 output = Select(Or(Ne(n, Floor(n)), Lt(n, ScalarLike(n, 0.))),
1870 ScalarLike(n, nan), output);
1871 return output;
1872 };
1873 return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1874 TF_ASSIGN_OR_RETURN(auto n_shape, builder.GetShape(n));
1875 TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x));
1876 if (n_shape != x_shape) {
1877 return InvalidArgument(
1878 "Arguments to Polygamma must have equal shapes and types; "
1879 "got %s and %s",
1880 n_shape.ToString(), x_shape.ToString());
1881 }
1882 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x));
1883 bool needs_upcast =
1884 n_shape.element_type() == F16 || x_shape.element_type() == BF16;
1885
1886 if (needs_upcast) {
1887 n = ConvertElementType(n, F32);
1888 x = ConvertElementType(x, F32);
1889 }
1890 XlaOp result = doit(n, x, n_shape.element_type());
1891 if (needs_upcast) {
1892 result = ConvertElementType(result, n_shape.element_type());
1893 }
1894 return result;
1895 });
1896 }
1897
Zeta(XlaOp x,XlaOp q)1898 XlaOp Zeta(XlaOp x, XlaOp q) {
1899 auto& builder = *x.builder();
1900 auto doit = [&builder](XlaOp x, XlaOp q, PrimitiveType type) -> XlaOp {
1901 // (2k) ! / B_{2k}, where B_{2k} are the Bernoulli numbers.
1902 // These are ordered in reverse.
1903 static const std::array<double, 12> kZetaCoeffs{
1904 -7.1661652561756670113e18,
1905 1.8152105401943546773e17,
1906 -4.5979787224074726105e15,
1907 1.1646782814350067249e14,
1908 -2.950130727918164224e12,
1909 7.47242496e10,
1910 -1.8924375803183791606e9,
1911 47900160.0,
1912 -1209600.0,
1913 30240.0,
1914 -720.0,
1915 12.0,
1916 };
1917
1918 // For speed we'll always use 9 iterations for the initial series estimate,
1919 // and a 12 term expansion for the Euler-Maclaurin formula.
1920
1921 XlaOp a = q;
1922 XlaOp neg_power = ScalarLike(a, 0.);
1923 XlaOp initial_sum = Pow(q, Neg(x));
1924 for (int i = 0; i < 9; ++i) {
1925 a = a + ScalarLike(a, 1.);
1926 neg_power = Pow(a, Neg(x));
1927 initial_sum = initial_sum + neg_power;
1928 }
1929 a = a + ScalarLike(a, 1.);
1930 neg_power = Pow(a, Neg(x));
1931 XlaOp s = initial_sum + neg_power * a / (x - ScalarLike(a, 1.));
1932 XlaOp a_inverse_square = Reciprocal(Square(a));
1933 XlaOp horner_sum = ScalarLike(a, 0.);
1934 XlaOp factor = ScalarLike(a, 1.);
1935 // Use Horner's rule for this.
1936 // Note this differs from Cephes which does a 'naive' polynomial evaluation.
1937 // Using Horner's rule allows to avoid some NaN's and Infs from happening,
1938 // resulting in more numerically stable code.
1939 for (int i = 0; i < 11; ++i) {
1940 factor =
1941 (x - ScalarLike(x, 22 - 2 * i)) * (x - ScalarLike(x, 21 - 2 * i));
1942 horner_sum = factor * a_inverse_square *
1943 (horner_sum + ScalarLike(a, 1. / kZetaCoeffs[i]));
1944 }
1945 s = s + neg_power *
1946 (ScalarLike(neg_power, 0.5) +
1947 x / a * (ScalarLike(a, 1. / kZetaCoeffs[11]) + horner_sum));
1948
1949 const double nan = std::numeric_limits<double>::quiet_NaN();
1950 const double inf = std::numeric_limits<double>::infinity();
1951 // Use the initial zeta sum without the correction term coming
1952 // from Euler-Maclaurin if it is accurate enough.
1953 XlaOp output =
1954 Select(Lt(Abs(neg_power), Abs(initial_sum) * Epsilon(&builder, type)),
1955 initial_sum, s);
1956 // This is the harmonic series.
1957 output = Select(Eq(x, ScalarLike(x, 1.)), ScalarLike(x, inf), output);
1958 // Function is not defined for x < 1.
1959 output = Select(Lt(x, ScalarLike(x, 1.)), ScalarLike(x, nan), output);
1960 // If q <= 0, then when q is an integer or x is not an integer, this is
1961 // NaN.
1962 XlaOp domain_error = And(Le(q, ScalarLike(x, 0.)), Ne(x, Floor(x)));
1963 XlaOp negative_integer_q = And(Le(q, ScalarLike(x, 0.)), Eq(q, Floor(q)));
1964 output = Select(negative_integer_q, ScalarLike(x, inf), output);
1965 output = Select(domain_error, ScalarLike(x, nan), output);
1966 return output;
1967 };
1968 return builder.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1969 TF_ASSIGN_OR_RETURN(auto x_shape, builder.GetShape(x));
1970 TF_ASSIGN_OR_RETURN(auto q_shape, builder.GetShape(q));
1971 if (x_shape != q_shape) {
1972 return InvalidArgument(
1973 "Arguments to Zeta must have equal shapes and types; got %s and %s",
1974 x_shape.ToString(), q_shape.ToString());
1975 }
1976 TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Zeta", x));
1977 bool needs_upcast =
1978 x_shape.element_type() == F16 || x_shape.element_type() == BF16;
1979
1980 if (needs_upcast) {
1981 x = ConvertElementType(x, F32);
1982 q = ConvertElementType(q, F32);
1983 }
1984 XlaOp result = doit(x, q, x_shape.element_type());
1985 if (needs_upcast) {
1986 result = ConvertElementType(result, x_shape.element_type());
1987 }
1988 return result;
1989 });
1990 }
1991
1992 } // namespace xla
1993