• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of tanh op.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 
26 /// Expands tanh op into
27 ///   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
28 ///   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0
29 struct TanhOpConverter : public OpRewritePattern<TanhOp> {
30 public:
31   using OpRewritePattern::OpRewritePattern;
32 
matchAndRewrite__anon071702000111::TanhOpConverter33   LogicalResult matchAndRewrite(TanhOp op,
34                                 PatternRewriter &rewriter) const final {
35     auto floatType = op.operand().getType();
36     Location loc = op.getLoc();
37     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
38     auto floatTwo = rewriter.getFloatAttr(floatType, 2.0);
39     Value one = rewriter.create<ConstantOp>(loc, floatOne);
40     Value two = rewriter.create<ConstantOp>(loc, floatTwo);
41     Value doubledX = rewriter.create<MulFOp>(loc, op.operand(), two);
42 
43     // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
44     Value negDoubledX = rewriter.create<NegFOp>(loc, doubledX);
45     Value exp2x = rewriter.create<ExpOp>(loc, negDoubledX);
46     Value dividend = rewriter.create<SubFOp>(loc, one, exp2x);
47     Value divisor = rewriter.create<AddFOp>(loc, one, exp2x);
48     Value positiveRes = rewriter.create<DivFOp>(loc, dividend, divisor);
49 
50     // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
51     exp2x = rewriter.create<ExpOp>(loc, doubledX);
52     dividend = rewriter.create<SubFOp>(loc, exp2x, one);
53     divisor = rewriter.create<AddFOp>(loc, exp2x, one);
54     Value negativeRes = rewriter.create<DivFOp>(loc, dividend, divisor);
55 
56     // tanh(x) = x >= 0 ? positiveRes : negativeRes
57     auto floatZero = rewriter.getFloatAttr(floatType, 0.0);
58     Value zero = rewriter.create<ConstantOp>(loc, floatZero);
59     Value cmpRes =
60         rewriter.create<CmpFOp>(loc, CmpFPredicate::OGE, op.operand(), zero);
61     rewriter.replaceOpWithNewOp<SelectOp>(op, cmpRes, positiveRes, negativeRes);
62     return success();
63   }
64 };
65 } // namespace
66 
populateExpandTanhPattern(OwningRewritePatternList & patterns,MLIRContext * ctx)67 void mlir::populateExpandTanhPattern(OwningRewritePatternList &patterns,
68                                      MLIRContext *ctx) {
69   patterns.insert<TanhOpConverter>(ctx);
70 }
71