1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements the lowering for trigonometric standard ops to
17 // approximations.
18
19 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
22 #include "mlir/Dialect/Math/IR/Math.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27
28 namespace mlir {
29 namespace mhlo {
30 namespace {
31
32 template <typename OpTy>
33 class ApproximateOnExtendedF32Lowering : public OpRewritePattern<OpTy> {
34 public:
ApproximateOnExtendedF32Lowering(MLIRContext * ctx)35 explicit ApproximateOnExtendedF32Lowering(MLIRContext *ctx)
36 : OpRewritePattern<OpTy>(ctx, /*benefit=*/100) {}
37
38 virtual Value emitApproximation(ValueRange, Location,
39 PatternRewriter &) const = 0;
40
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const41 LogicalResult matchAndRewrite(OpTy op,
42 PatternRewriter &rewriter) const override {
43 Location loc = op.getLoc();
44 auto raw_args = op.getOperation()->getOperands();
45
46 // Supports only f16 and f32 for now.
47 if (!op.getType().isF16() && !op.getType().isF32()) return failure();
48
49 // Extend operands to f32 if needed and possible.
50 SmallVector<Value, 2> f32_args;
51 f32_args.reserve(raw_args.size());
52 for (Value arg : raw_args) {
53 // Similar to XLA, do not rewrite f64 as precision might matter.
54 Type arg_ty = arg.getType();
55 if (arg_ty.isF64()) return failure();
56
57 if (arg_ty.isF16())
58 arg = rewriter.create<FPExtOp>(loc, arg, rewriter.getF32Type());
59
60 // If we still do not have f32, fail.
61 if (!arg.getType().isF32()) return failure();
62
63 f32_args.push_back(arg);
64 }
65
66 Value result = emitApproximation(f32_args, loc, rewriter);
67 assert(result.getType().isF32() && "Expect f32 intermediate result.");
68
69 // Truncate back if needed.
70 if (op.getType().isF16())
71 result = rewriter.create<FPTruncOp>(loc, result, rewriter.getF16Type());
72
73 rewriter.replaceOp(op, {result});
74 return success();
75 }
76 };
77
78 // This approximation resembles Eigen and realizes a constant approximation for
79 // the +/-1 limits on top.
80 // https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Core/MathFunctionsImpl.h
81 class ApproximateTanhLowering
82 : public ApproximateOnExtendedF32Lowering<math::TanhOp> {
83 public:
ApproximateTanhLowering(MLIRContext * ctx)84 explicit ApproximateTanhLowering(MLIRContext *ctx)
85 : ApproximateOnExtendedF32Lowering<math::TanhOp>(ctx) {}
86
87 // Emits the fast tanh approximation that is also used by XLA.
emitApproximation(ValueRange args,Location loc,PatternRewriter & rewriter) const88 Value emitApproximation(ValueRange args, Location loc,
89 PatternRewriter &rewriter) const override {
90 Value input = args.front();
91 assert(input.getType().isF32());
92 static constexpr std::array<float, 7> numerator_coeffs{
93 -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
94 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f,
95 4.89352455891786e-03f};
96 static constexpr std::array<float, 4> denominator_coeffs{
97 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
98 4.89352518554385e-03f};
99
100 // Materialize polynomial approximation.
101 Value input_squared = rewriter.create<MulFOp>(loc, input, input);
102 Value numerator = rewriter.create<ConstantOp>(
103 loc, rewriter.getF32FloatAttr(numerator_coeffs[0]));
104 for (int i = 1; i < numerator_coeffs.size(); i++) {
105 numerator = rewriter.create<AddFOp>(
106 loc, rewriter.create<MulFOp>(loc, input_squared, numerator),
107 rewriter.create<ConstantOp>(
108 loc, rewriter.getF32FloatAttr(numerator_coeffs[i])));
109 }
110 numerator = rewriter.create<MulFOp>(loc, input, numerator);
111 Value denominator = rewriter.create<ConstantOp>(
112 loc, rewriter.getF32FloatAttr(denominator_coeffs[0]));
113 for (int i = 1; i < denominator_coeffs.size(); i++) {
114 denominator = rewriter.create<AddFOp>(
115 loc, rewriter.create<MulFOp>(loc, input_squared, denominator),
116 rewriter.create<ConstantOp>(
117 loc, rewriter.getF32FloatAttr(denominator_coeffs[i])));
118 }
119 Value approx = rewriter.create<DivFOp>(loc, numerator, denominator);
120
121 // For small values of |x|, we can approximate tanh(x) = x. For extremely
122 // small values of x (|x| < 1e-37), the other approximation would evaluate
123 // tanh(x) = 0.
124 constexpr float kUseIdentityApprox = 0.0004;
125 Value abs_input = rewriter.create<AbsFOp>(loc, input);
126 Value use_identity_approx = rewriter.create<CmpFOp>(
127 loc, CmpFPredicate::OLT, abs_input,
128 rewriter.create<ConstantOp>(
129 loc, rewriter.getF32FloatAttr(kUseIdentityApprox)));
130 approx = rewriter.create<SelectOp>(loc, use_identity_approx, input, approx);
131
132 // For very small/large values, use a constant approximation -1/1.
133 Value too_large_input = rewriter.create<CmpFOp>(
134 loc, CmpFPredicate::UGT, input,
135 rewriter.create<ConstantOp>(
136 loc, rewriter.getF32FloatAttr(7.90531110763549805f)));
137 Value too_small_input = rewriter.create<CmpFOp>(
138 loc, CmpFPredicate::ULT, input,
139 rewriter.create<ConstantOp>(
140 loc, rewriter.getF32FloatAttr(-7.90531110763549805f)));
141 Value input_is_nan =
142 rewriter.create<CmpFOp>(loc, CmpFPredicate::UNE, input, input);
143 approx = rewriter.create<SelectOp>(
144 loc, too_large_input,
145 rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0)),
146 approx);
147 approx = rewriter.create<SelectOp>(
148 loc, too_small_input,
149 rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(-1.0)),
150 approx);
151 approx = rewriter.create<SelectOp>(loc, input_is_nan, input, approx);
152
153 return approx;
154 }
155 };
156
157 struct LegalizeTrigonometricToApproximationPass
158 : public LegalizeTanhToApproximationPassBase<
159 LegalizeTrigonometricToApproximationPass> {
160 /// Perform the lowering of standard dialect operations to approximations.
runOnFunctionmlir::mhlo::__anonaf04e55a0111::LegalizeTrigonometricToApproximationPass161 void runOnFunction() override {
162 OwningRewritePatternList patterns(&getContext());
163 PopulateTrigonometricToApproximationPatterns(&getContext(), &patterns);
164 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
165 }
166 };
167
168 } // anonymous namespace
169
170 std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
createLegalizeTrigonometricToApproximationPass()171 createLegalizeTrigonometricToApproximationPass() {
172 return std::make_unique<LegalizeTrigonometricToApproximationPass>();
173 }
174
PopulateTrigonometricToApproximationPatterns(mlir::MLIRContext * context,OwningRewritePatternList * patterns)175 void PopulateTrigonometricToApproximationPatterns(
176 mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
177 // clang-format off
178 patterns->insert<ApproximateTanhLowering>(context);
179 // clang-format on
180 }
181
182 } // namespace mhlo
183 } // namespace mlir
184