1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Enable the use of M_* math constants.
17 // NOTE: this must be first in the file to ensure that if cmath is transitively
18 // included by any other header it has the define set on first processing.
19 // https://docs.microsoft.com/en-us/cpp/c-runtime-library/math-constants
20 #define _USE_MATH_DEFINES
21 #include <cmath>
22 #include <numeric>
23 #include <vector>
24
25 #include "llvm/ADT/SmallVector.h"
26 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
27 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
28 #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
29 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
30 #include "mlir-hlo/utils/broadcast_utils.h"
31 #include "mlir-hlo/utils/hlo_utils.h"
32 #include "mlir/Dialect/SCF/SCF.h"
33 #include "mlir/Dialect/Shape/IR/Shape.h"
34 #include "mlir/Dialect/StandardOps/IR/Ops.h"
35 #include "mlir/Dialect/Tensor/IR/Tensor.h"
36 #include "mlir/IR/Attributes.h"
37 #include "mlir/IR/BuiltinTypes.h"
38 #include "mlir/IR/ImplicitLocOpBuilder.h"
39 #include "mlir/IR/MLIRContext.h"
40 #include "mlir/IR/OperationSupport.h"
41 #include "mlir/IR/PatternMatch.h"
42 #include "mlir/Transforms/DialectConversion.h"
43
44 namespace mlir {
45 namespace chlo {
46 namespace {
47
48 struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
49 using OpConversionPattern<ConstantLikeOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertConstantLikeOp50 LogicalResult matchAndRewrite(
51 ConstantLikeOp op, ArrayRef<Value> operands,
52 ConversionPatternRewriter &rewriter) const override {
53 auto result_ty = op.getType().cast<ShapedType>();
54
55 // Unranked uses are not supported.
56 if (!result_ty.hasRank()) return failure();
57
58 // Lower to MHLO constant if statically shaped.
59 if (result_ty.hasStaticShape()) {
60 rewriter.replaceOpWithNewOp<mhlo::ConstOp>(
61 op, DenseElementsAttr::get(result_ty, op.value()));
62 return success();
63 }
64
65 // Lower to broadcasted constant.
66 ConstantLikeOp::Adaptor transformed(operands);
67 auto loc = op.getLoc();
68 Type extent_tensor_type = shape::getExtentTensorType(op.getContext());
69 Value constant = rewriter.create<mhlo::ConstOp>(loc, op.value());
70 Value uncasted_shape = rewriter.create<shape::ShapeOfOp>(
71 loc, extent_tensor_type, transformed.operand());
72 Type shape_ty =
73 RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType());
74 Value shape =
75 rewriter.create<tensor::CastOp>(loc, shape_ty, uncasted_shape);
76 rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
77 op, result_ty, constant, shape, rewriter.getI64TensorAttr({}));
78 return success();
79 }
80 };
81
82 template <typename FTy>
MaterializePolynomialApproximation(ConversionPatternRewriter & rewriter,Location loc,Value x,const std::vector<FTy> & coefficients)83 Value MaterializePolynomialApproximation(ConversionPatternRewriter &rewriter,
84 Location loc, Value x,
85 const std::vector<FTy> &coefficients) {
86 Value poly = chlo::getConstantLike(rewriter, loc, 0.0, x);
87 for (FTy c : coefficients) {
88 poly = rewriter.create<mhlo::MulOp>(loc, x.getType(), poly, x);
89 poly = rewriter.create<mhlo::AddOp>(
90 loc, x.getType(), poly, chlo::getConstantLike(rewriter, loc, c, x));
91 }
92 return poly;
93 }
94
95 // Precondition is |x| >= 1. Use erf approximation, otherwise.
96 //
97 // We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
98 // argument and derive the final approximation for all |x| >= 1.
99 // This implementation is based on Cephes.
MaterializeErfcApproximationF64ForMagnituteGEOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)100 Value MaterializeErfcApproximationF64ForMagnituteGEOne(
101 ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
102 Value x = args.front();
103 assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
104 "expect f64 element type");
105 const double kMaxlog = 7.09782712893383996843E2;
106 const std::vector<double> kErfcPCoefficients{
107 2.46196981473530512524E-10, 5.64189564831068821977E-1,
108 7.46321056442269912687E0, 4.86371970985681366614E1,
109 1.96520832956077098242E2, 5.26445194995477358631E2,
110 9.34528527171957607540E2, 1.02755188689515710272E3,
111 5.57535335369399327526E2};
112 const std::vector<double> kErfcQCoefficients{
113 1.00000000000000000000E0, 1.32281951154744992508E1,
114 8.67072140885989742329E1, 3.54937778887819891062E2,
115 9.75708501743205489753E2, 1.82390916687909736289E3,
116 2.24633760818710981792E3, 1.65666309194161350182E3,
117 5.57535340817727675546E2};
118 const std::vector<double> kErfcRCoefficients{
119 5.64189583547755073984E-1, 1.27536670759978104416E0,
120 5.01905042251180477414E0, 6.16021097993053585195E0,
121 7.40974269950448939160E0, 2.97886665372100240670E0};
122 const std::vector<double> kErfcSCoefficients{
123 1.00000000000000000000E0, 2.26052863220117276590E0,
124 9.39603524938001434673E0, 1.20489539808096656605E1,
125 1.70814450747565897222E1, 9.60896809063285878198E0,
126 3.36907645100081516050E0};
127
128 // Let z = -x^2.
129 Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
130 Value z = rewriter.create<mhlo::NegOp>(loc, x_sq);
131
132 // Materialize polynomial approximation for x in [1, 8) as
133 // erfc(x) = exp(z) P(|x|) / Q(|x|).
134 Value exp_z = rewriter.create<mhlo::ExpOp>(loc, z);
135 Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
136 Value poly_p = MaterializePolynomialApproximation(rewriter, loc, abs_x,
137 kErfcPCoefficients);
138 Value exp_z_mul_poly_p = rewriter.create<mhlo::MulOp>(loc, exp_z, poly_p);
139 Value poly_q = MaterializePolynomialApproximation(rewriter, loc, abs_x,
140 kErfcQCoefficients);
141 Value erfc_approx_1_8 =
142 rewriter.create<mhlo::DivOp>(loc, exp_z_mul_poly_p, poly_q);
143
144 // Materialize polynomial approximation for x in >= 8 as
145 // erfc(x) exp(z) R(|x|) / S(|x|).
146 Value poly_r = MaterializePolynomialApproximation(rewriter, loc, abs_x,
147 kErfcRCoefficients);
148 Value exp_z_mul_poly_r = rewriter.create<mhlo::MulOp>(loc, exp_z, poly_r);
149 Value poly_s = MaterializePolynomialApproximation(rewriter, loc, abs_x,
150 kErfcSCoefficients);
151 Value erfc_approx_8_inf =
152 rewriter.create<mhlo::DivOp>(loc, exp_z_mul_poly_r, poly_s);
153
154 // Combine polynomial approximations for x >= 1.
155 const StringAttr kLT = rewriter.getStringAttr(
156 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
157 Value eight = chlo::getConstantLike(rewriter, loc, 8.0, x);
158 Value abs_x_lt_8 = rewriter.create<mhlo::CompareOp>(loc, abs_x, eight, kLT);
159 Value erfc_approx = rewriter.create<mhlo::SelectOp>(
160 loc, abs_x_lt_8, erfc_approx_1_8, erfc_approx_8_inf);
161
162 // Clamp to prevent overflow and materialize approximation for large x as
163 // erfc(x) = 0.
164 Value z_lt_neg_maxlog = rewriter.create<mhlo::CompareOp>(
165 loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), kLT);
166 Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
167 Value erfc_approx_clamped =
168 rewriter.create<mhlo::SelectOp>(loc, z_lt_neg_maxlog, zero, erfc_approx);
169
170 // Derive approximation for x <= -1 as
171 // erfc(x) = 2 - erfc(-x).
172 // Reuse previously materialized approximations all of which take |x| as their
173 // argument.
174 Value x_lt_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLT);
175 Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
176 Value two_sub_erfc_approx_clamped =
177 rewriter.create<mhlo::SubOp>(loc, two, erfc_approx_clamped);
178 return rewriter.create<mhlo::SelectOp>(
179 loc, x_lt_zero, two_sub_erfc_approx_clamped, erfc_approx_clamped);
180 }
181
182 // Precondition is |x| <= 1. Use erfc approximation, otherwise.
183 // This implementation is based on Cephes.
MaterializeErfApproximationF64ForMagnituteLEOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)184 Value MaterializeErfApproximationF64ForMagnituteLEOne(
185 ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
186 Value x = args.front();
187 assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
188 "expect f64 element type");
189 const std::vector<double> kErfTCoefficients{
190 9.60497373987051638749E0, 9.00260197203842689217E1,
191 2.23200534594684319226E3, 7.00332514112805075473E3,
192 5.55923013010394962768E4};
193 const std::vector<double> kErfUCoefficients{
194 1.00000000000000000000E0, 3.35617141647503099647E1,
195 5.21357949780152679795E2, 4.59432382970980127987E3,
196 2.26290000613890934246E4, 4.92673942608635921086E4};
197
198 // Materialize polynomial approximation for |x| <= 1 as
199 // erf(x) = x T(x^2) / U(x^2).
200 Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
201 Value poly_t = MaterializePolynomialApproximation(rewriter, loc, x_sq,
202 kErfTCoefficients);
203 Value x_mul_poly_t = rewriter.create<mhlo::MulOp>(loc, x, poly_t);
204 Value poly_u = MaterializePolynomialApproximation(rewriter, loc, x_sq,
205 kErfUCoefficients);
206 return rewriter.create<mhlo::DivOp>(loc, x_mul_poly_t, poly_u);
207 }
208
209 // This implementation is based on Cephes.
MaterializeErfApproximationF64(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)210 Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter,
211 Location loc, ValueRange args) {
212 Value x = args.front();
213 assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
214 "expect f64 element type");
215
216 // Rely on erf approximation for |x| < 1
217 // erf(x) = erf_approx(x)
218 Value erf_approx =
219 MaterializeErfApproximationF64ForMagnituteLEOne(rewriter, loc, x);
220
221 // Rely on erfc approximation for |x| >= 1 and materialize erf as
222 // erf(x) = 1 - erfc_approx(x)
223 Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
224 Value erfc_approx =
225 MaterializeErfcApproximationF64ForMagnituteGEOne(rewriter, loc, x);
226 Value erfc_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erfc_approx);
227
228 // Materialize approximation selection based on argument.
229 Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
230 const StringAttr kLT = rewriter.getStringAttr(
231 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
232 Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
233 return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_approx,
234 erfc_based_approx);
235 }
236
MaterializeErfcApproximationF64(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)237 Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
238 Location loc, ValueRange args) {
239 Value x = args.front();
240 assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
241 "expect f64 element type");
242
243 // Rely on erfc approximation for |x| >= 1
244 // erfc(x) = erfc_approx(x)
245 Value erfc_approx =
246 MaterializeErfcApproximationF64ForMagnituteGEOne(rewriter, loc, x);
247
248 // Rely on erf approximation for |x| < 1 and materialize erfc as
249 // erfc(x) = 1 - erf_approx(x)
250 Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
251 Value erf_approx =
252 MaterializeErfApproximationF64ForMagnituteLEOne(rewriter, loc, x);
253 Value erf_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erf_approx);
254
255 // Materialize approximation selection based on argument.
256 Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
257 const StringAttr kLT = rewriter.getStringAttr(
258 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
259 Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
260 return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_based_approx,
261 erfc_approx);
262 }
263
264 // Precondition is |x| >= 1. Use erf approximation, otherwise.
265 //
266 // We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
267 // argument and derive the final approximation for all |x| >= 1.
268 // This implementation is based on Cephes.
MaterializeErfcApproximationF32ForMagnitudeGEOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)269 Value MaterializeErfcApproximationF32ForMagnitudeGEOne(
270 ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
271 Value x = args.front();
272 assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
273 "expect f32 element type");
274 const double kMaxlog = 88.72283905206835;
275 const std::vector<float> kErfcPCoefficients{
276 +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1,
277 -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1,
278 +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1,
279 };
280 const std::vector<float> kErfcRCoefficients{
281 -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0,
282 +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1,
283 -2.820767439740514E-1, +5.641895067754075E-1,
284 };
285
286 // Let z = -x^2.
287 Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
288 Value z = rewriter.create<mhlo::NegOp>(loc, x_sq);
289
290 // Materialize polynomial approximation for x >= 1 as
291 // erfc(x) = exp(z) 1/x P(1/x^2) if x in [1, 2)
292 // erfc(x) = exp(z) 1/x R(1/x^2) if x >= 2
293 const StringAttr kLT = rewriter.getStringAttr(
294 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
295 Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
296 Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
297 Value reciprocal_x_sq = rewriter.create<mhlo::DivOp>(loc, one, x_sq);
298 Value exp_z = rewriter.create<mhlo::ExpOp>(loc, z);
299 Value one_div_abs_x = rewriter.create<mhlo::DivOp>(loc, one, abs_x);
300 Value exp_z_mul_one_div_abs_x =
301 rewriter.create<mhlo::MulOp>(loc, exp_z, one_div_abs_x);
302 Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
303 Value abs_x_lt_two = rewriter.create<mhlo::CompareOp>(loc, abs_x, two, kLT);
304 Value poly_p = MaterializePolynomialApproximation(
305 rewriter, loc, reciprocal_x_sq, kErfcPCoefficients);
306 Value poly_r = MaterializePolynomialApproximation(
307 rewriter, loc, reciprocal_x_sq, kErfcRCoefficients);
308 Value poly =
309 rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_two, poly_p, poly_r);
310 Value erfc_approx =
311 rewriter.create<mhlo::MulOp>(loc, exp_z_mul_one_div_abs_x, poly);
312
313 // Clamp to prevent overflow and materialize approximation for large x as
314 // erfc(x) = 0.
315 Value z_lt_neq_maxlog = rewriter.create<mhlo::CompareOp>(
316 loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), kLT);
317 Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
318 Value erfc_approx_clamped =
319 rewriter.create<mhlo::SelectOp>(loc, z_lt_neq_maxlog, zero, erfc_approx);
320
321 // Derive approximation for x <= -1 as
322 // erfc(x) = 2 - erfc(-x).
323 // Reuse previously materialized approximations all of which take |x| as their
324 // argument.
325 Value x_lt_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLT);
326 Value two_sub_erfc_approx =
327 rewriter.create<mhlo::SubOp>(loc, two, erfc_approx_clamped);
328 return rewriter.create<mhlo::SelectOp>(loc, x_lt_zero, two_sub_erfc_approx,
329 erfc_approx_clamped);
330 }
331
332 // Precondition is |x| <= 1. Use erfc approximation, otherwise.
333 // This implementation is based on Cephes.
MaterializeErfApproximationF32ForMagnitudeLEOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)334 Value MaterializeErfApproximationF32ForMagnitudeLEOne(
335 ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
336 Value x = args.front();
337 assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
338 "expect f32 element type");
339 const std::vector<float> kErfTCoefficients{
340 +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3,
341 -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1,
342 +1.128379165726710E+0,
343 };
344
345 // Materialize polynomial approximation for |x| <= 1 as
346 // erf(x) = x T(x^2).
347 Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
348 Value poly_t = MaterializePolynomialApproximation(rewriter, loc, x_sq,
349 kErfTCoefficients);
350 return rewriter.create<mhlo::MulOp>(loc, x, poly_t);
351 }
352
353 // This is the same approximation as used in Eigen.
MaterializeErfApproximationF32(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)354 Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
355 Location loc, ValueRange args) {
356 Value x = args.front();
357 assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
358 "expect f32 element type");
359 const std::vector<float> kAlpha{
360 -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
361 -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
362 -1.60960333262415e-02f,
363 };
364 const std::vector<float> kBeta{
365 -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
366 -7.37332916720468e-03f, -1.42647390514189e-02f,
367 };
368
369 // Clamp argument between -4 and 4.
370 Value lb = chlo::getConstantLike(rewriter, loc, -4.0, x);
371 Value ub = chlo::getConstantLike(rewriter, loc, 4.0, x);
372 x = rewriter.create<mhlo::ClampOp>(loc, x.getType(), lb, x, ub);
373 Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
374
375 // Materialize polynomial approximation for x in [-4, 4] as
376 // erf(x) = x * Alpha(x^2) / Beta(x^2).
377 Value alpha_poly =
378 MaterializePolynomialApproximation(rewriter, loc, x_sq, kAlpha);
379 Value beta_poly =
380 MaterializePolynomialApproximation(rewriter, loc, x_sq, kBeta);
381 Value x_mul_alpha_poly = rewriter.create<mhlo::MulOp>(loc, x, alpha_poly);
382 return rewriter.create<mhlo::DivOp>(loc, x_mul_alpha_poly, beta_poly);
383 }
384
MaterializeErfcApproximationF32(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)385 Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
386 Location loc, ValueRange args) {
387 Value x = args.front();
388 assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
389 "expect f32 element type");
390
391 // Rely on erfc approximation for |x| >= 1
392 // erfc(x) = erfc_approx(x)
393 Value erfc_approx =
394 MaterializeErfcApproximationF32ForMagnitudeGEOne(rewriter, loc, x);
395
396 // Rely on erf approximation for |x| < 1 and materialize erfc as
397 // erfc(x) = 1 - erf_approx(x)
398 Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
399 Value erf_approx =
400 MaterializeErfApproximationF32ForMagnitudeLEOne(rewriter, loc, x);
401 Value erf_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erf_approx);
402
403 // Materialize approximation selection based on argument.
404 const StringAttr kLT = rewriter.getStringAttr(
405 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
406 Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
407 Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
408 return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_based_approx,
409 erfc_approx);
410 }
411
MaterializeWithUpcast(ConversionPatternRewriter & rewriter,Location loc,ValueRange args,FloatType min_precision_ty,Value callback (ConversionPatternRewriter &,Location,ValueRange))412 Value MaterializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc,
413 ValueRange args, FloatType min_precision_ty,
414 Value callback(ConversionPatternRewriter &,
415 Location, ValueRange)) {
416 auto original_ty =
417 getElementTypeOrSelf(args.front().getType()).cast<FloatType>();
418 bool needs_upcast = original_ty.getWidth() < min_precision_ty.getWidth();
419
420 // Upcast arguments if necessary.
421 llvm::SmallVector<Value, 2> casted_args;
422 if (needs_upcast) {
423 for (Value a : args) {
424 casted_args.push_back(
425 rewriter.create<mhlo::ConvertOp>(loc, a, min_precision_ty));
426 }
427 args = casted_args;
428 }
429
430 Value result = callback(rewriter, loc, args);
431
432 // Cast back if necessary.
433 if (needs_upcast) {
434 result = rewriter.create<mhlo::ConvertOp>(loc, result, original_ty);
435 }
436
437 return result;
438 }
439
440 struct ConvertErfOp : public OpConversionPattern<ErfOp> {
441 using OpConversionPattern<ErfOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertErfOp442 LogicalResult matchAndRewrite(
443 ErfOp op, ArrayRef<Value> operands,
444 ConversionPatternRewriter &rewriter) const override {
445 Location loc = op.getLoc();
446 ErfOp::Adaptor transformed(operands);
447 Value x = transformed.operand();
448 Type ty = x.getType().cast<ShapedType>().getElementType();
449
450 // For now, we support only f64, f32, and f16.
451 if (!ty.isF64() && !ty.isF32() && !ty.isF16()) return failure();
452
453 if (ty.isF64()) {
454 rewriter.replaceOp(op, MaterializeErfApproximationF64(rewriter, loc, x));
455 return success();
456 }
457
458 rewriter.replaceOp(op, MaterializeWithUpcast(
459 rewriter, loc, operands, rewriter.getF32Type(),
460 &MaterializeErfApproximationF32));
461 return success();
462 }
463 };
464
465 struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
466 using OpConversionPattern<ErfcOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertErfcOp467 LogicalResult matchAndRewrite(
468 ErfcOp op, ArrayRef<Value> operands,
469 ConversionPatternRewriter &rewriter) const override {
470 Location loc = op.getLoc();
471 ErfcOp::Adaptor transformed(operands);
472 Value x = transformed.operand();
473 Type ty = x.getType().cast<ShapedType>().getElementType();
474
475 // For now, we support only f64, f32, and f16.
476 if (!ty.isF64() && !ty.isF32() && !ty.isF16()) return failure();
477
478 if (ty.isF64()) {
479 rewriter.replaceOp(op, MaterializeErfcApproximationF64(rewriter, loc, x));
480 return success();
481 }
482
483 rewriter.replaceOp(op, MaterializeWithUpcast(
484 rewriter, loc, operands, rewriter.getF32Type(),
485 &MaterializeErfcApproximationF32));
486 return success();
487 }
488 };
489
490 // Coefficients for the Lanczos approximation of the gamma function. The
491 // coefficients are uniquely determined by the choice of g and n (kLanczosGamma
492 // and kLanczosCoefficients.size() + 1). The coefficients below correspond to
493 // [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and
494 // [7, 9] seemed to be the least sensitive to the quality of the log function.
495 // In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
496 // for a particularly inaccurate log function.
497 constexpr double kLanczosGamma = 7; // aka g
498 constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
499 constexpr std::array<double, 8> kLanczosCoefficients = {
500 676.520368121885098567009190444019, -1259.13921672240287047156078755283,
501 771.3234287776530788486528258894, -176.61502916214059906584551354,
502 12.507343278686904814458936853, -0.13857109526572011689554707,
503 9.984369578019570859563e-6, 1.50563273514931155834e-7};
504
505 // Compute the Lgamma function using Lanczos' approximation from "A Precision
506 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
507 // series B. Vol. 1:
508 // lgamma(z + 1) = (log(2) + log(pi)) / 2
509 // + (z + 1/2) * log(t(z))
510 // - t(z) + log(a(z))
511 // with t(z) = z + kLanczosGamma + 1/2
512 // a(z) = kBaseLanczosCoeff
513 // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
MaterializeLgamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)514 Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
515 ValueRange args) {
516 // If the input is less than 0.5 use Euler's reflection formula.
517 // gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
518 // Let z be
519 // z = -x if x < 1/2
520 // z = x - 1 otheriwse
521 Value x = args.front();
522 const StringAttr kLT = rewriter.getStringAttr(
523 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
524 Value half = getConstantLike(rewriter, loc, 0.5, x);
525 Value need_to_reflect = rewriter.create<mhlo::CompareOp>(loc, x, half, kLT);
526 Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
527 Value one = getConstantLike(rewriter, loc, 1, x);
528 Value x_sub_one = rewriter.create<mhlo::SubOp>(loc, x, one);
529 Value z =
530 rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, neg_x, x_sub_one);
531
532 // Materialize
533 // a(z) = kBaseLanczosCoeff
534 // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
535 Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
536 for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
537 Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
538 Value one_based_index = getConstantLike(rewriter, loc, i + 1, x);
539 Value quotient = rewriter.create<mhlo::DivOp>(
540 loc, coeff, rewriter.create<mhlo::AddOp>(loc, z, one_based_index));
541 a = rewriter.create<mhlo::AddOp>(loc, a, quotient);
542 }
543
544 // To improve accuracy on platforms with less-precise log implementations,
545 // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
546 // device.
547 // Materialize as
548 // log(t) = log(kLanczosGamma + 1/2 + z)
549 // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
550 Value lanczos_plus_half =
551 getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
552 Value t = rewriter.create<mhlo::AddOp>(loc, lanczos_plus_half, z);
553 Value log_term =
554 getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
555 Value log1p_term = rewriter.create<mhlo::Log1pOp>(
556 loc, rewriter.create<mhlo::DivOp>(loc, z, lanczos_plus_half));
557 Value log_t = rewriter.create<mhlo::AddOp>(loc, log_term, log1p_term);
558
559 // Note that t(z) may be large and we need to be careful not to overflow to
560 // infinity in the relevant term
561 // r = (z + 1/2) * log(t(z)) - t(z).
562 // Therefore, we compute this as
563 // r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
564 Value t_div_log_t = rewriter.create<mhlo::DivOp>(loc, t, log_t);
565 Value sum = rewriter.create<mhlo::SubOp>(
566 loc, rewriter.create<mhlo::AddOp>(loc, z, half), t_div_log_t);
567 Value r = rewriter.create<mhlo::MulOp>(loc, sum, log_t);
568
569 // Compute the final result (modulo reflection) as
570 // lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)).
571 Value log_a = rewriter.create<mhlo::LogOp>(loc, a);
572 Value lgamma = rewriter.create<mhlo::AddOp>(
573 loc,
574 rewriter.create<mhlo::AddOp>(
575 loc,
576 getConstantLike(rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x),
577 r),
578 log_a);
579
580 // Compute the reflected value for x < 0.5 as
581 // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
582 //
583 // The abs is needed because lgamma is the log of the absolute value of the
584 // gamma function.
585 //
586 // We have to be careful when computing the final term above. gamma(x) goes
587 // to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x)
588 // term. The slope is large, so precision is particularly important.
589 //
590 // Because abs(sin(pi * x)) has period of 1 we can equivalently use
591 // abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is
592 // more numerically accurate: It doesn't overflow to inf like pi * x would and
593 // if x is an integer it evaluates to exactly 0 which is important because we
594 // then take the log of this value, and log(0) is inf.
595 //
596 // We don't have a frac(x) primitive in HLO and computing it is tricky, but
597 // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our
598 // purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
599 //
600 // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
601 // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
602 // [0, 1] is symmetric across the line Y=0.5.
603 //
604
605 // Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of
606 // pi * abs_frac for values of abs_frac close to 1.
607 Value abs = rewriter.create<mhlo::AbsOp>(loc, x);
608 Value abs_frac = rewriter.create<mhlo::SubOp>(
609 loc, abs, rewriter.create<mhlo::FloorOp>(loc, abs));
610 Value reduce_abs_frac =
611 rewriter.create<mhlo::CompareOp>(loc, half, abs_frac, kLT);
612 abs_frac = rewriter.create<mhlo::SelectOp>(
613 loc, reduce_abs_frac, rewriter.create<mhlo::SubOp>(loc, one, abs_frac),
614 abs_frac);
615
616 // Materialize reflection.
617 Value reflection_denom = rewriter.create<mhlo::LogOp>(
618 loc,
619 rewriter.create<mhlo::SinOp>(
620 loc, rewriter.create<mhlo::MulOp>(
621 loc, getConstantLike(rewriter, loc, M_PI, x), abs_frac)));
622 Value lgamma_reflection = rewriter.create<mhlo::SubOp>(
623 loc,
624 rewriter.create<mhlo::SubOp>(
625 loc, getConstantLike(rewriter, loc, std::log(M_PI), x),
626 reflection_denom),
627 lgamma);
628
629 // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
630 // then it "wins" and the result is +/-inf.
631 Value finite_reflection_denom =
632 rewriter.create<mhlo::IsFiniteOp>(loc, reflection_denom);
633 Value neg_reflection_denom =
634 rewriter.create<mhlo::NegOp>(loc, reflection_denom);
635 lgamma_reflection = rewriter.create<mhlo::SelectOp>(
636 loc, finite_reflection_denom, lgamma_reflection, neg_reflection_denom);
637
638 // Select whether or not to rely on the reflection.
639 lgamma = rewriter.create<mhlo::SelectOp>(loc, need_to_reflect,
640 lgamma_reflection, lgamma);
641
642 // Materialize +/-inf behavior as
643 // lgamma(+/-inf) = +inf.
644 Value x_is_inf = rewriter.create<chlo::IsInfOp>(loc, x);
645 return rewriter.create<mhlo::SelectOp>(
646 loc, x_is_inf,
647 chlo::getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false),
648 lgamma);
649 }
650
651 // Express `cosh` as
652 // cosh(x) = (e^x + e^-x) / 2
653 // = e^(x + log(1/2)) + e^(-x + log(1/2))
654 //
655 // The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not.
656 //
657 // This incorrectly overflows to inf for two f32 input values, namely
658 // +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
659 // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
660 // we deem this acceptable.
MaterializeCoshApproximation(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)661 Value MaterializeCoshApproximation(ConversionPatternRewriter &rewriter,
662 Location loc, ValueRange operands) {
663 CoshOp::Adaptor transformed(operands);
664 Value x = transformed.operand();
665
666 Value log_one_half =
667 rewriter.create<mhlo::LogOp>(loc, getConstantLike(rewriter, loc, 0.5, x));
668 Value exp_add = rewriter.create<mhlo::ExpOp>(
669 loc, rewriter.create<mhlo::AddOp>(loc, x, log_one_half));
670 Value exp_sub = rewriter.create<mhlo::ExpOp>(
671 loc, rewriter.create<mhlo::SubOp>(loc, log_one_half, x));
672 return rewriter.create<mhlo::AddOp>(loc, exp_add, exp_sub);
673 }
674
675 struct ConvertCoshOp : public OpConversionPattern<CoshOp> {
676 using OpConversionPattern<CoshOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertCoshOp677 LogicalResult matchAndRewrite(
678 CoshOp op, ArrayRef<Value> operands,
679 ConversionPatternRewriter &rewriter) const override {
680 CoshOp::Adaptor transformed(operands);
681 Value x = transformed.operand();
682 if (x.getType().cast<ShapedType>().getElementType().isa<ComplexType>()) {
683 // TODO(hinsu): Support operands with complex element types by always
684 // using the formula for large x. The compare op is not legal for complex
685 // numbers.
686 return failure();
687 }
688 rewriter.replaceOp(op,
689 MaterializeWithUpcast(rewriter, op.getLoc(), operands,
690 rewriter.getF32Type(),
691 &MaterializeCoshApproximation));
692 return success();
693 }
694 };
695
696 // Compute the Digamma function using Lanczos' approximation from "A Precision
697 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
698 // series B. Vol. 1:
699 // digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z)
700 // with t(z) = z + kLanczosGamma + 1/2
701 // a(z) = kBaseLanczosCoeff
702 // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
703 // a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
MaterializeDigamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)704 Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
705 ValueRange args) {
706 // If the input is less than 0.5 use Euler's reflection formula.
707 // digamma(x) = digamma(1 - x) - pi * cot(pi * x)
708 // Let z be
709 // z = -x if x < 1/2
710 // z = x - 1 otheriwse
711 Value x = args.front();
712 const StringAttr kLT = rewriter.getStringAttr(
713 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
714 Value half = getConstantLike(rewriter, loc, 0.5, x);
715 Value need_to_reflect = rewriter.create<mhlo::CompareOp>(loc, x, half, kLT);
716 Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
717 Value one = getConstantLike(rewriter, loc, 1, x);
718 Value x_sub_one = rewriter.create<mhlo::SubOp>(loc, x, one);
719 Value z =
720 rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, neg_x, x_sub_one);
721
722 // Materialize
723 // a(z) = kBaseLanczosCoeff
724 // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
725 // a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
726 Value zero = getConstantLike(rewriter, loc, 0.0, x);
727 Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
728 Value a_prime = zero;
729 for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
730 Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
731 Value one_based_index = getConstantLike(rewriter, loc, i + 1, x);
732 Value z_term = rewriter.create<mhlo::AddOp>(loc, z, one_based_index);
733 a_prime = rewriter.create<mhlo::SubOp>(
734 loc, a_prime,
735 rewriter.create<mhlo::DivOp>(
736 loc, coeff, rewriter.create<mhlo::MulOp>(loc, z_term, z_term)));
737 a = rewriter.create<mhlo::AddOp>(
738 loc, a, rewriter.create<mhlo::DivOp>(loc, coeff, z_term));
739 }
740
741 // To improve accuracy on platforms with less-precise log implementations,
742 // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
743 // device.
744 // Materialize as
745 // log(t) = log(kLanczosGamma + 1/2 + z)
746 // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
747 Value lanczos_plus_half =
748 getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
749 Value t = rewriter.create<mhlo::AddOp>(loc, lanczos_plus_half, z);
750 Value log_term =
751 getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
752 Value log1p_term = rewriter.create<mhlo::Log1pOp>(
753 loc, rewriter.create<mhlo::DivOp>(loc, z, lanczos_plus_half));
754 Value log_t = rewriter.create<mhlo::AddOp>(loc, log_term, log1p_term);
755
756 // Materialize the final result (modulo reflection) as
757 // digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z).
758 Value a_prime_div_a = rewriter.create<mhlo::DivOp>(loc, a_prime, a);
759 Value lanczos_gamma_div_t = rewriter.create<mhlo::DivOp>(
760 loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t);
761 Value digamma = rewriter.create<mhlo::SubOp>(
762 loc, rewriter.create<mhlo::AddOp>(loc, log_t, a_prime_div_a),
763 lanczos_gamma_div_t);
764
765 // We need to be careful how we compute cot(pi * input) below: For
766 // near-integral arguments, pi * input can lose precision.
767 //
768 // Input is already known to be less than 0.5 (otherwise we don't have to
769 // reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to
770 // increase precision of pi * x and the resulting cotangent.
771 Value reduced_x = rewriter.create<mhlo::AddOp>(
772 loc, x,
773 rewriter.create<mhlo::AbsOp>(
774 loc, rewriter.create<mhlo::FloorOp>(
775 loc, rewriter.create<mhlo::AddOp>(
776 loc, x, getConstantLike(rewriter, loc, 0.5, x)))));
777
778 // Materialize reflection for inputs less than 0.5 as
779 // digamma(x) = digamma(1 - x) - pi * cot(pi * x)
780 // = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x)
781 Value pi = getConstantLike(rewriter, loc, M_PI, x);
782 Value pi_mul_reduced_x = rewriter.create<mhlo::MulOp>(loc, pi, reduced_x);
783 Value cos = rewriter.create<mhlo::CosOp>(loc, pi_mul_reduced_x);
784 Value sin = rewriter.create<mhlo::SinOp>(loc, pi_mul_reduced_x);
785 Value reflection = rewriter.create<mhlo::SubOp>(
786 loc, digamma,
787 rewriter.create<mhlo::DivOp>(
788 loc, rewriter.create<mhlo::MulOp>(loc, pi, cos), sin));
789
790 // Select whether or not to rely on the reflection.
791 digamma = rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, reflection,
792 digamma);
793
794 // Digamma has poles at negative integers and zero; return nan for those.
795 const StringAttr kLE = rewriter.getStringAttr(
796 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
797 Value is_le_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLE);
798 const StringAttr kEQ = rewriter.getStringAttr(
799 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
800 Value is_int = rewriter.create<mhlo::CompareOp>(
801 loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kEQ);
802 Value is_pole = rewriter.create<mhlo::AndOp>(loc, is_le_zero, is_int);
803 return rewriter.create<mhlo::SelectOp>(
804 loc, is_pole,
805 getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
806 x),
807 digamma);
808 }
809
MaterializeZeta(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)810 Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
811 ValueRange args) {
812 assert(args.size() == 2);
813 Value x = args[0];
814 Value q = args[1];
815 static const std::array<double, 12> kZetaCoeffs{
816 -7.1661652561756670113e18,
817 1.8152105401943546773e17,
818 -4.5979787224074726105e15,
819 1.1646782814350067249e14,
820 -2.950130727918164224e12,
821 7.47242496e10,
822 -1.8924375803183791606e9,
823 47900160.0,
824 -1209600.0,
825 30240.0,
826 -720.0,
827 12.0,
828 };
829
830 // For speed we'll always use 9 iterations for the initial series estimate,
831 // and a 12 term expansion for the Euler-Maclaurin formula.
832 Value a = q;
833 Value zero = chlo::getConstantLike(rewriter, loc, 0.0, a);
834 Value neg_power = zero;
835 Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
836 Value initial_sum = rewriter.create<mhlo::PowOp>(loc, q, neg_x);
837 Value one = chlo::getConstantLike(rewriter, loc, 1.0, a);
838 for (int i = 0; i < 9; ++i) {
839 a = rewriter.create<mhlo::AddOp>(loc, a, one);
840 neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
841 initial_sum = rewriter.create<mhlo::AddOp>(loc, initial_sum, neg_power);
842 }
843 a = rewriter.create<mhlo::AddOp>(loc, a, one);
844 neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
845 Value one_like_x = chlo::getConstantLike(rewriter, loc, 1.0, x);
846 Value x_minus_one = rewriter.create<mhlo::SubOp>(loc, x, one_like_x);
847 Value neg_power_mul_a = rewriter.create<mhlo::MulOp>(loc, neg_power, a);
848 Value neg_power_mul_a_div_x_minus_one =
849 rewriter.create<mhlo::DivOp>(loc, neg_power_mul_a, x_minus_one);
850 Value s = rewriter.create<mhlo::AddOp>(loc, initial_sum,
851 neg_power_mul_a_div_x_minus_one);
852 Value a_inverse_square = rewriter.create<mhlo::DivOp>(
853 loc, one, rewriter.create<mhlo::MulOp>(loc, a, a));
854
855 Value horner_sum = zero;
856 Value factor = one;
857 // Use Horner's rule for this.
858 // Note this differs from Cephes which does a 'naive' polynomial evaluation.
859 // Using Horner's rule allows to avoid some NaN's and Infs from happening,
860 // resulting in more numerically stable code.
861 for (int i = 0; i < 11; ++i) {
862 Value factor_lhs = rewriter.create<mhlo::SubOp>(
863 loc, x, chlo::getConstantLike(rewriter, loc, 22 - 2 * i, x));
864 Value factor_rhs = rewriter.create<mhlo::SubOp>(
865 loc, x, chlo::getConstantLike(rewriter, loc, 21 - 2 * i, x));
866 factor = rewriter.create<mhlo::MulOp>(loc, factor_lhs, factor_rhs);
867 horner_sum = rewriter.create<mhlo::MulOp>(
868 loc, factor,
869 rewriter.create<mhlo::MulOp>(
870 loc, a_inverse_square,
871 rewriter.create<mhlo::AddOp>(
872 loc, horner_sum,
873 chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a))));
874 }
875 Value zero_point_five_like_neg_power =
876 chlo::getConstantLike(rewriter, loc, .5, neg_power);
877 Value x_div_a = rewriter.create<mhlo::DivOp>(loc, x, a);
878 s = rewriter.create<mhlo::AddOp>(
879 loc, s,
880 rewriter.create<mhlo::MulOp>(
881 loc, neg_power,
882 rewriter.create<mhlo::AddOp>(
883 loc, zero_point_five_like_neg_power,
884 rewriter.create<mhlo::MulOp>(
885 loc, x_div_a,
886 rewriter.create<mhlo::AddOp>(
887 loc,
888 chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11],
889 a),
890 horner_sum)))));
891
892 // Use the initial zeta sum without the correction term coming
893 // from Euler-Maclaurin if it is accurate enough.
894 const StringAttr kLT = rewriter.getStringAttr(
895 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
896 Value abs_neg_power = rewriter.create<mhlo::AbsOp>(loc, neg_power);
897 Value abs_initial_sum = rewriter.create<mhlo::AbsOp>(loc, initial_sum);
898 Value output = rewriter.create<mhlo::SelectOp>(
899 loc,
900 rewriter.create<mhlo::CompareOp>(
901 loc, abs_neg_power,
902 rewriter.create<mhlo::MulOp>(
903 loc, abs_initial_sum,
904 chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)),
905 kLT),
906 initial_sum, s);
907
908 // Function is not defined for x < 1.
909 Value nan = chlo::getConstantLike(
910 rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
911 output = rewriter.create<mhlo::SelectOp>(
912 loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kLT), nan,
913 output);
914
915 // For q <= 0, x must be an integer.
916 const StringAttr kLE = rewriter.getStringAttr(
917 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
918 const StringAttr kNE = rewriter.getStringAttr(
919 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
920 Value q_le_zero = rewriter.create<mhlo::CompareOp>(loc, q, zero, kLE);
921 Value x_not_int = rewriter.create<mhlo::CompareOp>(
922 loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kNE);
923 Value x_domain_error =
924 rewriter.create<mhlo::AndOp>(loc, q_le_zero, x_not_int);
925 output = rewriter.create<mhlo::SelectOp>(loc, x_domain_error, nan, output);
926
927 // For all integer q <= 0, zeta has a pole. The limit is only defined as
928 // +inf if x is and even integer.
929 const StringAttr kEQ = rewriter.getStringAttr(
930 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
931 Value inf = chlo::getConstantLike(rewriter, loc,
932 std::numeric_limits<double>::infinity(), x);
933 Value q_is_int = rewriter.create<mhlo::CompareOp>(
934 loc, q, rewriter.create<mhlo::FloorOp>(loc, q), kEQ);
935 Value at_pole = rewriter.create<mhlo::AndOp>(loc, q_le_zero, q_is_int);
936 Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
937 Value x_is_int = rewriter.create<mhlo::CompareOp>(
938 loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kEQ);
939 Value x_is_even = rewriter.create<mhlo::CompareOp>(
940 loc, rewriter.create<mhlo::RemOp>(loc, x, two), zero, kEQ);
941 Value x_is_even_int = rewriter.create<mhlo::AndOp>(loc, x_is_int, x_is_even);
942 output = rewriter.create<mhlo::SelectOp>(
943 loc, at_pole,
944 rewriter.create<mhlo::SelectOp>(loc, x_is_even_int, inf, nan), output);
945
946 // For x = 1, this is the harmonic series and diverges.
947 output = rewriter.create<mhlo::SelectOp>(
948 loc, rewriter.create<mhlo::CompareOp>(loc, x, one, kEQ), inf, output);
949
950 return output;
951 }
952
MaterializePolygamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)953 Value MaterializePolygamma(ConversionPatternRewriter &rewriter, Location loc,
954 ValueRange args) {
955 PolygammaOp::Adaptor transformed(args);
956 Value n = transformed.n();
957 Value x = transformed.x();
958
959 // Handle integer n > 0.
960 Value one = getConstantLike(rewriter, loc, 1.0, x);
961 Value two = getConstantLike(rewriter, loc, 2.0, x);
962 Value sign = rewriter.create<mhlo::SubOp>(
963 loc,
964 rewriter.create<mhlo::MulOp>(loc, two,
965 rewriter.create<mhlo::RemOp>(loc, n, two)),
966 one);
967 Value n_plus_one = rewriter.create<mhlo::AddOp>(loc, n, one);
968 Value exp_lgamma_np1 = rewriter.create<mhlo::ExpOp>(
969 loc, rewriter.create<chlo::LgammaOp>(loc, n_plus_one));
970 Value zeta = rewriter.create<chlo::ZetaOp>(loc, n_plus_one, x);
971 Value result = rewriter.create<mhlo::MulOp>(
972 loc, rewriter.create<mhlo::MulOp>(loc, sign, exp_lgamma_np1), zeta);
973
974 // Handle n = 0.
975 const StringAttr kEQ = rewriter.getStringAttr(
976 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
977 Value zero = getConstantLike(rewriter, loc, 0.0, x);
978 Value n_eq_zero = rewriter.create<mhlo::CompareOp>(loc, n, zero, kEQ);
979 result = rewriter.create<mhlo::SelectOp>(
980 loc, n_eq_zero, rewriter.create<chlo::DigammaOp>(loc, x), result);
981
982 // Check that n is a natural number. Return nan, otherwise.
983 const StringAttr kNE = rewriter.getStringAttr(
984 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
985 Value non_int = rewriter.create<mhlo::CompareOp>(
986 loc, n, rewriter.create<mhlo::FloorOp>(loc, n), kNE);
987 const StringAttr kLT = rewriter.getStringAttr(
988 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
989 Value negative = rewriter.create<mhlo::CompareOp>(loc, n, zero, kLT);
990 Value non_natural = rewriter.create<mhlo::OrOp>(loc, non_int, negative);
991 return rewriter.create<mhlo::SelectOp>(
992 loc, non_natural,
993 getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
994 x),
995 result);
996 }
997
998 struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
999 using OpConversionPattern<LgammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertLgammaOp1000 LogicalResult matchAndRewrite(
1001 LgammaOp op, ArrayRef<Value> operands,
1002 ConversionPatternRewriter &rewriter) const override {
1003 FloatType min_precision_ty = rewriter.getF32Type();
1004 rewriter.replaceOp(
1005 op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
1006 min_precision_ty, &MaterializeLgamma));
1007 return success();
1008 }
1009 };
1010
1011 struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
1012 using OpConversionPattern<DigammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertDigammaOp1013 LogicalResult matchAndRewrite(
1014 DigammaOp op, ArrayRef<Value> operands,
1015 ConversionPatternRewriter &rewriter) const override {
1016 FloatType min_precision_ty = rewriter.getF32Type();
1017 rewriter.replaceOp(
1018 op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
1019 min_precision_ty, &MaterializeDigamma));
1020 return success();
1021 }
1022 };
1023
MaterializeNextAfter(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)1024 Value MaterializeNextAfter(ConversionPatternRewriter &rewriter, Location loc,
1025 ValueRange operands) {
1026 NextAfterOp::Adaptor transformed(operands);
1027 Value x = transformed.x();
1028 Value y = transformed.y();
1029 auto result_ty = x.getType().cast<ShapedType>();
1030 auto bitwidth = result_ty.getElementType().getIntOrFloatBitWidth();
1031 ImplicitLocOpBuilder b(loc, rewriter);
1032 auto int_ty = result_ty.clone(b.getIntegerType(bitwidth));
1033 auto x_as_int = b.create<mhlo::BitcastConvertOp>(int_ty, x);
1034 auto y_as_int = b.create<mhlo::BitcastConvertOp>(int_ty, y);
1035
1036 // The result is NaN if either "x" or "y" are NaN.
1037 const StringAttr kNE = rewriter.getStringAttr(
1038 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
1039 auto x_is_nan = b.create<mhlo::CompareOp>(x, x, kNE);
1040 auto y_is_nan = b.create<mhlo::CompareOp>(y, y, kNE);
1041 auto nan_input = b.create<mhlo::OrOp>(x_is_nan, y_is_nan);
1042 auto result_for_nan = getConstantLike(
1043 rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
1044 auto result_for_nan_as_int =
1045 b.create<mhlo::BitcastConvertOp>(int_ty, result_for_nan);
1046
1047 // The sign bit is the MSB.
1048 const int64_t sign_bit = int64_t{1} << (bitwidth - 1);
1049 // Discard the sign bit to make the result non-negative.
1050 auto sign_mask = getConstantLike(rewriter, loc, sign_bit, x_as_int);
1051 auto negated_sign_mask = getConstantLike(rewriter, loc, ~sign_bit, x_as_int);
1052 auto x_abs = b.create<mhlo::AndOp>(x_as_int, negated_sign_mask);
1053 auto y_abs = b.create<mhlo::AndOp>(y_as_int, negated_sign_mask);
1054
1055 // When both "x" and "y" are equal, the result is "y".
1056 const StringAttr kEQ = rewriter.getStringAttr(
1057 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
1058 auto x_and_y_are_equal = b.create<mhlo::CompareOp>(x, y, kEQ);
1059 auto result_for_equal = y_as_int;
1060
1061 // When both "x" and "y" are 0, the result is "y". This is a separate case
1062 // from above because "x" and "y" might have a different sign.
1063 auto zero = getConstantLike(rewriter, loc, 0, x_as_int);
1064 auto x_is_zero = b.create<mhlo::CompareOp>(x_abs, zero, kEQ);
1065 auto y_is_zero = b.create<mhlo::CompareOp>(y_abs, zero, kEQ);
1066 auto result_for_both_zero = y_as_int;
1067
1068 auto x_sign = b.create<mhlo::AndOp>(x_as_int, sign_mask);
1069 auto y_sign = b.create<mhlo::AndOp>(y_as_int, sign_mask);
1070
1071 // If from == 0 && to != 0, we need to return the smallest subnormal number
1072 // signed like "to".
1073 auto one = getConstantLike(rewriter, loc, 1, x_as_int);
1074 auto result_for_x_zero_y_non_zero = b.create<mhlo::OrOp>(y_sign, one);
1075
1076 // If the sign of "x" and "y" disagree:
1077 // - we need to make the magnitude of "from" smaller so that it is closer to
1078 // zero.
1079 //
1080 // Otherwise the signs agree:
1081 // - "x" with a magnitude larger than "y" means we need to make the magnitude
1082 // smaller.
1083 // - "x" with a magnitude smaller than "y" means we need to make the magnitude
1084 // larger.
1085 auto signs_disagree = b.create<mhlo::CompareOp>(x_sign, y_sign, kNE);
1086 const StringAttr kGT = rewriter.getStringAttr(
1087 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::GT));
1088 auto x_magnitude_larger_than_y = b.create<mhlo::CompareOp>(x_abs, y_abs, kGT);
1089 auto result_has_smaller_magnitude =
1090 b.create<mhlo::OrOp>(x_magnitude_larger_than_y, signs_disagree);
1091 auto minus_one = getConstantLike(rewriter, loc, -1, x_as_int);
1092 auto magnitude_adjustment =
1093 b.create<mhlo::SelectOp>(result_has_smaller_magnitude, minus_one, one);
1094 Value result = b.create<mhlo::AddOp>(x_as_int, magnitude_adjustment);
1095 // Handle from == +-0.
1096 result = b.create<mhlo::SelectOp>(
1097 x_is_zero,
1098 b.create<mhlo::SelectOp>(y_is_zero, result_for_both_zero,
1099 result_for_x_zero_y_non_zero),
1100 result);
1101 // Handle from == to.
1102 result =
1103 b.create<mhlo::SelectOp>(x_and_y_are_equal, result_for_equal, result);
1104 // Handle isnan(x) || isnan(y).
1105 result = b.create<mhlo::SelectOp>(nan_input, result_for_nan_as_int, result);
1106
1107 // Cast back to the original type.
1108 return b.create<mhlo::BitcastConvertOp>(result_ty, result);
1109 }
1110
1111 struct ConvertNextAfterOp : public OpConversionPattern<NextAfterOp> {
1112 using OpConversionPattern<NextAfterOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertNextAfterOp1113 LogicalResult matchAndRewrite(
1114 NextAfterOp op, ArrayRef<Value> operands,
1115 ConversionPatternRewriter &rewriter) const override {
1116 rewriter.replaceOp(op,
1117 MaterializeNextAfter(rewriter, op.getLoc(), operands));
1118 return success();
1119 }
1120 };
1121
1122 struct ConvertPolygammaOp : public OpConversionPattern<PolygammaOp> {
1123 using OpConversionPattern<PolygammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertPolygammaOp1124 LogicalResult matchAndRewrite(
1125 PolygammaOp op, ArrayRef<Value> operands,
1126 ConversionPatternRewriter &rewriter) const override {
1127 Location loc = op.getLoc();
1128 FloatType min_precision_ty = rewriter.getF32Type();
1129 rewriter.replaceOp(
1130 op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty,
1131 &MaterializePolygamma));
1132 return success();
1133 }
1134 };
1135
MaterializeSinhApproximationForLargeX(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)1136 Value MaterializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter,
1137 Location loc, ValueRange operands) {
1138 SinhOp::Adaptor transformed(operands);
1139 Value x = transformed.operand();
1140 auto result_ty = x.getType().cast<ShapedType>();
1141
1142 // TODO(b/190374484): Use mhlo::ConstantLikeOp when it supports complex types.
1143 Value two = rewriter.create<mhlo::ConstOp>(
1144 loc, hlo::GetScalarOfType(getElementTypeOrSelf(x.getType()), 2));
1145 Type extent_tensor_type = shape::getExtentTensorType(x.getContext());
1146 Value uncasted_shape =
1147 rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, x);
1148 Type shape_ty =
1149 RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType());
1150 Value shape = rewriter.create<tensor::CastOp>(loc, shape_ty, uncasted_shape);
1151 Value two_with_x_shape = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1152 loc, result_ty, two, shape, rewriter.getI64TensorAttr({}));
1153
1154 Value log_two = rewriter.create<mhlo::LogOp>(loc, two_with_x_shape);
1155 Value log_one_half = rewriter.create<mhlo::NegOp>(loc, log_two);
1156 Value exp_add = rewriter.create<mhlo::ExpOp>(
1157 loc, rewriter.create<mhlo::AddOp>(loc, x, log_one_half));
1158 Value exp_sub = rewriter.create<mhlo::ExpOp>(
1159 loc, rewriter.create<mhlo::SubOp>(loc, log_one_half, x));
1160 return rewriter.create<mhlo::SubOp>(loc, exp_add, exp_sub);
1161 }
1162
1163 // Express `sinh` as
1164 // sinh(x) = (e^x - e^-x) / 2 if |x| < 1
1165 // = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
MaterializeSinhApproximation(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)1166 Value MaterializeSinhApproximation(ConversionPatternRewriter &rewriter,
1167 Location loc, ValueRange operands) {
1168 Value large_sinh_result =
1169 MaterializeSinhApproximationForLargeX(rewriter, loc, operands);
1170
1171 SinhOp::Adaptor transformed(operands);
1172 Value x = transformed.operand();
1173 const StringAttr kLT = rewriter.getStringAttr(
1174 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
1175 Value exp_x = rewriter.create<mhlo::ExpOp>(loc, x);
1176 Value exp_neg_x =
1177 rewriter.create<mhlo::ExpOp>(loc, rewriter.create<mhlo::NegOp>(loc, x));
1178 Value exp_difference = rewriter.create<mhlo::SubOp>(loc, exp_x, exp_neg_x);
1179 Value two = getConstantLike(rewriter, loc, 2.0, x);
1180 Value small_sinh_result =
1181 rewriter.create<mhlo::DivOp>(loc, exp_difference, two);
1182
1183 Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
1184 Value one = getConstantLike(rewriter, loc, 1.0, x);
1185 Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
1186 return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, small_sinh_result,
1187 large_sinh_result);
1188 }
1189
1190 struct ConvertSinhOp : public OpConversionPattern<SinhOp> {
1191 using OpConversionPattern<SinhOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertSinhOp1192 LogicalResult matchAndRewrite(
1193 SinhOp op, ArrayRef<Value> operands,
1194 ConversionPatternRewriter &rewriter) const override {
1195 SinhOp::Adaptor transformed(operands);
1196 Value x = transformed.operand();
1197 if (x.getType().cast<ShapedType>().getElementType().isa<ComplexType>()) {
1198 rewriter.replaceOp(op, MaterializeSinhApproximationForLargeX(
1199 rewriter, op.getLoc(), operands));
1200 return success();
1201 }
1202 rewriter.replaceOp(op,
1203 MaterializeWithUpcast(rewriter, op.getLoc(), operands,
1204 rewriter.getF32Type(),
1205 &MaterializeSinhApproximation));
1206 return success();
1207 }
1208 };
1209
1210 struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
1211 using OpConversionPattern<ZetaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertZetaOp1212 LogicalResult matchAndRewrite(
1213 ZetaOp op, ArrayRef<Value> operands,
1214 ConversionPatternRewriter &rewriter) const override {
1215 Location loc = op.getLoc();
1216 FloatType min_precision_ty = rewriter.getF32Type();
1217 rewriter.replaceOp(
1218 op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty,
1219 &MaterializeZeta));
1220 return success();
1221 }
1222 };
1223
1224 struct ConvertSelectOp : public OpConversionPattern<BroadcastSelectOp> {
1225 using OpConversionPattern<BroadcastSelectOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertSelectOp1226 LogicalResult matchAndRewrite(
1227 BroadcastSelectOp op, ArrayRef<Value> operands,
1228 ConversionPatternRewriter &rewriter) const override {
1229 // Only support ranked operands.
1230 typename BroadcastSelectOp::Adaptor transformed(operands);
1231 Value pred = transformed.pred();
1232 Value on_true = transformed.on_true();
1233 Value on_false = transformed.on_false();
1234 auto pred_type = pred.getType().dyn_cast<RankedTensorType>();
1235 auto on_true_type = on_true.getType().dyn_cast<RankedTensorType>();
1236 auto on_false_type = on_false.getType().dyn_cast<RankedTensorType>();
1237 auto result_type = op.getResult().getType().dyn_cast<RankedTensorType>();
1238 if (!pred_type || !on_true_type || !on_false_type || !result_type) {
1239 return failure();
1240 }
1241
1242 auto loc = op.getLoc();
1243
1244 Value pred_shape = rewriter.createOrFold<shape::ShapeOfOp>(loc, pred);
1245 Value on_true_shape = rewriter.createOrFold<shape::ShapeOfOp>(loc, on_true);
1246 Value on_false_shape =
1247 rewriter.createOrFold<shape::ShapeOfOp>(loc, on_false);
1248 int64_t result_rank = std::max(
1249 {pred_type.getRank(), on_true_type.getRank(), on_false_type.getRank()});
1250
1251 Value broadcastable_cstr =
1252 rewriter.createOrFold<shape::CstrBroadcastableOp>(
1253 loc, ValueRange{pred_shape, on_true_shape, on_false_shape});
1254 auto assuming_op = rewriter.create<shape::AssumingOp>(
1255 loc, ArrayRef<Type>{result_type}, broadcastable_cstr);
1256
1257 OpBuilder::InsertionGuard guard(rewriter);
1258 rewriter.createBlock(&assuming_op.doRegion());
1259
1260 Value result_extents = rewriter.createOrFold<shape::BroadcastOp>(
1261 loc, shape::getExtentTensorType(op.getContext()),
1262 ValueRange{pred_shape, on_true_shape, on_false_shape},
1263 /*error=*/nullptr);
1264 auto shape_type =
1265 RankedTensorType::get({result_rank}, rewriter.getIndexType());
1266 result_extents =
1267 rewriter.createOrFold<tensor::CastOp>(loc, shape_type, result_extents);
1268
1269 Value broadcasted_pred = pred;
1270 // Pred has an implicit broadcast for scalars, so use that when convenient.
1271 if (pred_type.getRank() > 0) {
1272 auto pred_broadcast_dimensions = llvm::to_vector<4>(
1273 llvm::seq<int64_t>(result_rank - pred_type.getRank(), result_rank));
1274 broadcasted_pred = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1275 loc,
1276 RankedTensorType::get(result_type.getShape(),
1277 pred_type.getElementType()),
1278 pred, result_extents,
1279 rewriter.getI64TensorAttr(pred_broadcast_dimensions));
1280 }
1281 auto on_true_broadcast_dimensions = llvm::to_vector<4>(
1282 llvm::seq<int64_t>(result_rank - on_true_type.getRank(), result_rank));
1283 Value broadcasted_on_true = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1284 loc,
1285 RankedTensorType::get(result_type.getShape(),
1286 on_true_type.getElementType()),
1287 on_true, result_extents,
1288 rewriter.getI64TensorAttr(on_true_broadcast_dimensions));
1289 auto on_false_broadcast_dimensions = llvm::to_vector<4>(
1290 llvm::seq<int64_t>(result_rank - on_false_type.getRank(), result_rank));
1291 Value broadcasted_on_false = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1292 loc,
1293 RankedTensorType::get(result_type.getShape(),
1294 on_false_type.getElementType()),
1295 on_false, result_extents,
1296 rewriter.getI64TensorAttr(on_false_broadcast_dimensions));
1297
1298 // And generate the final non-broadcasted ternary op.
1299 Value final_result = rewriter.create<mhlo::SelectOp>(
1300 loc, result_type, broadcasted_pred, broadcasted_on_true,
1301 broadcasted_on_false);
1302 rewriter.create<shape::AssumingYieldOp>(loc, final_result);
1303 rewriter.replaceOp(op, {assuming_op.getResult(0)});
1304 return success();
1305 }
1306 };
1307
1308 // Converts binary ops that statically are determined to not broadcast directly
1309 // to the corresponding mhlo non-broadcasting op.
1310 template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
1311 struct ConvertTrivialNonBroadcastBinaryOp
1312 : public OpConversionPattern<ChloOpTy> {
1313 using OpConversionPattern<ChloOpTy>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertTrivialNonBroadcastBinaryOp1314 LogicalResult matchAndRewrite(
1315 ChloOpTy op, ArrayRef<Value> operands,
1316 ConversionPatternRewriter &rewriter) const override {
1317 // Only rewrite for statically determinable non-broadcasting cases.
1318 typename ChloOpTy::Adaptor transformed(operands);
1319 auto lhs_type =
1320 transformed.lhs().getType().template dyn_cast<RankedTensorType>();
1321 auto rhs_type =
1322 transformed.rhs().getType().template dyn_cast<RankedTensorType>();
1323 if (!lhs_type || !rhs_type) return failure();
1324
1325 // Requires rank broadcast.
1326 if (lhs_type.getRank() != rhs_type.getRank()) return failure();
1327 // Any dynamic dimension may require broadcasting and requires more
1328 // analysis.
1329 if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape())
1330 return failure();
1331
1332 for (auto extents : llvm::zip(lhs_type.getShape(), rhs_type.getShape())) {
1333 auto lhs_extent = std::get<0>(extents);
1334 auto rhs_extent = std::get<1>(extents);
1335 if (lhs_extent != rhs_extent) {
1336 return failure();
1337 }
1338 }
1339
1340 rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(),
1341 operands, rewriter)});
1342 return success();
1343 }
1344 };
1345
1346 // Converts a binary op with ranked broadcasting operands to explicitly
1347 // broadcast and invoke the corresponding mhlo non-broadcasting op.
1348 // Note that dynamic broadcasting supported by this pattern is only valid for
1349 // "numpy" broadcasting semantics as defined here:
1350 // https://docs.scipy.org/doc/numpy/reference/ufuncs.html
1351 // Specifically, this includes the following cases:
1352 // - Same rank broadcast (operands have the same static rank).
1353 // - Different-rank broadcast, either without a broadcast_dims attribte or
1354 // with the broadcast_dims attribute set to map to a prefix padding.
1355 // - Legal combinations of degenerate (1-dim) implicit broadcasting.
1356 // The restriction on broadcast_dims derives from the definition of the
1357 // `shape.broadcast` op, which only supports prefix-padding.
1358 template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
1359 struct ConvertRankedDynamicBroadcastBinaryOp
1360 : public OpConversionPattern<ChloOpTy> {
1361 using OpConversionPattern<ChloOpTy>::OpConversionPattern;
matchAndRewritemlir::chlo::__anon8f4f31520111::ConvertRankedDynamicBroadcastBinaryOp1362 LogicalResult matchAndRewrite(
1363 ChloOpTy op, ArrayRef<Value> operands,
1364 ConversionPatternRewriter &rewriter) const override {
1365 // Only support ranked operands.
1366 typename ChloOpTy::Adaptor transformed(operands);
1367 Value lhs = transformed.lhs();
1368 Value rhs = transformed.rhs();
1369 auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
1370 auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
1371 auto result_type =
1372 op.getResult().getType().template dyn_cast<RankedTensorType>();
1373 if (!lhs_type || !rhs_type || !result_type) return failure();
1374
1375 // Check for "numpy"-style rank broadcast.
1376 auto broadcast_dimensions = op.broadcast_dimensions();
1377 if (broadcast_dimensions &&
1378 !hlo::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) {
1379 // Note: It is unclear whether the general specification of explicit
1380 // broadcast_dimensions on binary ops is a feature we want to carry
1381 // forward. While it can technically be implemented for ranked-dynamic,
1382 // it is incompatible with unranked inputs. If this warning is emitted
1383 // in real programs, it is an indication that the feature should be
1384 // implemented versus just falling back on the more standard definition
1385 // of numpy-like prefix-padding.
1386 op.emitWarning() << "unsupported non prefix-padded dynamic rank "
1387 << "broadcast_dimensions = " << *broadcast_dimensions;
1388 return failure();
1389 }
1390
1391 // Compute result shape.
1392 auto loc = op.getLoc();
1393
1394 // Insert a constraint on the shapes being broadcastable and insert all
1395 // future code into an assuming block reliant on the constraint.
1396 Value lhs_shape = rewriter.create<shape::ShapeOfOp>(loc, lhs);
1397 Value rhs_shape = rewriter.create<shape::ShapeOfOp>(loc, rhs);
1398 auto broadcastable_cstr =
1399 rewriter.create<shape::CstrBroadcastableOp>(loc, lhs_shape, rhs_shape);
1400 auto assuming_op = rewriter.create<shape::AssumingOp>(
1401 loc, ArrayRef<Type>{result_type}, broadcastable_cstr.result());
1402
1403 OpBuilder::InsertionGuard guard(rewriter);
1404 rewriter.createBlock(&assuming_op.doRegion());
1405
1406 int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
1407 Value result_extents =
1408 hlo::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
1409 rewriter);
1410
1411 // Note that we unconditionally emit DynamicBroadcastInDim ops and let
1412 // downstream canonicalizations fold them away if possible. This is
1413 // because, in the dynamic case, there are many corner cases regarding
1414 // when it is safe to omit, and some of them require analysis to prove
1415 // properly.
1416 auto lhs_broadcast_dimensions = llvm::to_vector<4>(
1417 llvm::seq<int64_t>(result_rank - lhs_type.getRank(), result_rank));
1418 Value broadcasted_lhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1419 loc,
1420 RankedTensorType::get(result_type.getShape(),
1421 lhs_type.getElementType()),
1422 lhs, result_extents,
1423 rewriter.getI64TensorAttr(lhs_broadcast_dimensions));
1424 auto rhs_broadcast_dimensions = llvm::to_vector<4>(
1425 llvm::seq<int64_t>(result_rank - rhs_type.getRank(), result_rank));
1426 Value broadcasted_rhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1427 loc,
1428 RankedTensorType::get(result_type.getShape(),
1429 rhs_type.getElementType()),
1430 rhs, result_extents,
1431 rewriter.getI64TensorAttr(rhs_broadcast_dimensions));
1432
1433 // And generate the final non-broadcasted binary op.
1434 Value final_result = Adaptor::CreateOp(
1435 op, result_type, {broadcasted_lhs, broadcasted_rhs}, rewriter);
1436 rewriter.create<shape::AssumingYieldOp>(loc, final_result);
1437 rewriter.replaceOp(op, {assuming_op.getResult(0)});
1438 return success();
1439 }
1440 };
1441
1442 class ConvertDynamicReshapeOp
1443 : public OpRewritePattern<chlo::DynamicReshapeOp> {
1444 public:
1445 using OpRewritePattern::OpRewritePattern;
1446
matchAndRewrite(chlo::DynamicReshapeOp op,PatternRewriter & rewriter) const1447 LogicalResult matchAndRewrite(chlo::DynamicReshapeOp op,
1448 PatternRewriter &rewriter) const override {
1449 auto loc = op.getLoc();
1450 auto tensor = op.operand();
1451 auto shape = op.output_shape();
1452
1453 auto shape_ty = shape.getType().cast<ShapedType>();
1454 auto result_ty = op.getType().cast<ShapedType>();
1455
1456 Value input_shape = rewriter.create<shape::ShapeOfOp>(loc, tensor);
1457 Value num_els = rewriter.create<shape::NumElementsOp>(loc, input_shape);
1458 Value cstr = rewriter.create<mhlo::CstrReshapableOp>(loc, num_els, shape);
1459 rewriter.replaceOpWithNewOp<shape::AssumingOp>(
1460 op, cstr, [&](OpBuilder &b, Location l) {
1461 Value computed_shape = b.create<mhlo::ComputeReshapeShapeOp>(
1462 l, shape_ty, num_els, shape);
1463 SmallVector<Value> result;
1464 result.push_back(b.create<mhlo::DynamicReshapeOp>(
1465 l, result_ty, tensor, computed_shape));
1466 return result;
1467 });
1468
1469 return success();
1470 }
1471 };
1472
1473 #include "generated_chlo_legalize_to_hlo.inc"
1474 } // namespace
1475
PopulateChloBroadcastingPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1476 void PopulateChloBroadcastingPatterns(MLIRContext *context,
1477 OwningRewritePatternList *patterns) {
1478 // Instantiate conversion templates for conforming binary elementwise ops
1479 // that do not have different dtypes between operands and results and do
1480 // not have special attributes that need to be preserved.
1481 PopulateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>(
1482 context, patterns, 10);
1483 PopulateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
1484 context, patterns, 5);
1485 patterns
1486 ->insert<ConvertConstantLikeOp, ConvertDynamicReshapeOp, ConvertSelectOp>(
1487 context);
1488 }
1489
PopulateDecomposeChloPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1490 void PopulateDecomposeChloPatterns(MLIRContext *context,
1491 OwningRewritePatternList *patterns) {
1492 populateWithGenerated(*patterns);
1493
1494 // Other patterns.
1495 // clang-format off
1496 patterns->insert<ConvertCoshOp,
1497 ConvertDigammaOp,
1498 ConvertErfOp,
1499 ConvertErfcOp,
1500 ConvertLgammaOp,
1501 ConvertNextAfterOp,
1502 ConvertPolygammaOp,
1503 ConvertSinhOp,
1504 ConvertZetaOp>(context);
1505 // clang-format on
1506 }
1507
1508 } // namespace chlo
1509 } // namespace mlir
1510