• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements the lowering for trigonometric standard ops to
17 // approximations.
18 
19 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
22 #include "mlir/Dialect/Math/IR/Math.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 
28 namespace mlir {
29 namespace mhlo {
30 namespace {
31 
32 template <typename OpTy>
33 class ApproximateOnExtendedF32Lowering : public OpRewritePattern<OpTy> {
34  public:
ApproximateOnExtendedF32Lowering(MLIRContext * ctx)35   explicit ApproximateOnExtendedF32Lowering(MLIRContext *ctx)
36       : OpRewritePattern<OpTy>(ctx, /*benefit=*/100) {}
37 
38   virtual Value emitApproximation(ValueRange, Location,
39                                   PatternRewriter &) const = 0;
40 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const41   LogicalResult matchAndRewrite(OpTy op,
42                                 PatternRewriter &rewriter) const override {
43     Location loc = op.getLoc();
44     auto raw_args = op.getOperation()->getOperands();
45 
46     // Supports only f16 and f32 for now.
47     if (!op.getType().isF16() && !op.getType().isF32()) return failure();
48 
49     // Extend operands to f32 if needed and possible.
50     SmallVector<Value, 2> f32_args;
51     f32_args.reserve(raw_args.size());
52     for (Value arg : raw_args) {
53       // Similar to XLA, do not rewrite f64 as precision might matter.
54       Type arg_ty = arg.getType();
55       if (arg_ty.isF64()) return failure();
56 
57       if (arg_ty.isF16())
58         arg = rewriter.create<FPExtOp>(loc, arg, rewriter.getF32Type());
59 
60       // If we still do not have f32, fail.
61       if (!arg.getType().isF32()) return failure();
62 
63       f32_args.push_back(arg);
64     }
65 
66     Value result = emitApproximation(f32_args, loc, rewriter);
67     assert(result.getType().isF32() && "Expect f32 intermediate result.");
68 
69     // Truncate back if needed.
70     if (op.getType().isF16())
71       result = rewriter.create<FPTruncOp>(loc, result, rewriter.getF16Type());
72 
73     rewriter.replaceOp(op, {result});
74     return success();
75   }
76 };
77 
78 // This approximation resembles Eigen and realizes a constant approximation for
79 // the +/-1 limits on top.
80 // https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Core/MathFunctionsImpl.h
81 class ApproximateTanhLowering
82     : public ApproximateOnExtendedF32Lowering<math::TanhOp> {
83  public:
ApproximateTanhLowering(MLIRContext * ctx)84   explicit ApproximateTanhLowering(MLIRContext *ctx)
85       : ApproximateOnExtendedF32Lowering<math::TanhOp>(ctx) {}
86 
87   // Emits the fast tanh approximation that is also used by XLA.
emitApproximation(ValueRange args,Location loc,PatternRewriter & rewriter) const88   Value emitApproximation(ValueRange args, Location loc,
89                           PatternRewriter &rewriter) const override {
90     Value input = args.front();
91     assert(input.getType().isF32());
92     static constexpr std::array<float, 7> numerator_coeffs{
93         -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
94         5.12229709037114e-08f,  1.48572235717979e-05f, 6.37261928875436e-04f,
95         4.89352455891786e-03f};
96     static constexpr std::array<float, 4> denominator_coeffs{
97         1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
98         4.89352518554385e-03f};
99 
100     // Materialize polynomial approximation.
101     Value input_squared = rewriter.create<MulFOp>(loc, input, input);
102     Value numerator = rewriter.create<ConstantOp>(
103         loc, rewriter.getF32FloatAttr(numerator_coeffs[0]));
104     for (int i = 1; i < numerator_coeffs.size(); i++) {
105       numerator = rewriter.create<AddFOp>(
106           loc, rewriter.create<MulFOp>(loc, input_squared, numerator),
107           rewriter.create<ConstantOp>(
108               loc, rewriter.getF32FloatAttr(numerator_coeffs[i])));
109     }
110     numerator = rewriter.create<MulFOp>(loc, input, numerator);
111     Value denominator = rewriter.create<ConstantOp>(
112         loc, rewriter.getF32FloatAttr(denominator_coeffs[0]));
113     for (int i = 1; i < denominator_coeffs.size(); i++) {
114       denominator = rewriter.create<AddFOp>(
115           loc, rewriter.create<MulFOp>(loc, input_squared, denominator),
116           rewriter.create<ConstantOp>(
117               loc, rewriter.getF32FloatAttr(denominator_coeffs[i])));
118     }
119     Value approx = rewriter.create<DivFOp>(loc, numerator, denominator);
120 
121     // For small values of |x|, we can approximate tanh(x) = x. For extremely
122     // small values of x (|x| < 1e-37), the other approximation would evaluate
123     // tanh(x) = 0.
124     constexpr float kUseIdentityApprox = 0.0004;
125     Value abs_input = rewriter.create<AbsFOp>(loc, input);
126     Value use_identity_approx = rewriter.create<CmpFOp>(
127         loc, CmpFPredicate::OLT, abs_input,
128         rewriter.create<ConstantOp>(
129             loc, rewriter.getF32FloatAttr(kUseIdentityApprox)));
130     approx = rewriter.create<SelectOp>(loc, use_identity_approx, input, approx);
131 
132     // For very small/large values, use a constant approximation -1/1.
133     Value too_large_input = rewriter.create<CmpFOp>(
134         loc, CmpFPredicate::UGT, input,
135         rewriter.create<ConstantOp>(
136             loc, rewriter.getF32FloatAttr(7.90531110763549805f)));
137     Value too_small_input = rewriter.create<CmpFOp>(
138         loc, CmpFPredicate::ULT, input,
139         rewriter.create<ConstantOp>(
140             loc, rewriter.getF32FloatAttr(-7.90531110763549805f)));
141     Value input_is_nan =
142         rewriter.create<CmpFOp>(loc, CmpFPredicate::UNE, input, input);
143     approx = rewriter.create<SelectOp>(
144         loc, too_large_input,
145         rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0)),
146         approx);
147     approx = rewriter.create<SelectOp>(
148         loc, too_small_input,
149         rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(-1.0)),
150         approx);
151     approx = rewriter.create<SelectOp>(loc, input_is_nan, input, approx);
152 
153     return approx;
154   }
155 };
156 
157 struct LegalizeTrigonometricToApproximationPass
158     : public LegalizeTanhToApproximationPassBase<
159           LegalizeTrigonometricToApproximationPass> {
160   /// Perform the lowering of standard dialect operations to approximations.
runOnFunctionmlir::mhlo::__anonaf04e55a0111::LegalizeTrigonometricToApproximationPass161   void runOnFunction() override {
162     OwningRewritePatternList patterns(&getContext());
163     PopulateTrigonometricToApproximationPatterns(&getContext(), &patterns);
164     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
165   }
166 };
167 
168 }  // anonymous namespace
169 
170 std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
createLegalizeTrigonometricToApproximationPass()171 createLegalizeTrigonometricToApproximationPass() {
172   return std::make_unique<LegalizeTrigonometricToApproximationPass>();
173 }
174 
PopulateTrigonometricToApproximationPatterns(mlir::MLIRContext * context,OwningRewritePatternList * patterns)175 void PopulateTrigonometricToApproximationPatterns(
176     mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
177   // clang-format off
178   patterns->insert<ApproximateTanhLowering>(context);
179   // clang-format on
180 }
181 
182 }  // namespace mhlo
183 }  // namespace mlir
184