• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <iterator>
18 #include <memory>
19 
20 #include "llvm/Support/raw_ostream.h"
21 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"  // from @llvm-project
22 #include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
26 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
27 #include "mlir/IR/Matchers.h"  // from @llvm-project
28 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
29 #include "mlir/IR/Region.h"  // from @llvm-project
30 #include "mlir/Support/LLVM.h"  // from @llvm-project
31 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
32 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
33 #include "mlir/Transforms/LoopUtils.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
36 #include "tensorflow/compiler/mlir/tfr/passes/passes.h"
37 
38 //===----------------------------------------------------------------------===//
39 // Canonicalization patterns for the scf.for and scf.if ops. They are used to
40 // optimize the control flow in the tfr function. Technically, both patterns
41 // should be upstreamed to be part of the op definition.
42 // TODO(fengliuai): sync with the llvm upstream for both patterns.
43 //
44 namespace mlir {
45 namespace TFR {
46 
47 namespace {
48 
49 class UnrollSCFForOp : public OpRewritePattern<scf::ForOp> {
50   using OpRewritePattern<scf::ForOp>::OpRewritePattern;
51 
52  public:
matchAndRewrite(scf::ForOp for_op,PatternRewriter & rewriter) const53   LogicalResult matchAndRewrite(scf::ForOp for_op,
54                                 PatternRewriter &rewriter) const override {
55     Location loc = for_op.getLoc();
56     APInt lower_bound, upper_bound, step;
57     if (!matchPattern(for_op.lowerBound(), m_ConstantInt(&lower_bound)) ||
58         !matchPattern(for_op.upperBound(), m_ConstantInt(&upper_bound)) ||
59         !matchPattern(for_op.step(), m_ConstantInt(&step))) {
60       return failure();
61     }
62     uint64_t trip_count = (upper_bound - lower_bound).sdiv(step).getZExtValue();
63     if (trip_count <= 0) return failure();
64 
65     // TODO(fengliuai): use loopUnrollByFactor once the iter_arg is supported
66 
67     Block *single_block = for_op.getBody();
68     BlockAndValueMapping mapping;
69     Value iv = for_op.getInductionVar();
70     for (auto iter_op :
71          llvm::zip(for_op.getRegionIterArgs(), for_op.initArgs())) {
72       mapping.map(std::get<0>(iter_op), std::get<1>(iter_op));
73     }
74     mapping.map(iv, for_op.lowerBound());
75     for (auto i = 0; i < trip_count; ++i) {
76       if (!iv.use_empty()) {
77         // iv' = iv + step * i;
78         Value iter = rewriter.create<ConstantIndexOp>(loc, i);
79         Value step_cst =
80             rewriter.create<ConstantIndexOp>(loc, step.getSExtValue());
81         Value stride = rewriter.create<MulIOp>(loc, step_cst, iter);
82         Value iv_unroll =
83             rewriter.create<AddIOp>(loc, mapping.lookup(iv), stride);
84         mapping.map(iv, iv_unroll);
85       }
86 
87       Operation *terminator_op;
88       for (auto it = single_block->begin(); it != single_block->end(); ++it) {
89         terminator_op = rewriter.clone(*it, mapping);
90       }
91       // Map the block arguments to the yield results.
92       for (auto iter_op : llvm::zip(for_op.getRegionIterArgs(),
93                                     terminator_op->getOperands())) {
94         mapping.map(std::get<0>(iter_op), std::get<1>(iter_op));
95       }
96       rewriter.eraseOp(terminator_op);
97     }
98     SmallVector<Value, 4> returned;
99     for (Value arg : for_op.getRegionIterArgs()) {
100       returned.push_back(mapping.lookup(arg));
101     }
102     rewriter.replaceOp(for_op, returned);
103     return success();
104   }
105 };
106 
107 // TODO(fengliuai): up stream this pattern.
108 class SimplifySCFIfOp : public OpRewritePattern<scf::IfOp> {
109   using OpRewritePattern<scf::IfOp>::OpRewritePattern;
110 
111  public:
matchAndRewrite(scf::IfOp if_op,PatternRewriter & rewriter) const112   LogicalResult matchAndRewrite(scf::IfOp if_op,
113                                 PatternRewriter &rewriter) const override {
114     // Then branch
115     if (matchPattern(if_op.condition(), m_NonZero())) {
116       return InlineRegion(if_op.getLoc(), rewriter, if_op, &if_op.thenRegion());
117     }
118 
119     // Else branch
120     if (matchPattern(if_op.condition(), m_Zero())) {
121       if (if_op.elseRegion().empty()) {
122         // Remove the op
123         rewriter.eraseOp(if_op);
124         return success();
125       } else {
126         return InlineRegion(if_op.getLoc(), rewriter, if_op,
127                             &if_op.elseRegion());
128       }
129     }
130 
131     // Not a constant condition
132     return failure();
133   }
134 
135  private:
136   LogicalResult InlineRegion(Location loc, PatternRewriter &rewriter,
137                              Operation *inline_point, Region *region) const;
138 };
139 
InlineRegion(Location loc,PatternRewriter & rewriter,Operation * inline_point,Region * region) const140 LogicalResult SimplifySCFIfOp::InlineRegion(Location loc,
141                                             PatternRewriter &rewriter,
142                                             Operation *inline_point,
143                                             Region *region) const {
144   InlinerInterface interface(loc.getContext());
145   if (failed(inlineRegion(interface, region, inline_point, {},
146                           inline_point->getResults(), loc,
147                           /*shouldCloneInlinedRegion=*/true))) {
148     return failure();
149   }
150 
151   // If the inlining was successful then erase the scf.if op.
152   rewriter.eraseOp(inline_point);
153   return success();
154 }
155 
156 }  // namespace
157 
populateCanonicalizationPatterns(FuncOp func,OwningRewritePatternList & patterns)158 void populateCanonicalizationPatterns(FuncOp func,
159                                       OwningRewritePatternList &patterns) {
160   MLIRContext *context = func.getContext();
161   mlir::Dialect *tf = context->getLoadedDialect<mlir::TF::TensorFlowDialect>();
162   // Load all official canonicalization patterns. Here we skip the
163   // canonicalization of the ops in the tf dialect, because they couldn't
164   // propagate the attributes correctly. These optimization will be played by
165   // bridge.
166   func->walk([&](Operation *op) {
167     if (op->getDialect() != tf) {
168       op->getAbstractOperation()->getCanonicalizationPatterns(patterns,
169                                                               context);
170     }
171   });
172   patterns.insert<UnrollSCFForOp, SimplifySCFIfOp>(context);
173 }
174 
175 }  // namespace TFR
176 }  // namespace mlir
177