• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Enable the use of M_* math constants.
17 // NOTE: this must be first in the file to ensure that if cmath is transitively
18 // included by any other header it has the define set on first processing.
19 // https://docs.microsoft.com/en-us/cpp/c-runtime-library/math-constants
20 #define _USE_MATH_DEFINES
21 #include <cmath>
22 #include <numeric>
23 #include <vector>
24 
25 #include "llvm/ADT/SmallVector.h"
26 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
27 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
28 #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
29 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
30 #include "mlir-hlo/utils/broadcast_utils.h"
31 #include "mlir-hlo/utils/hlo_utils.h"
32 #include "mlir/Dialect/SCF/SCF.h"
33 #include "mlir/Dialect/Shape/IR/Shape.h"
34 #include "mlir/Dialect/StandardOps/IR/Ops.h"
35 #include "mlir/Dialect/Tensor/IR/Tensor.h"
36 #include "mlir/IR/Attributes.h"
37 #include "mlir/IR/BuiltinTypes.h"
38 #include "mlir/IR/ImplicitLocOpBuilder.h"
39 #include "mlir/IR/MLIRContext.h"
40 #include "mlir/IR/OperationSupport.h"
41 #include "mlir/IR/PatternMatch.h"
42 #include "mlir/Transforms/DialectConversion.h"
43 
44 namespace mlir {
45 namespace chlo {
46 namespace {
47 
48 struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
49   using OpConversionPattern<ConstantLikeOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertConstantLikeOp50   LogicalResult matchAndRewrite(
51       ConstantLikeOp op, ArrayRef<Value> operands,
52       ConversionPatternRewriter &rewriter) const override {
53     auto result_ty = op.getType().cast<ShapedType>();
54 
55     // Unranked uses are not supported.
56     if (!result_ty.hasRank()) return failure();
57 
58     // Lower to MHLO constant if statically shaped.
59     if (result_ty.hasStaticShape()) {
60       rewriter.replaceOpWithNewOp<mhlo::ConstOp>(
61           op, DenseElementsAttr::get(result_ty, op.value()));
62       return success();
63     }
64 
65     // Lower to broadcasted constant.
66     ConstantLikeOp::Adaptor transformed(operands);
67     auto loc = op.getLoc();
68     Type extent_tensor_type = shape::getExtentTensorType(op.getContext());
69     Value constant = rewriter.create<mhlo::ConstOp>(loc, op.value());
70     Value uncasted_shape = rewriter.create<shape::ShapeOfOp>(
71         loc, extent_tensor_type, transformed.operand());
72     Type shape_ty =
73         RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType());
74     Value shape =
75         rewriter.create<tensor::CastOp>(loc, shape_ty, uncasted_shape);
76     rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
77         op, result_ty, constant, shape, rewriter.getI64TensorAttr({}));
78     return success();
79   }
80 };
81 
82 template <typename FTy>
MaterializePolynomialApproximation(ConversionPatternRewriter & rewriter,Location loc,Value x,const std::vector<FTy> & coefficients)83 Value MaterializePolynomialApproximation(ConversionPatternRewriter &rewriter,
84                                          Location loc, Value x,
85                                          const std::vector<FTy> &coefficients) {
86   Value poly = chlo::getConstantLike(rewriter, loc, 0.0, x);
87   for (FTy c : coefficients) {
88     poly = rewriter.create<mhlo::MulOp>(loc, x.getType(), poly, x);
89     poly = rewriter.create<mhlo::AddOp>(
90         loc, x.getType(), poly, chlo::getConstantLike(rewriter, loc, c, x));
91   }
92   return poly;
93 }
94 
95 // Precondition is |x| >= 1. Use erf approximation, otherwise.
96 //
97 // We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
98 // argument and derive the final approximation for all |x| >= 1.
99 // This implementation is based on Cephes.
MaterializeErfcApproximationF64ForMagnituteGEOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)100 Value MaterializeErfcApproximationF64ForMagnituteGEOne(
101     ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
102   Value x = args.front();
103   assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
104          "expect f64 element type");
105   const double kMaxlog = 7.09782712893383996843E2;
106   const std::vector<double> kErfcPCoefficients{
107       2.46196981473530512524E-10, 5.64189564831068821977E-1,
108       7.46321056442269912687E0,   4.86371970985681366614E1,
109       1.96520832956077098242E2,   5.26445194995477358631E2,
110       9.34528527171957607540E2,   1.02755188689515710272E3,
111       5.57535335369399327526E2};
112   const std::vector<double> kErfcQCoefficients{
113       1.00000000000000000000E0, 1.32281951154744992508E1,
114       8.67072140885989742329E1, 3.54937778887819891062E2,
115       9.75708501743205489753E2, 1.82390916687909736289E3,
116       2.24633760818710981792E3, 1.65666309194161350182E3,
117       5.57535340817727675546E2};
118   const std::vector<double> kErfcRCoefficients{
119       5.64189583547755073984E-1, 1.27536670759978104416E0,
120       5.01905042251180477414E0,  6.16021097993053585195E0,
121       7.40974269950448939160E0,  2.97886665372100240670E0};
122   const std::vector<double> kErfcSCoefficients{
123       1.00000000000000000000E0, 2.26052863220117276590E0,
124       9.39603524938001434673E0, 1.20489539808096656605E1,
125       1.70814450747565897222E1, 9.60896809063285878198E0,
126       3.36907645100081516050E0};
127 
128   // Let z = -x^2.
129   Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
130   Value z = rewriter.create<mhlo::NegOp>(loc, x_sq);
131 
132   // Materialize polynomial approximation for x in [1, 8) as
133   //   erfc(x) = exp(z) P(|x|) / Q(|x|).
134   Value exp_z = rewriter.create<mhlo::ExpOp>(loc, z);
135   Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
136   Value poly_p = MaterializePolynomialApproximation(rewriter, loc, abs_x,
137                                                     kErfcPCoefficients);
138   Value exp_z_mul_poly_p = rewriter.create<mhlo::MulOp>(loc, exp_z, poly_p);
139   Value poly_q = MaterializePolynomialApproximation(rewriter, loc, abs_x,
140                                                     kErfcQCoefficients);
141   Value erfc_approx_1_8 =
142       rewriter.create<mhlo::DivOp>(loc, exp_z_mul_poly_p, poly_q);
143 
144   // Materialize polynomial approximation for x in >= 8 as
145   //   erfc(x) exp(z) R(|x|) / S(|x|).
146   Value poly_r = MaterializePolynomialApproximation(rewriter, loc, abs_x,
147                                                     kErfcRCoefficients);
148   Value exp_z_mul_poly_r = rewriter.create<mhlo::MulOp>(loc, exp_z, poly_r);
149   Value poly_s = MaterializePolynomialApproximation(rewriter, loc, abs_x,
150                                                     kErfcSCoefficients);
151   Value erfc_approx_8_inf =
152       rewriter.create<mhlo::DivOp>(loc, exp_z_mul_poly_r, poly_s);
153 
154   // Combine polynomial approximations for x >= 1.
155   const StringAttr kLT = rewriter.getStringAttr(
156       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
157   Value eight = chlo::getConstantLike(rewriter, loc, 8.0, x);
158   Value abs_x_lt_8 = rewriter.create<mhlo::CompareOp>(loc, abs_x, eight, kLT);
159   Value erfc_approx = rewriter.create<mhlo::SelectOp>(
160       loc, abs_x_lt_8, erfc_approx_1_8, erfc_approx_8_inf);
161 
162   // Clamp to prevent overflow and materialize approximation for large x as
163   //   erfc(x) = 0.
164   Value z_lt_neg_maxlog = rewriter.create<mhlo::CompareOp>(
165       loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), kLT);
166   Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
167   Value erfc_approx_clamped =
168       rewriter.create<mhlo::SelectOp>(loc, z_lt_neg_maxlog, zero, erfc_approx);
169 
170   // Derive approximation for x <= -1 as
171   //   erfc(x) = 2 - erfc(-x).
172   // Reuse previously materialized approximations all of which take |x| as their
173   // argument.
174   Value x_lt_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLT);
175   Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
176   Value two_sub_erfc_approx_clamped =
177       rewriter.create<mhlo::SubOp>(loc, two, erfc_approx_clamped);
178   return rewriter.create<mhlo::SelectOp>(
179       loc, x_lt_zero, two_sub_erfc_approx_clamped, erfc_approx_clamped);
180 }
181 
182 // Precondition is |x| <= 1. Use erfc approximation, otherwise.
183 // This implementation is based on Cephes.
MaterializeErfApproximationF64ForMagnituteLEOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)184 Value MaterializeErfApproximationF64ForMagnituteLEOne(
185     ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
186   Value x = args.front();
187   assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
188          "expect f64 element type");
189   const std::vector<double> kErfTCoefficients{
190       9.60497373987051638749E0, 9.00260197203842689217E1,
191       2.23200534594684319226E3, 7.00332514112805075473E3,
192       5.55923013010394962768E4};
193   const std::vector<double> kErfUCoefficients{
194       1.00000000000000000000E0, 3.35617141647503099647E1,
195       5.21357949780152679795E2, 4.59432382970980127987E3,
196       2.26290000613890934246E4, 4.92673942608635921086E4};
197 
198   // Materialize polynomial approximation for |x| <= 1 as
199   //   erf(x) = x T(x^2) / U(x^2).
200   Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
201   Value poly_t = MaterializePolynomialApproximation(rewriter, loc, x_sq,
202                                                     kErfTCoefficients);
203   Value x_mul_poly_t = rewriter.create<mhlo::MulOp>(loc, x, poly_t);
204   Value poly_u = MaterializePolynomialApproximation(rewriter, loc, x_sq,
205                                                     kErfUCoefficients);
206   return rewriter.create<mhlo::DivOp>(loc, x_mul_poly_t, poly_u);
207 }
208 
209 // This implementation is based on Cephes.
MaterializeErfApproximationF64(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)210 Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter,
211                                      Location loc, ValueRange args) {
212   Value x = args.front();
213   assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
214          "expect f64 element type");
215 
216   // Rely on erf approximation for |x| < 1
217   //   erf(x) = erf_approx(x)
218   Value erf_approx =
219       MaterializeErfApproximationF64ForMagnituteLEOne(rewriter, loc, x);
220 
221   // Rely on erfc approximation for |x| >= 1 and materialize erf as
222   //   erf(x) = 1 - erfc_approx(x)
223   Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
224   Value erfc_approx =
225       MaterializeErfcApproximationF64ForMagnituteGEOne(rewriter, loc, x);
226   Value erfc_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erfc_approx);
227 
228   // Materialize approximation selection based on argument.
229   Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
230   const StringAttr kLT = rewriter.getStringAttr(
231       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
232   Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
233   return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_approx,
234                                          erfc_based_approx);
235 }
236 
MaterializeErfcApproximationF64(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)237 Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
238                                       Location loc, ValueRange args) {
239   Value x = args.front();
240   assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
241          "expect f64 element type");
242 
243   // Rely on erfc approximation for |x| >= 1
244   //   erfc(x) = erfc_approx(x)
245   Value erfc_approx =
246       MaterializeErfcApproximationF64ForMagnituteGEOne(rewriter, loc, x);
247 
248   // Rely on erf approximation for |x| < 1 and materialize erfc as
249   //   erfc(x) = 1 - erf_approx(x)
250   Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
251   Value erf_approx =
252       MaterializeErfApproximationF64ForMagnituteLEOne(rewriter, loc, x);
253   Value erf_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erf_approx);
254 
255   // Materialize approximation selection based on argument.
256   Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
257   const StringAttr kLT = rewriter.getStringAttr(
258       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
259   Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
260   return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_based_approx,
261                                          erfc_approx);
262 }
263 
264 // Precondition is |x| >= 1. Use erf approximation, otherwise.
265 //
266 // We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
267 // argument and derive the final approximation for all |x| >= 1.
268 // This implementation is based on Cephes.
MaterializeErfcApproximationF32ForMagnitudeGEOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)269 Value MaterializeErfcApproximationF32ForMagnitudeGEOne(
270     ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
271   Value x = args.front();
272   assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
273          "expect f32 element type");
274   const double kMaxlog = 88.72283905206835;
275   const std::vector<float> kErfcPCoefficients{
276       +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1,
277       -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1,
278       +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1,
279   };
280   const std::vector<float> kErfcRCoefficients{
281       -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0,
282       +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1,
283       -2.820767439740514E-1, +5.641895067754075E-1,
284   };
285 
286   // Let z = -x^2.
287   Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
288   Value z = rewriter.create<mhlo::NegOp>(loc, x_sq);
289 
290   // Materialize polynomial approximation for x >= 1 as
291   //   erfc(x) = exp(z) 1/x P(1/x^2)   if x in [1, 2)
292   //   erfc(x) = exp(z) 1/x R(1/x^2)   if x >= 2
293   const StringAttr kLT = rewriter.getStringAttr(
294       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
295   Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
296   Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
297   Value reciprocal_x_sq = rewriter.create<mhlo::DivOp>(loc, one, x_sq);
298   Value exp_z = rewriter.create<mhlo::ExpOp>(loc, z);
299   Value one_div_abs_x = rewriter.create<mhlo::DivOp>(loc, one, abs_x);
300   Value exp_z_mul_one_div_abs_x =
301       rewriter.create<mhlo::MulOp>(loc, exp_z, one_div_abs_x);
302   Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
303   Value abs_x_lt_two = rewriter.create<mhlo::CompareOp>(loc, abs_x, two, kLT);
304   Value poly_p = MaterializePolynomialApproximation(
305       rewriter, loc, reciprocal_x_sq, kErfcPCoefficients);
306   Value poly_r = MaterializePolynomialApproximation(
307       rewriter, loc, reciprocal_x_sq, kErfcRCoefficients);
308   Value poly =
309       rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_two, poly_p, poly_r);
310   Value erfc_approx =
311       rewriter.create<mhlo::MulOp>(loc, exp_z_mul_one_div_abs_x, poly);
312 
313   // Clamp to prevent overflow and materialize approximation for large x as
314   //   erfc(x) = 0.
315   Value z_lt_neq_maxlog = rewriter.create<mhlo::CompareOp>(
316       loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), kLT);
317   Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
318   Value erfc_approx_clamped =
319       rewriter.create<mhlo::SelectOp>(loc, z_lt_neq_maxlog, zero, erfc_approx);
320 
321   // Derive approximation for x <= -1 as
322   //   erfc(x) = 2 - erfc(-x).
323   // Reuse previously materialized approximations all of which take |x| as their
324   // argument.
325   Value x_lt_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLT);
326   Value two_sub_erfc_approx =
327       rewriter.create<mhlo::SubOp>(loc, two, erfc_approx_clamped);
328   return rewriter.create<mhlo::SelectOp>(loc, x_lt_zero, two_sub_erfc_approx,
329                                          erfc_approx_clamped);
330 }
331 
332 // Precondition is |x| <= 1. Use erfc approximation, otherwise.
333 // This implementation is based on Cephes.
MaterializeErfApproximationF32ForMagnitudeLEOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)334 Value MaterializeErfApproximationF32ForMagnitudeLEOne(
335     ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
336   Value x = args.front();
337   assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
338          "expect f32 element type");
339   const std::vector<float> kErfTCoefficients{
340       +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3,
341       -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1,
342       +1.128379165726710E+0,
343   };
344 
345   // Materialize polynomial approximation for |x| <= 1 as
346   //   erf(x) = x T(x^2).
347   Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
348   Value poly_t = MaterializePolynomialApproximation(rewriter, loc, x_sq,
349                                                     kErfTCoefficients);
350   return rewriter.create<mhlo::MulOp>(loc, x, poly_t);
351 }
352 
353 // This is the same approximation as used in Eigen.
MaterializeErfApproximationF32(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)354 Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
355                                      Location loc, ValueRange args) {
356   Value x = args.front();
357   assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
358          "expect f32 element type");
359   const std::vector<float> kAlpha{
360       -2.72614225801306e-10f, 2.77068142495902e-08f,  -2.10102402082508e-06f,
361       -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
362       -1.60960333262415e-02f,
363   };
364   const std::vector<float> kBeta{
365       -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
366       -7.37332916720468e-03f, -1.42647390514189e-02f,
367   };
368 
369   // Clamp argument between -4 and 4.
370   Value lb = chlo::getConstantLike(rewriter, loc, -4.0, x);
371   Value ub = chlo::getConstantLike(rewriter, loc, 4.0, x);
372   x = rewriter.create<mhlo::ClampOp>(loc, x.getType(), lb, x, ub);
373   Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
374 
375   // Materialize polynomial approximation for x in [-4, 4] as
376   //   erf(x) = x * Alpha(x^2) / Beta(x^2).
377   Value alpha_poly =
378       MaterializePolynomialApproximation(rewriter, loc, x_sq, kAlpha);
379   Value beta_poly =
380       MaterializePolynomialApproximation(rewriter, loc, x_sq, kBeta);
381   Value x_mul_alpha_poly = rewriter.create<mhlo::MulOp>(loc, x, alpha_poly);
382   return rewriter.create<mhlo::DivOp>(loc, x_mul_alpha_poly, beta_poly);
383 }
384 
MaterializeErfcApproximationF32(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)385 Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
386                                       Location loc, ValueRange args) {
387   Value x = args.front();
388   assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
389          "expect f32 element type");
390 
391   // Rely on erfc approximation for |x| >= 1
392   //   erfc(x) = erfc_approx(x)
393   Value erfc_approx =
394       MaterializeErfcApproximationF32ForMagnitudeGEOne(rewriter, loc, x);
395 
396   // Rely on erf approximation for |x| < 1 and materialize erfc as
397   //   erfc(x) = 1 - erf_approx(x)
398   Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
399   Value erf_approx =
400       MaterializeErfApproximationF32ForMagnitudeLEOne(rewriter, loc, x);
401   Value erf_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erf_approx);
402 
403   // Materialize approximation selection based on argument.
404   const StringAttr kLT = rewriter.getStringAttr(
405       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
406   Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
407   Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
408   return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_based_approx,
409                                          erfc_approx);
410 }
411 
MaterializeWithUpcast(ConversionPatternRewriter & rewriter,Location loc,ValueRange args,FloatType min_precision_ty,Value callback (ConversionPatternRewriter &,Location,ValueRange))412 Value MaterializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc,
413                             ValueRange args, FloatType min_precision_ty,
414                             Value callback(ConversionPatternRewriter &,
415                                            Location, ValueRange)) {
416   auto original_ty =
417       getElementTypeOrSelf(args.front().getType()).cast<FloatType>();
418   bool needs_upcast = original_ty.getWidth() < min_precision_ty.getWidth();
419 
420   // Upcast arguments if necessary.
421   llvm::SmallVector<Value, 2> casted_args;
422   if (needs_upcast) {
423     for (Value a : args) {
424       casted_args.push_back(
425           rewriter.create<mhlo::ConvertOp>(loc, a, min_precision_ty));
426     }
427     args = casted_args;
428   }
429 
430   Value result = callback(rewriter, loc, args);
431 
432   // Cast back if necessary.
433   if (needs_upcast) {
434     result = rewriter.create<mhlo::ConvertOp>(loc, result, original_ty);
435   }
436 
437   return result;
438 }
439 
440 struct ConvertErfOp : public OpConversionPattern<ErfOp> {
441   using OpConversionPattern<ErfOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertErfOp442   LogicalResult matchAndRewrite(
443       ErfOp op, ArrayRef<Value> operands,
444       ConversionPatternRewriter &rewriter) const override {
445     Location loc = op.getLoc();
446     ErfOp::Adaptor transformed(operands);
447     Value x = transformed.operand();
448     Type ty = x.getType().cast<ShapedType>().getElementType();
449 
450     // For now, we support only f64, f32, and f16.
451     if (!ty.isF64() && !ty.isF32() && !ty.isF16()) return failure();
452 
453     if (ty.isF64()) {
454       rewriter.replaceOp(op, MaterializeErfApproximationF64(rewriter, loc, x));
455       return success();
456     }
457 
458     rewriter.replaceOp(op, MaterializeWithUpcast(
459                                rewriter, loc, operands, rewriter.getF32Type(),
460                                &MaterializeErfApproximationF32));
461     return success();
462   }
463 };
464 
465 struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
466   using OpConversionPattern<ErfcOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertErfcOp467   LogicalResult matchAndRewrite(
468       ErfcOp op, ArrayRef<Value> operands,
469       ConversionPatternRewriter &rewriter) const override {
470     Location loc = op.getLoc();
471     ErfcOp::Adaptor transformed(operands);
472     Value x = transformed.operand();
473     Type ty = x.getType().cast<ShapedType>().getElementType();
474 
475     // For now, we support only f64, f32, and f16.
476     if (!ty.isF64() && !ty.isF32() && !ty.isF16()) return failure();
477 
478     if (ty.isF64()) {
479       rewriter.replaceOp(op, MaterializeErfcApproximationF64(rewriter, loc, x));
480       return success();
481     }
482 
483     rewriter.replaceOp(op, MaterializeWithUpcast(
484                                rewriter, loc, operands, rewriter.getF32Type(),
485                                &MaterializeErfcApproximationF32));
486     return success();
487   }
488 };
489 
490 // Coefficients for the Lanczos approximation of the gamma function. The
491 // coefficients are uniquely determined by the choice of g and n (kLanczosGamma
492 // and kLanczosCoefficients.size() + 1). The coefficients below correspond to
493 // [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and
494 // [7, 9] seemed to be the least sensitive to the quality of the log function.
495 // In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
496 // for a particularly inaccurate log function.
497 constexpr double kLanczosGamma = 7;  // aka g
498 constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
499 constexpr std::array<double, 8> kLanczosCoefficients = {
500     676.520368121885098567009190444019, -1259.13921672240287047156078755283,
501     771.3234287776530788486528258894,   -176.61502916214059906584551354,
502     12.507343278686904814458936853,     -0.13857109526572011689554707,
503     9.984369578019570859563e-6,         1.50563273514931155834e-7};
504 
505 // Compute the Lgamma function using Lanczos' approximation from "A Precision
506 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
507 // series B. Vol. 1:
508 //   lgamma(z + 1) = (log(2) + log(pi)) / 2
509 //                     + (z + 1/2) * log(t(z))
510 //                     - t(z) + log(a(z))
511 //   with   t(z) = z + kLanczosGamma + 1/2
512 //          a(z) = kBaseLanczosCoeff
513 //                   + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
MaterializeLgamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)514 Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
515                         ValueRange args) {
516   // If the input is less than 0.5 use Euler's reflection formula.
517   //   gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
518   // Let z be
519   //   z = -x      if x < 1/2
520   //   z = x - 1   otheriwse
521   Value x = args.front();
522   const StringAttr kLT = rewriter.getStringAttr(
523       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
524   Value half = getConstantLike(rewriter, loc, 0.5, x);
525   Value need_to_reflect = rewriter.create<mhlo::CompareOp>(loc, x, half, kLT);
526   Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
527   Value one = getConstantLike(rewriter, loc, 1, x);
528   Value x_sub_one = rewriter.create<mhlo::SubOp>(loc, x, one);
529   Value z =
530       rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, neg_x, x_sub_one);
531 
532   // Materialize
533   //   a(z) = kBaseLanczosCoeff
534   //            + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
535   Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
536   for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
537     Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
538     Value one_based_index = getConstantLike(rewriter, loc, i + 1, x);
539     Value quotient = rewriter.create<mhlo::DivOp>(
540         loc, coeff, rewriter.create<mhlo::AddOp>(loc, z, one_based_index));
541     a = rewriter.create<mhlo::AddOp>(loc, a, quotient);
542   }
543 
544   // To improve accuracy on platforms with less-precise log implementations,
545   // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
546   // device.
547   // Materialize as
548   //   log(t) = log(kLanczosGamma + 1/2 + z)
549   //          = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
550   Value lanczos_plus_half =
551       getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
552   Value t = rewriter.create<mhlo::AddOp>(loc, lanczos_plus_half, z);
553   Value log_term =
554       getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
555   Value log1p_term = rewriter.create<mhlo::Log1pOp>(
556       loc, rewriter.create<mhlo::DivOp>(loc, z, lanczos_plus_half));
557   Value log_t = rewriter.create<mhlo::AddOp>(loc, log_term, log1p_term);
558 
559   // Note that t(z) may be large and we need to be careful not to overflow to
560   // infinity in the relevant term
561   //   r = (z + 1/2) * log(t(z)) - t(z).
562   // Therefore, we compute this as
563   //   r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
564   Value t_div_log_t = rewriter.create<mhlo::DivOp>(loc, t, log_t);
565   Value sum = rewriter.create<mhlo::SubOp>(
566       loc, rewriter.create<mhlo::AddOp>(loc, z, half), t_div_log_t);
567   Value r = rewriter.create<mhlo::MulOp>(loc, sum, log_t);
568 
569   // Compute the final result (modulo reflection) as
570   //   lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)).
571   Value log_a = rewriter.create<mhlo::LogOp>(loc, a);
572   Value lgamma = rewriter.create<mhlo::AddOp>(
573       loc,
574       rewriter.create<mhlo::AddOp>(
575           loc,
576           getConstantLike(rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x),
577           r),
578       log_a);
579 
580   // Compute the reflected value for x < 0.5 as
581   //   lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
582   //
583   // The abs is needed because lgamma is the log of the absolute value of the
584   // gamma function.
585   //
586   // We have to be careful when computing the final term above. gamma(x) goes
587   // to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x)
588   // term. The slope is large, so precision is particularly important.
589   //
590   // Because abs(sin(pi * x)) has period of 1 we can equivalently use
591   // abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is
592   // more numerically accurate: It doesn't overflow to inf like pi * x would and
593   // if x is an integer it evaluates to exactly 0 which is important because we
594   // then take the log of this value, and log(0) is inf.
595   //
596   // We don't have a frac(x) primitive in HLO and computing it is tricky, but
597   // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our
598   // purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
599   //
600   // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
601   // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
602   // [0, 1] is symmetric across the line Y=0.5.
603   //
604 
605   // Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of
606   // pi * abs_frac for values of abs_frac close to 1.
607   Value abs = rewriter.create<mhlo::AbsOp>(loc, x);
608   Value abs_frac = rewriter.create<mhlo::SubOp>(
609       loc, abs, rewriter.create<mhlo::FloorOp>(loc, abs));
610   Value reduce_abs_frac =
611       rewriter.create<mhlo::CompareOp>(loc, half, abs_frac, kLT);
612   abs_frac = rewriter.create<mhlo::SelectOp>(
613       loc, reduce_abs_frac, rewriter.create<mhlo::SubOp>(loc, one, abs_frac),
614       abs_frac);
615 
616   // Materialize reflection.
617   Value reflection_denom = rewriter.create<mhlo::LogOp>(
618       loc,
619       rewriter.create<mhlo::SinOp>(
620           loc, rewriter.create<mhlo::MulOp>(
621                    loc, getConstantLike(rewriter, loc, M_PI, x), abs_frac)));
622   Value lgamma_reflection = rewriter.create<mhlo::SubOp>(
623       loc,
624       rewriter.create<mhlo::SubOp>(
625           loc, getConstantLike(rewriter, loc, std::log(M_PI), x),
626           reflection_denom),
627       lgamma);
628 
629   // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
630   // then it "wins" and the result is +/-inf.
631   Value finite_reflection_denom =
632       rewriter.create<mhlo::IsFiniteOp>(loc, reflection_denom);
633   Value neg_reflection_denom =
634       rewriter.create<mhlo::NegOp>(loc, reflection_denom);
635   lgamma_reflection = rewriter.create<mhlo::SelectOp>(
636       loc, finite_reflection_denom, lgamma_reflection, neg_reflection_denom);
637 
638   // Select whether or not to rely on the reflection.
639   lgamma = rewriter.create<mhlo::SelectOp>(loc, need_to_reflect,
640                                            lgamma_reflection, lgamma);
641 
642   // Materialize +/-inf behavior as
643   //   lgamma(+/-inf) = +inf.
644   Value x_is_inf = rewriter.create<chlo::IsInfOp>(loc, x);
645   return rewriter.create<mhlo::SelectOp>(
646       loc, x_is_inf,
647       chlo::getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false),
648       lgamma);
649 }
650 
651 // Express `cosh` as
652 //   cosh(x) = (e^x + e^-x) / 2
653 //           = e^(x + log(1/2)) + e^(-x + log(1/2))
654 //
655 // The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not.
656 //
657 // This incorrectly overflows to inf for two f32 input values, namely
658 // +/-89.4159851, due to rounding error when computing x +/- log(1/2).  The
659 // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
660 // we deem this acceptable.
MaterializeCoshApproximation(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)661 Value MaterializeCoshApproximation(ConversionPatternRewriter &rewriter,
662                                    Location loc, ValueRange operands) {
663   CoshOp::Adaptor transformed(operands);
664   Value x = transformed.operand();
665 
666   Value log_one_half =
667       rewriter.create<mhlo::LogOp>(loc, getConstantLike(rewriter, loc, 0.5, x));
668   Value exp_add = rewriter.create<mhlo::ExpOp>(
669       loc, rewriter.create<mhlo::AddOp>(loc, x, log_one_half));
670   Value exp_sub = rewriter.create<mhlo::ExpOp>(
671       loc, rewriter.create<mhlo::SubOp>(loc, log_one_half, x));
672   return rewriter.create<mhlo::AddOp>(loc, exp_add, exp_sub);
673 }
674 
675 struct ConvertCoshOp : public OpConversionPattern<CoshOp> {
676   using OpConversionPattern<CoshOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertCoshOp677   LogicalResult matchAndRewrite(
678       CoshOp op, ArrayRef<Value> operands,
679       ConversionPatternRewriter &rewriter) const override {
680     CoshOp::Adaptor transformed(operands);
681     Value x = transformed.operand();
682     if (x.getType().cast<ShapedType>().getElementType().isa<ComplexType>()) {
683       // TODO(hinsu): Support operands with complex element types by always
684       // using the formula for large x. The compare op is not legal for complex
685       // numbers.
686       return failure();
687     }
688     rewriter.replaceOp(op,
689                        MaterializeWithUpcast(rewriter, op.getLoc(), operands,
690                                              rewriter.getF32Type(),
691                                              &MaterializeCoshApproximation));
692     return success();
693   }
694 };
695 
696 // Compute the Digamma function using Lanczos' approximation from "A Precision
697 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
698 // series B. Vol. 1:
699 //   digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z)
700 //   with   t(z) = z + kLanczosGamma + 1/2
701 //          a(z) = kBaseLanczosCoeff
702 //                   + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
703 //          a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
MaterializeDigamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)704 Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
705                          ValueRange args) {
706   // If the input is less than 0.5 use Euler's reflection formula.
707   //   digamma(x) = digamma(1 - x) - pi * cot(pi * x)
708   // Let z be
709   //   z = -x      if x < 1/2
710   //   z = x - 1   otheriwse
711   Value x = args.front();
712   const StringAttr kLT = rewriter.getStringAttr(
713       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
714   Value half = getConstantLike(rewriter, loc, 0.5, x);
715   Value need_to_reflect = rewriter.create<mhlo::CompareOp>(loc, x, half, kLT);
716   Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
717   Value one = getConstantLike(rewriter, loc, 1, x);
718   Value x_sub_one = rewriter.create<mhlo::SubOp>(loc, x, one);
719   Value z =
720       rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, neg_x, x_sub_one);
721 
722   // Materialize
723   //   a(z) = kBaseLanczosCoeff
724   //            + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
725   //   a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
726   Value zero = getConstantLike(rewriter, loc, 0.0, x);
727   Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
728   Value a_prime = zero;
729   for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
730     Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
731     Value one_based_index = getConstantLike(rewriter, loc, i + 1, x);
732     Value z_term = rewriter.create<mhlo::AddOp>(loc, z, one_based_index);
733     a_prime = rewriter.create<mhlo::SubOp>(
734         loc, a_prime,
735         rewriter.create<mhlo::DivOp>(
736             loc, coeff, rewriter.create<mhlo::MulOp>(loc, z_term, z_term)));
737     a = rewriter.create<mhlo::AddOp>(
738         loc, a, rewriter.create<mhlo::DivOp>(loc, coeff, z_term));
739   }
740 
741   // To improve accuracy on platforms with less-precise log implementations,
742   // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
743   // device.
744   // Materialize as
745   //   log(t) = log(kLanczosGamma + 1/2 + z)
746   //          = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
747   Value lanczos_plus_half =
748       getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
749   Value t = rewriter.create<mhlo::AddOp>(loc, lanczos_plus_half, z);
750   Value log_term =
751       getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
752   Value log1p_term = rewriter.create<mhlo::Log1pOp>(
753       loc, rewriter.create<mhlo::DivOp>(loc, z, lanczos_plus_half));
754   Value log_t = rewriter.create<mhlo::AddOp>(loc, log_term, log1p_term);
755 
756   // Materialize the final result (modulo reflection) as
757   //   digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z).
758   Value a_prime_div_a = rewriter.create<mhlo::DivOp>(loc, a_prime, a);
759   Value lanczos_gamma_div_t = rewriter.create<mhlo::DivOp>(
760       loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t);
761   Value digamma = rewriter.create<mhlo::SubOp>(
762       loc, rewriter.create<mhlo::AddOp>(loc, log_t, a_prime_div_a),
763       lanczos_gamma_div_t);
764 
765   // We need to be careful how we compute cot(pi * input) below: For
766   // near-integral arguments, pi * input can lose precision.
767   //
768   // Input is already known to be less than 0.5 (otherwise we don't have to
769   // reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to
770   // increase precision of pi * x and the resulting cotangent.
771   Value reduced_x = rewriter.create<mhlo::AddOp>(
772       loc, x,
773       rewriter.create<mhlo::AbsOp>(
774           loc, rewriter.create<mhlo::FloorOp>(
775                    loc, rewriter.create<mhlo::AddOp>(
776                             loc, x, getConstantLike(rewriter, loc, 0.5, x)))));
777 
778   // Materialize reflection for inputs less than 0.5 as
779   //   digamma(x) = digamma(1 - x) - pi * cot(pi * x)
780   //              = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x)
781   Value pi = getConstantLike(rewriter, loc, M_PI, x);
782   Value pi_mul_reduced_x = rewriter.create<mhlo::MulOp>(loc, pi, reduced_x);
783   Value cos = rewriter.create<mhlo::CosOp>(loc, pi_mul_reduced_x);
784   Value sin = rewriter.create<mhlo::SinOp>(loc, pi_mul_reduced_x);
785   Value reflection = rewriter.create<mhlo::SubOp>(
786       loc, digamma,
787       rewriter.create<mhlo::DivOp>(
788           loc, rewriter.create<mhlo::MulOp>(loc, pi, cos), sin));
789 
790   // Select whether or not to rely on the reflection.
791   digamma = rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, reflection,
792                                             digamma);
793 
794   // Digamma has poles at negative integers and zero; return nan for those.
795   const StringAttr kLE = rewriter.getStringAttr(
796       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
797   Value is_le_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLE);
798   const StringAttr kEQ = rewriter.getStringAttr(
799       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
800   Value is_int = rewriter.create<mhlo::CompareOp>(
801       loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kEQ);
802   Value is_pole = rewriter.create<mhlo::AndOp>(loc, is_le_zero, is_int);
803   return rewriter.create<mhlo::SelectOp>(
804       loc, is_pole,
805       getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
806                       x),
807       digamma);
808 }
809 
MaterializeZeta(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)810 Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
811                       ValueRange args) {
812   assert(args.size() == 2);
813   Value x = args[0];
814   Value q = args[1];
815   static const std::array<double, 12> kZetaCoeffs{
816       -7.1661652561756670113e18,
817       1.8152105401943546773e17,
818       -4.5979787224074726105e15,
819       1.1646782814350067249e14,
820       -2.950130727918164224e12,
821       7.47242496e10,
822       -1.8924375803183791606e9,
823       47900160.0,
824       -1209600.0,
825       30240.0,
826       -720.0,
827       12.0,
828   };
829 
830   // For speed we'll always use 9 iterations for the initial series estimate,
831   // and a 12 term expansion for the Euler-Maclaurin formula.
832   Value a = q;
833   Value zero = chlo::getConstantLike(rewriter, loc, 0.0, a);
834   Value neg_power = zero;
835   Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
836   Value initial_sum = rewriter.create<mhlo::PowOp>(loc, q, neg_x);
837   Value one = chlo::getConstantLike(rewriter, loc, 1.0, a);
838   for (int i = 0; i < 9; ++i) {
839     a = rewriter.create<mhlo::AddOp>(loc, a, one);
840     neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
841     initial_sum = rewriter.create<mhlo::AddOp>(loc, initial_sum, neg_power);
842   }
843   a = rewriter.create<mhlo::AddOp>(loc, a, one);
844   neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
845   Value one_like_x = chlo::getConstantLike(rewriter, loc, 1.0, x);
846   Value x_minus_one = rewriter.create<mhlo::SubOp>(loc, x, one_like_x);
847   Value neg_power_mul_a = rewriter.create<mhlo::MulOp>(loc, neg_power, a);
848   Value neg_power_mul_a_div_x_minus_one =
849       rewriter.create<mhlo::DivOp>(loc, neg_power_mul_a, x_minus_one);
850   Value s = rewriter.create<mhlo::AddOp>(loc, initial_sum,
851                                          neg_power_mul_a_div_x_minus_one);
852   Value a_inverse_square = rewriter.create<mhlo::DivOp>(
853       loc, one, rewriter.create<mhlo::MulOp>(loc, a, a));
854 
855   Value horner_sum = zero;
856   Value factor = one;
857   // Use Horner's rule for this.
858   // Note this differs from Cephes which does a 'naive' polynomial evaluation.
859   // Using Horner's rule allows to avoid some NaN's and Infs from happening,
860   // resulting in more numerically stable code.
861   for (int i = 0; i < 11; ++i) {
862     Value factor_lhs = rewriter.create<mhlo::SubOp>(
863         loc, x, chlo::getConstantLike(rewriter, loc, 22 - 2 * i, x));
864     Value factor_rhs = rewriter.create<mhlo::SubOp>(
865         loc, x, chlo::getConstantLike(rewriter, loc, 21 - 2 * i, x));
866     factor = rewriter.create<mhlo::MulOp>(loc, factor_lhs, factor_rhs);
867     horner_sum = rewriter.create<mhlo::MulOp>(
868         loc, factor,
869         rewriter.create<mhlo::MulOp>(
870             loc, a_inverse_square,
871             rewriter.create<mhlo::AddOp>(
872                 loc, horner_sum,
873                 chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a))));
874   }
875   Value zero_point_five_like_neg_power =
876       chlo::getConstantLike(rewriter, loc, .5, neg_power);
877   Value x_div_a = rewriter.create<mhlo::DivOp>(loc, x, a);
878   s = rewriter.create<mhlo::AddOp>(
879       loc, s,
880       rewriter.create<mhlo::MulOp>(
881           loc, neg_power,
882           rewriter.create<mhlo::AddOp>(
883               loc, zero_point_five_like_neg_power,
884               rewriter.create<mhlo::MulOp>(
885                   loc, x_div_a,
886                   rewriter.create<mhlo::AddOp>(
887                       loc,
888                       chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11],
889                                             a),
890                       horner_sum)))));
891 
892   // Use the initial zeta sum without the correction term coming
893   // from Euler-Maclaurin if it is accurate enough.
894   const StringAttr kLT = rewriter.getStringAttr(
895       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
896   Value abs_neg_power = rewriter.create<mhlo::AbsOp>(loc, neg_power);
897   Value abs_initial_sum = rewriter.create<mhlo::AbsOp>(loc, initial_sum);
898   Value output = rewriter.create<mhlo::SelectOp>(
899       loc,
900       rewriter.create<mhlo::CompareOp>(
901           loc, abs_neg_power,
902           rewriter.create<mhlo::MulOp>(
903               loc, abs_initial_sum,
904               chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)),
905           kLT),
906       initial_sum, s);
907 
908   // Function is not defined for x < 1.
909   Value nan = chlo::getConstantLike(
910       rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
911   output = rewriter.create<mhlo::SelectOp>(
912       loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kLT), nan,
913       output);
914 
915   // For q <= 0, x must be an integer.
916   const StringAttr kLE = rewriter.getStringAttr(
917       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
918   const StringAttr kNE = rewriter.getStringAttr(
919       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
920   Value q_le_zero = rewriter.create<mhlo::CompareOp>(loc, q, zero, kLE);
921   Value x_not_int = rewriter.create<mhlo::CompareOp>(
922       loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kNE);
923   Value x_domain_error =
924       rewriter.create<mhlo::AndOp>(loc, q_le_zero, x_not_int);
925   output = rewriter.create<mhlo::SelectOp>(loc, x_domain_error, nan, output);
926 
927   // For all integer q <= 0, zeta has a pole. The limit is only defined as
928   // +inf if x is and even integer.
929   const StringAttr kEQ = rewriter.getStringAttr(
930       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
931   Value inf = chlo::getConstantLike(rewriter, loc,
932                                     std::numeric_limits<double>::infinity(), x);
933   Value q_is_int = rewriter.create<mhlo::CompareOp>(
934       loc, q, rewriter.create<mhlo::FloorOp>(loc, q), kEQ);
935   Value at_pole = rewriter.create<mhlo::AndOp>(loc, q_le_zero, q_is_int);
936   Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
937   Value x_is_int = rewriter.create<mhlo::CompareOp>(
938       loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kEQ);
939   Value x_is_even = rewriter.create<mhlo::CompareOp>(
940       loc, rewriter.create<mhlo::RemOp>(loc, x, two), zero, kEQ);
941   Value x_is_even_int = rewriter.create<mhlo::AndOp>(loc, x_is_int, x_is_even);
942   output = rewriter.create<mhlo::SelectOp>(
943       loc, at_pole,
944       rewriter.create<mhlo::SelectOp>(loc, x_is_even_int, inf, nan), output);
945 
946   // For x = 1, this is the harmonic series and diverges.
947   output = rewriter.create<mhlo::SelectOp>(
948       loc, rewriter.create<mhlo::CompareOp>(loc, x, one, kEQ), inf, output);
949 
950   return output;
951 }
952 
MaterializePolygamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)953 Value MaterializePolygamma(ConversionPatternRewriter &rewriter, Location loc,
954                            ValueRange args) {
955   PolygammaOp::Adaptor transformed(args);
956   Value n = transformed.n();
957   Value x = transformed.x();
958 
959   // Handle integer n > 0.
960   Value one = getConstantLike(rewriter, loc, 1.0, x);
961   Value two = getConstantLike(rewriter, loc, 2.0, x);
962   Value sign = rewriter.create<mhlo::SubOp>(
963       loc,
964       rewriter.create<mhlo::MulOp>(loc, two,
965                                    rewriter.create<mhlo::RemOp>(loc, n, two)),
966       one);
967   Value n_plus_one = rewriter.create<mhlo::AddOp>(loc, n, one);
968   Value exp_lgamma_np1 = rewriter.create<mhlo::ExpOp>(
969       loc, rewriter.create<chlo::LgammaOp>(loc, n_plus_one));
970   Value zeta = rewriter.create<chlo::ZetaOp>(loc, n_plus_one, x);
971   Value result = rewriter.create<mhlo::MulOp>(
972       loc, rewriter.create<mhlo::MulOp>(loc, sign, exp_lgamma_np1), zeta);
973 
974   // Handle n = 0.
975   const StringAttr kEQ = rewriter.getStringAttr(
976       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
977   Value zero = getConstantLike(rewriter, loc, 0.0, x);
978   Value n_eq_zero = rewriter.create<mhlo::CompareOp>(loc, n, zero, kEQ);
979   result = rewriter.create<mhlo::SelectOp>(
980       loc, n_eq_zero, rewriter.create<chlo::DigammaOp>(loc, x), result);
981 
982   // Check that n is a natural number. Return nan, otherwise.
983   const StringAttr kNE = rewriter.getStringAttr(
984       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
985   Value non_int = rewriter.create<mhlo::CompareOp>(
986       loc, n, rewriter.create<mhlo::FloorOp>(loc, n), kNE);
987   const StringAttr kLT = rewriter.getStringAttr(
988       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
989   Value negative = rewriter.create<mhlo::CompareOp>(loc, n, zero, kLT);
990   Value non_natural = rewriter.create<mhlo::OrOp>(loc, non_int, negative);
991   return rewriter.create<mhlo::SelectOp>(
992       loc, non_natural,
993       getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
994                       x),
995       result);
996 }
997 
998 struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
999   using OpConversionPattern<LgammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertLgammaOp1000   LogicalResult matchAndRewrite(
1001       LgammaOp op, ArrayRef<Value> operands,
1002       ConversionPatternRewriter &rewriter) const override {
1003     FloatType min_precision_ty = rewriter.getF32Type();
1004     rewriter.replaceOp(
1005         op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
1006                                   min_precision_ty, &MaterializeLgamma));
1007     return success();
1008   }
1009 };
1010 
1011 struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
1012   using OpConversionPattern<DigammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertDigammaOp1013   LogicalResult matchAndRewrite(
1014       DigammaOp op, ArrayRef<Value> operands,
1015       ConversionPatternRewriter &rewriter) const override {
1016     FloatType min_precision_ty = rewriter.getF32Type();
1017     rewriter.replaceOp(
1018         op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
1019                                   min_precision_ty, &MaterializeDigamma));
1020     return success();
1021   }
1022 };
1023 
MaterializeNextAfter(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)1024 Value MaterializeNextAfter(ConversionPatternRewriter &rewriter, Location loc,
1025                            ValueRange operands) {
1026   NextAfterOp::Adaptor transformed(operands);
1027   Value x = transformed.x();
1028   Value y = transformed.y();
1029   auto result_ty = x.getType().cast<ShapedType>();
1030   auto bitwidth = result_ty.getElementType().getIntOrFloatBitWidth();
1031   ImplicitLocOpBuilder b(loc, rewriter);
1032   auto int_ty = result_ty.clone(b.getIntegerType(bitwidth));
1033   auto x_as_int = b.create<mhlo::BitcastConvertOp>(int_ty, x);
1034   auto y_as_int = b.create<mhlo::BitcastConvertOp>(int_ty, y);
1035 
1036   // The result is NaN if either "x" or "y" are NaN.
1037   const StringAttr kNE = rewriter.getStringAttr(
1038       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
1039   auto x_is_nan = b.create<mhlo::CompareOp>(x, x, kNE);
1040   auto y_is_nan = b.create<mhlo::CompareOp>(y, y, kNE);
1041   auto nan_input = b.create<mhlo::OrOp>(x_is_nan, y_is_nan);
1042   auto result_for_nan = getConstantLike(
1043       rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
1044   auto result_for_nan_as_int =
1045       b.create<mhlo::BitcastConvertOp>(int_ty, result_for_nan);
1046 
1047   // The sign bit is the MSB.
1048   const int64_t sign_bit = int64_t{1} << (bitwidth - 1);
1049   // Discard the sign bit to make the result non-negative.
1050   auto sign_mask = getConstantLike(rewriter, loc, sign_bit, x_as_int);
1051   auto negated_sign_mask = getConstantLike(rewriter, loc, ~sign_bit, x_as_int);
1052   auto x_abs = b.create<mhlo::AndOp>(x_as_int, negated_sign_mask);
1053   auto y_abs = b.create<mhlo::AndOp>(y_as_int, negated_sign_mask);
1054 
1055   // When both "x" and "y" are equal, the result is "y".
1056   const StringAttr kEQ = rewriter.getStringAttr(
1057       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
1058   auto x_and_y_are_equal = b.create<mhlo::CompareOp>(x, y, kEQ);
1059   auto result_for_equal = y_as_int;
1060 
1061   // When both "x" and "y" are 0, the result is "y". This is a separate case
1062   // from above because "x" and "y" might have a different sign.
1063   auto zero = getConstantLike(rewriter, loc, 0, x_as_int);
1064   auto x_is_zero = b.create<mhlo::CompareOp>(x_abs, zero, kEQ);
1065   auto y_is_zero = b.create<mhlo::CompareOp>(y_abs, zero, kEQ);
1066   auto result_for_both_zero = y_as_int;
1067 
1068   auto x_sign = b.create<mhlo::AndOp>(x_as_int, sign_mask);
1069   auto y_sign = b.create<mhlo::AndOp>(y_as_int, sign_mask);
1070 
1071   // If from == 0 && to != 0, we need to return the smallest subnormal number
1072   // signed like "to".
1073   auto one = getConstantLike(rewriter, loc, 1, x_as_int);
1074   auto result_for_x_zero_y_non_zero = b.create<mhlo::OrOp>(y_sign, one);
1075 
1076   // If the sign of "x" and "y" disagree:
1077   // - we need to make the magnitude of "from" smaller so that it is closer to
1078   //   zero.
1079   //
1080   // Otherwise the signs agree:
1081   // - "x" with a magnitude larger than "y" means we need to make the magnitude
1082   //   smaller.
1083   // - "x" with a magnitude smaller than "y" means we need to make the magnitude
1084   //   larger.
1085   auto signs_disagree = b.create<mhlo::CompareOp>(x_sign, y_sign, kNE);
1086   const StringAttr kGT = rewriter.getStringAttr(
1087       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::GT));
1088   auto x_magnitude_larger_than_y = b.create<mhlo::CompareOp>(x_abs, y_abs, kGT);
1089   auto result_has_smaller_magnitude =
1090       b.create<mhlo::OrOp>(x_magnitude_larger_than_y, signs_disagree);
1091   auto minus_one = getConstantLike(rewriter, loc, -1, x_as_int);
1092   auto magnitude_adjustment =
1093       b.create<mhlo::SelectOp>(result_has_smaller_magnitude, minus_one, one);
1094   Value result = b.create<mhlo::AddOp>(x_as_int, magnitude_adjustment);
1095   // Handle from == +-0.
1096   result = b.create<mhlo::SelectOp>(
1097       x_is_zero,
1098       b.create<mhlo::SelectOp>(y_is_zero, result_for_both_zero,
1099                                result_for_x_zero_y_non_zero),
1100       result);
1101   // Handle from == to.
1102   result =
1103       b.create<mhlo::SelectOp>(x_and_y_are_equal, result_for_equal, result);
1104   // Handle isnan(x) || isnan(y).
1105   result = b.create<mhlo::SelectOp>(nan_input, result_for_nan_as_int, result);
1106 
1107   // Cast back to the original type.
1108   return b.create<mhlo::BitcastConvertOp>(result_ty, result);
1109 }
1110 
1111 struct ConvertNextAfterOp : public OpConversionPattern<NextAfterOp> {
1112   using OpConversionPattern<NextAfterOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertNextAfterOp1113   LogicalResult matchAndRewrite(
1114       NextAfterOp op, ArrayRef<Value> operands,
1115       ConversionPatternRewriter &rewriter) const override {
1116     rewriter.replaceOp(op,
1117                        MaterializeNextAfter(rewriter, op.getLoc(), operands));
1118     return success();
1119   }
1120 };
1121 
1122 struct ConvertPolygammaOp : public OpConversionPattern<PolygammaOp> {
1123   using OpConversionPattern<PolygammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertPolygammaOp1124   LogicalResult matchAndRewrite(
1125       PolygammaOp op, ArrayRef<Value> operands,
1126       ConversionPatternRewriter &rewriter) const override {
1127     Location loc = op.getLoc();
1128     FloatType min_precision_ty = rewriter.getF32Type();
1129     rewriter.replaceOp(
1130         op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty,
1131                                   &MaterializePolygamma));
1132     return success();
1133   }
1134 };
1135 
MaterializeSinhApproximationForLargeX(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)1136 Value MaterializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter,
1137                                             Location loc, ValueRange operands) {
1138   SinhOp::Adaptor transformed(operands);
1139   Value x = transformed.operand();
1140   auto result_ty = x.getType().cast<ShapedType>();
1141 
1142   // TODO(b/190374484): Use mhlo::ConstantLikeOp when it supports complex types.
1143   Value two = rewriter.create<mhlo::ConstOp>(
1144       loc, hlo::GetScalarOfType(getElementTypeOrSelf(x.getType()), 2));
1145   Type extent_tensor_type = shape::getExtentTensorType(x.getContext());
1146   Value uncasted_shape =
1147       rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, x);
1148   Type shape_ty =
1149       RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType());
1150   Value shape = rewriter.create<tensor::CastOp>(loc, shape_ty, uncasted_shape);
1151   Value two_with_x_shape = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1152       loc, result_ty, two, shape, rewriter.getI64TensorAttr({}));
1153 
1154   Value log_two = rewriter.create<mhlo::LogOp>(loc, two_with_x_shape);
1155   Value log_one_half = rewriter.create<mhlo::NegOp>(loc, log_two);
1156   Value exp_add = rewriter.create<mhlo::ExpOp>(
1157       loc, rewriter.create<mhlo::AddOp>(loc, x, log_one_half));
1158   Value exp_sub = rewriter.create<mhlo::ExpOp>(
1159       loc, rewriter.create<mhlo::SubOp>(loc, log_one_half, x));
1160   return rewriter.create<mhlo::SubOp>(loc, exp_add, exp_sub);
1161 }
1162 
1163 // Express `sinh` as
1164 //   sinh(x) = (e^x - e^-x) / 2                     if |x| < 1
1165 //           = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
MaterializeSinhApproximation(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)1166 Value MaterializeSinhApproximation(ConversionPatternRewriter &rewriter,
1167                                    Location loc, ValueRange operands) {
1168   Value large_sinh_result =
1169       MaterializeSinhApproximationForLargeX(rewriter, loc, operands);
1170 
1171   SinhOp::Adaptor transformed(operands);
1172   Value x = transformed.operand();
1173   const StringAttr kLT = rewriter.getStringAttr(
1174       mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
1175   Value exp_x = rewriter.create<mhlo::ExpOp>(loc, x);
1176   Value exp_neg_x =
1177       rewriter.create<mhlo::ExpOp>(loc, rewriter.create<mhlo::NegOp>(loc, x));
1178   Value exp_difference = rewriter.create<mhlo::SubOp>(loc, exp_x, exp_neg_x);
1179   Value two = getConstantLike(rewriter, loc, 2.0, x);
1180   Value small_sinh_result =
1181       rewriter.create<mhlo::DivOp>(loc, exp_difference, two);
1182 
1183   Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
1184   Value one = getConstantLike(rewriter, loc, 1.0, x);
1185   Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
1186   return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, small_sinh_result,
1187                                          large_sinh_result);
1188 }
1189 
1190 struct ConvertSinhOp : public OpConversionPattern<SinhOp> {
1191   using OpConversionPattern<SinhOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertSinhOp1192   LogicalResult matchAndRewrite(
1193       SinhOp op, ArrayRef<Value> operands,
1194       ConversionPatternRewriter &rewriter) const override {
1195     SinhOp::Adaptor transformed(operands);
1196     Value x = transformed.operand();
1197     if (x.getType().cast<ShapedType>().getElementType().isa<ComplexType>()) {
1198       rewriter.replaceOp(op, MaterializeSinhApproximationForLargeX(
1199                                  rewriter, op.getLoc(), operands));
1200       return success();
1201     }
1202     rewriter.replaceOp(op,
1203                        MaterializeWithUpcast(rewriter, op.getLoc(), operands,
1204                                              rewriter.getF32Type(),
1205                                              &MaterializeSinhApproximation));
1206     return success();
1207   }
1208 };
1209 
1210 struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
1211   using OpConversionPattern<ZetaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertZetaOp1212   LogicalResult matchAndRewrite(
1213       ZetaOp op, ArrayRef<Value> operands,
1214       ConversionPatternRewriter &rewriter) const override {
1215     Location loc = op.getLoc();
1216     FloatType min_precision_ty = rewriter.getF32Type();
1217     rewriter.replaceOp(
1218         op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty,
1219                                   &MaterializeZeta));
1220     return success();
1221   }
1222 };
1223 
1224 struct ConvertSelectOp : public OpConversionPattern<BroadcastSelectOp> {
1225   using OpConversionPattern<BroadcastSelectOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertSelectOp1226   LogicalResult matchAndRewrite(
1227       BroadcastSelectOp op, ArrayRef<Value> operands,
1228       ConversionPatternRewriter &rewriter) const override {
1229     // Only support ranked operands.
1230     typename BroadcastSelectOp::Adaptor transformed(operands);
1231     Value pred = transformed.pred();
1232     Value on_true = transformed.on_true();
1233     Value on_false = transformed.on_false();
1234     auto pred_type = pred.getType().dyn_cast<RankedTensorType>();
1235     auto on_true_type = on_true.getType().dyn_cast<RankedTensorType>();
1236     auto on_false_type = on_false.getType().dyn_cast<RankedTensorType>();
1237     auto result_type = op.getResult().getType().dyn_cast<RankedTensorType>();
1238     if (!pred_type || !on_true_type || !on_false_type || !result_type) {
1239       return failure();
1240     }
1241 
1242     auto loc = op.getLoc();
1243 
1244     Value pred_shape = rewriter.createOrFold<shape::ShapeOfOp>(loc, pred);
1245     Value on_true_shape = rewriter.createOrFold<shape::ShapeOfOp>(loc, on_true);
1246     Value on_false_shape =
1247         rewriter.createOrFold<shape::ShapeOfOp>(loc, on_false);
1248     int64_t result_rank = std::max(
1249         {pred_type.getRank(), on_true_type.getRank(), on_false_type.getRank()});
1250 
1251     Value broadcastable_cstr =
1252         rewriter.createOrFold<shape::CstrBroadcastableOp>(
1253             loc, ValueRange{pred_shape, on_true_shape, on_false_shape});
1254     auto assuming_op = rewriter.create<shape::AssumingOp>(
1255         loc, ArrayRef<Type>{result_type}, broadcastable_cstr);
1256 
1257     OpBuilder::InsertionGuard guard(rewriter);
1258     rewriter.createBlock(&assuming_op.doRegion());
1259 
1260     Value result_extents = rewriter.createOrFold<shape::BroadcastOp>(
1261         loc, shape::getExtentTensorType(op.getContext()),
1262         ValueRange{pred_shape, on_true_shape, on_false_shape},
1263         /*error=*/nullptr);
1264     auto shape_type =
1265         RankedTensorType::get({result_rank}, rewriter.getIndexType());
1266     result_extents =
1267         rewriter.createOrFold<tensor::CastOp>(loc, shape_type, result_extents);
1268 
1269     Value broadcasted_pred = pred;
1270     // Pred has an implicit broadcast for scalars, so use that when convenient.
1271     if (pred_type.getRank() > 0) {
1272       auto pred_broadcast_dimensions = llvm::to_vector<4>(
1273           llvm::seq<int64_t>(result_rank - pred_type.getRank(), result_rank));
1274       broadcasted_pred = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1275           loc,
1276           RankedTensorType::get(result_type.getShape(),
1277                                 pred_type.getElementType()),
1278           pred, result_extents,
1279           rewriter.getI64TensorAttr(pred_broadcast_dimensions));
1280     }
1281     auto on_true_broadcast_dimensions = llvm::to_vector<4>(
1282         llvm::seq<int64_t>(result_rank - on_true_type.getRank(), result_rank));
1283     Value broadcasted_on_true = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1284         loc,
1285         RankedTensorType::get(result_type.getShape(),
1286                               on_true_type.getElementType()),
1287         on_true, result_extents,
1288         rewriter.getI64TensorAttr(on_true_broadcast_dimensions));
1289     auto on_false_broadcast_dimensions = llvm::to_vector<4>(
1290         llvm::seq<int64_t>(result_rank - on_false_type.getRank(), result_rank));
1291     Value broadcasted_on_false = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1292         loc,
1293         RankedTensorType::get(result_type.getShape(),
1294                               on_false_type.getElementType()),
1295         on_false, result_extents,
1296         rewriter.getI64TensorAttr(on_false_broadcast_dimensions));
1297 
1298     // And generate the final non-broadcasted ternary op.
1299     Value final_result = rewriter.create<mhlo::SelectOp>(
1300         loc, result_type, broadcasted_pred, broadcasted_on_true,
1301         broadcasted_on_false);
1302     rewriter.create<shape::AssumingYieldOp>(loc, final_result);
1303     rewriter.replaceOp(op, {assuming_op.getResult(0)});
1304     return success();
1305   }
1306 };
1307 
1308 // Converts binary ops that statically are determined to not broadcast directly
1309 // to the corresponding mhlo non-broadcasting op.
1310 template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
1311 struct ConvertTrivialNonBroadcastBinaryOp
1312     : public OpConversionPattern<ChloOpTy> {
1313   using OpConversionPattern<ChloOpTy>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertTrivialNonBroadcastBinaryOp1314   LogicalResult matchAndRewrite(
1315       ChloOpTy op, ArrayRef<Value> operands,
1316       ConversionPatternRewriter &rewriter) const override {
1317     // Only rewrite for statically determinable non-broadcasting cases.
1318     typename ChloOpTy::Adaptor transformed(operands);
1319     auto lhs_type =
1320         transformed.lhs().getType().template dyn_cast<RankedTensorType>();
1321     auto rhs_type =
1322         transformed.rhs().getType().template dyn_cast<RankedTensorType>();
1323     if (!lhs_type || !rhs_type) return failure();
1324 
1325     // Requires rank broadcast.
1326     if (lhs_type.getRank() != rhs_type.getRank()) return failure();
1327     // Any dynamic dimension may require broadcasting and requires more
1328     // analysis.
1329     if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape())
1330       return failure();
1331 
1332     for (auto extents : llvm::zip(lhs_type.getShape(), rhs_type.getShape())) {
1333       auto lhs_extent = std::get<0>(extents);
1334       auto rhs_extent = std::get<1>(extents);
1335       if (lhs_extent != rhs_extent) {
1336         return failure();
1337       }
1338     }
1339 
1340     rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(),
1341                                               operands, rewriter)});
1342     return success();
1343   }
1344 };
1345 
1346 // Converts a binary op with ranked broadcasting operands to explicitly
1347 // broadcast and invoke the corresponding mhlo non-broadcasting op.
1348 // Note that dynamic broadcasting supported by this pattern is only valid for
1349 // "numpy" broadcasting semantics as defined here:
1350 //   https://docs.scipy.org/doc/numpy/reference/ufuncs.html
1351 // Specifically, this includes the following cases:
1352 //   - Same rank broadcast (operands have the same static rank).
1353 //   - Different-rank broadcast, either without a broadcast_dims attribte or
1354 //     with the broadcast_dims attribute set to map to a prefix padding.
1355 //   - Legal combinations of degenerate (1-dim) implicit broadcasting.
1356 // The restriction on broadcast_dims derives from the definition of the
1357 // `shape.broadcast` op, which only supports prefix-padding.
1358 template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
1359 struct ConvertRankedDynamicBroadcastBinaryOp
1360     : public OpConversionPattern<ChloOpTy> {
1361   using OpConversionPattern<ChloOpTy>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertRankedDynamicBroadcastBinaryOp1362   LogicalResult matchAndRewrite(
1363       ChloOpTy op, ArrayRef<Value> operands,
1364       ConversionPatternRewriter &rewriter) const override {
1365     // Only support ranked operands.
1366     typename ChloOpTy::Adaptor transformed(operands);
1367     Value lhs = transformed.lhs();
1368     Value rhs = transformed.rhs();
1369     auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
1370     auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
1371     auto result_type =
1372         op.getResult().getType().template dyn_cast<RankedTensorType>();
1373     if (!lhs_type || !rhs_type || !result_type) return failure();
1374 
1375     // Check for "numpy"-style rank broadcast.
1376     auto broadcast_dimensions = op.broadcast_dimensions();
1377     if (broadcast_dimensions &&
1378         !hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) {
1379       // Note: It is unclear whether the general specification of explicit
1380       // broadcast_dimensions on binary ops is a feature we want to carry
1381       // forward. While it can technically be implemented for ranked-dynamic,
1382       // it is incompatible with unranked inputs. If this warning is emitted
1383       // in real programs, it is an indication that the feature should be
1384       // implemented versus just falling back on the more standard definition
1385       // of numpy-like prefix-padding.
1386       op.emitWarning() << "unsupported non prefix-padded dynamic rank "
1387                        << "broadcast_dimensions = " << *broadcast_dimensions;
1388       return failure();
1389     }
1390 
1391     // Compute result shape.
1392     auto loc = op.getLoc();
1393 
1394     // Insert a constraint on the shapes being broadcastable and insert all
1395     // future code into an assuming block reliant on the constraint.
1396     Value lhs_shape = rewriter.create<shape::ShapeOfOp>(loc, lhs);
1397     Value rhs_shape = rewriter.create<shape::ShapeOfOp>(loc, rhs);
1398     auto broadcastable_cstr =
1399         rewriter.create<shape::CstrBroadcastableOp>(loc, lhs_shape, rhs_shape);
1400     auto assuming_op = rewriter.create<shape::AssumingOp>(
1401         loc, ArrayRef<Type>{result_type}, broadcastable_cstr.result());
1402 
1403     OpBuilder::InsertionGuard guard(rewriter);
1404     rewriter.createBlock(&assuming_op.doRegion());
1405 
1406     int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
1407     Value result_extents =
1408         hlo::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
1409                                                                rewriter);
1410 
1411     // Note that we unconditionally emit DynamicBroadcastInDim ops and let
1412     // downstream canonicalizations fold them away if possible. This is
1413     // because, in the dynamic case, there are many corner cases regarding
1414     // when it is safe to omit, and some of them require analysis to prove
1415     // properly.
1416     auto lhs_broadcast_dimensions = llvm::to_vector<4>(
1417         llvm::seq<int64_t>(result_rank - lhs_type.getRank(), result_rank));
1418     Value broadcasted_lhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1419         loc,
1420         RankedTensorType::get(result_type.getShape(),
1421                               lhs_type.getElementType()),
1422         lhs, result_extents,
1423         rewriter.getI64TensorAttr(lhs_broadcast_dimensions));
1424     auto rhs_broadcast_dimensions = llvm::to_vector<4>(
1425         llvm::seq<int64_t>(result_rank - rhs_type.getRank(), result_rank));
1426     Value broadcasted_rhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1427         loc,
1428         RankedTensorType::get(result_type.getShape(),
1429                               rhs_type.getElementType()),
1430         rhs, result_extents,
1431         rewriter.getI64TensorAttr(rhs_broadcast_dimensions));
1432 
1433     // And generate the final non-broadcasted binary op.
1434     Value final_result = Adaptor::CreateOp(
1435         op, result_type, {broadcasted_lhs, broadcasted_rhs}, rewriter);
1436     rewriter.create<shape::AssumingYieldOp>(loc, final_result);
1437     rewriter.replaceOp(op, {assuming_op.getResult(0)});
1438     return success();
1439   }
1440 };
1441 
1442 class ConvertDynamicReshapeOp
1443     : public OpRewritePattern<chlo::DynamicReshapeOp> {
1444  public:
1445   using OpRewritePattern::OpRewritePattern;
1446 
matchAndRewrite(chlo::DynamicReshapeOp op,PatternRewriter & rewriter) const1447   LogicalResult matchAndRewrite(chlo::DynamicReshapeOp op,
1448                                 PatternRewriter &rewriter) const override {
1449     auto loc = op.getLoc();
1450     auto tensor = op.operand();
1451     auto shape = op.output_shape();
1452 
1453     auto shape_ty = shape.getType().cast<ShapedType>();
1454     auto result_ty = op.getType().cast<ShapedType>();
1455 
1456     Value input_shape = rewriter.create<shape::ShapeOfOp>(loc, tensor);
1457     Value num_els = rewriter.create<shape::NumElementsOp>(loc, input_shape);
1458     Value cstr = rewriter.create<mhlo::CstrReshapableOp>(loc, num_els, shape);
1459     rewriter.replaceOpWithNewOp<shape::AssumingOp>(
1460         op, cstr, [&](OpBuilder &b, Location l) {
1461           Value computed_shape = b.create<mhlo::ComputeReshapeShapeOp>(
1462               l, shape_ty, num_els, shape);
1463           SmallVector<Value> result;
1464           result.push_back(b.create<mhlo::DynamicReshapeOp>(
1465               l, result_ty, tensor, computed_shape));
1466           return result;
1467         });
1468 
1469     return success();
1470   }
1471 };
1472 
1473 #include "generated_chlo_legalize_to_hlo.inc"
1474 }  // namespace
1475 
PopulateChloBroadcastingPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1476 void PopulateChloBroadcastingPatterns(MLIRContext *context,
1477                                       OwningRewritePatternList *patterns) {
1478   // Instantiate conversion templates for conforming binary elementwise ops
1479   // that do not have different dtypes between operands and results and do
1480   // not have special attributes that need to be preserved.
1481   PopulateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>(
1482       context, patterns, 10);
1483   PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
1484       context, patterns, 5);
1485   patterns
1486       ->insert<ConvertConstantLikeOp, ConvertDynamicReshapeOp, ConvertSelectOp>(
1487           context);
1488 }
1489 
PopulateDecomposeChloPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1490 void PopulateDecomposeChloPatterns(MLIRContext *context,
1491                                    OwningRewritePatternList *patterns) {
1492   populateWithGenerated(*patterns);
1493 
1494   // Other patterns.
1495   // clang-format off
1496   patterns->insert<ConvertCoshOp,
1497                    ConvertDigammaOp,
1498                    ConvertErfOp,
1499                    ConvertErfcOp,
1500                    ConvertLgammaOp,
1501                    ConvertNextAfterOp,
1502                    ConvertPolygammaOp,
1503                    ConvertSinhOp,
1504                    ConvertZetaOp>(context);
1505   // clang-format on
1506 }
1507 
1508 }  // namespace chlo
1509 }  // namespace mlir
1510