1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <iterator>
18 #include <memory>
19
20 #include "llvm/Support/raw_ostream.h"
21 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project
22 #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
23 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
24 #include "mlir/IR/Attributes.h" // from @llvm-project
25 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
26 #include "mlir/IR/MLIRContext.h" // from @llvm-project
27 #include "mlir/IR/Matchers.h" // from @llvm-project
28 #include "mlir/IR/PatternMatch.h" // from @llvm-project
29 #include "mlir/IR/Region.h" // from @llvm-project
30 #include "mlir/Support/LLVM.h" // from @llvm-project
31 #include "mlir/Support/LogicalResult.h" // from @llvm-project
32 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
33 #include "mlir/Transforms/LoopUtils.h" // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
36 #include "tensorflow/compiler/mlir/tfr/passes/passes.h"
37
38 //===----------------------------------------------------------------------===//
39 // Canonicalization patterns for the scf.for and scf.if ops. They are used to
40 // optimize the control flow in the tfr function. Technically, both patterns
41 // should be upstreamed to be part of the op definition.
42 // TODO(fengliuai): sync with the llvm upstream for both patterns.
43 //
44 namespace mlir {
45 namespace TFR {
46
47 namespace {
48
49 class UnrollSCFForOp : public OpRewritePattern<scf::ForOp> {
50 using OpRewritePattern<scf::ForOp>::OpRewritePattern;
51
52 public:
matchAndRewrite(scf::ForOp for_op,PatternRewriter & rewriter) const53 LogicalResult matchAndRewrite(scf::ForOp for_op,
54 PatternRewriter &rewriter) const override {
55 Location loc = for_op.getLoc();
56 APInt lower_bound, upper_bound, step;
57 if (!matchPattern(for_op.lowerBound(), m_ConstantInt(&lower_bound)) ||
58 !matchPattern(for_op.upperBound(), m_ConstantInt(&upper_bound)) ||
59 !matchPattern(for_op.step(), m_ConstantInt(&step))) {
60 return failure();
61 }
62 uint64_t trip_count = (upper_bound - lower_bound).sdiv(step).getZExtValue();
63 if (trip_count <= 0) return failure();
64
65 // TODO(fengliuai): use loopUnrollByFactor once the iter_arg is supported
66
67 Block *single_block = for_op.getBody();
68 BlockAndValueMapping mapping;
69 Value iv = for_op.getInductionVar();
70 for (auto iter_op :
71 llvm::zip(for_op.getRegionIterArgs(), for_op.initArgs())) {
72 mapping.map(std::get<0>(iter_op), std::get<1>(iter_op));
73 }
74 mapping.map(iv, for_op.lowerBound());
75 for (auto i = 0; i < trip_count; ++i) {
76 if (!iv.use_empty()) {
77 // iv' = iv + step * i;
78 Value iter = rewriter.create<ConstantIndexOp>(loc, i);
79 Value step_cst =
80 rewriter.create<ConstantIndexOp>(loc, step.getSExtValue());
81 Value stride = rewriter.create<MulIOp>(loc, step_cst, iter);
82 Value iv_unroll =
83 rewriter.create<AddIOp>(loc, mapping.lookup(iv), stride);
84 mapping.map(iv, iv_unroll);
85 }
86
87 Operation *terminator_op;
88 for (auto it = single_block->begin(); it != single_block->end(); ++it) {
89 terminator_op = rewriter.clone(*it, mapping);
90 }
91 // Map the block arguments to the yield results.
92 for (auto iter_op : llvm::zip(for_op.getRegionIterArgs(),
93 terminator_op->getOperands())) {
94 mapping.map(std::get<0>(iter_op), std::get<1>(iter_op));
95 }
96 rewriter.eraseOp(terminator_op);
97 }
98 SmallVector<Value, 4> returned;
99 for (Value arg : for_op.getRegionIterArgs()) {
100 returned.push_back(mapping.lookup(arg));
101 }
102 rewriter.replaceOp(for_op, returned);
103 return success();
104 }
105 };
106
107 // TODO(fengliuai): up stream this pattern.
108 class SimplifySCFIfOp : public OpRewritePattern<scf::IfOp> {
109 using OpRewritePattern<scf::IfOp>::OpRewritePattern;
110
111 public:
matchAndRewrite(scf::IfOp if_op,PatternRewriter & rewriter) const112 LogicalResult matchAndRewrite(scf::IfOp if_op,
113 PatternRewriter &rewriter) const override {
114 // Then branch
115 if (matchPattern(if_op.condition(), m_NonZero())) {
116 return InlineRegion(if_op.getLoc(), rewriter, if_op, &if_op.thenRegion());
117 }
118
119 // Else branch
120 if (matchPattern(if_op.condition(), m_Zero())) {
121 if (if_op.elseRegion().empty()) {
122 // Remove the op
123 rewriter.eraseOp(if_op);
124 return success();
125 } else {
126 return InlineRegion(if_op.getLoc(), rewriter, if_op,
127 &if_op.elseRegion());
128 }
129 }
130
131 // Not a constant condition
132 return failure();
133 }
134
135 private:
136 LogicalResult InlineRegion(Location loc, PatternRewriter &rewriter,
137 Operation *inline_point, Region *region) const;
138 };
139
InlineRegion(Location loc,PatternRewriter & rewriter,Operation * inline_point,Region * region) const140 LogicalResult SimplifySCFIfOp::InlineRegion(Location loc,
141 PatternRewriter &rewriter,
142 Operation *inline_point,
143 Region *region) const {
144 InlinerInterface interface(loc.getContext());
145 if (failed(inlineRegion(interface, region, inline_point, {},
146 inline_point->getResults(), loc,
147 /*shouldCloneInlinedRegion=*/true))) {
148 return failure();
149 }
150
151 // If the inlining was successful then erase the scf.if op.
152 rewriter.eraseOp(inline_point);
153 return success();
154 }
155
156 } // namespace
157
populateCanonicalizationPatterns(FuncOp func,OwningRewritePatternList & patterns)158 void populateCanonicalizationPatterns(FuncOp func,
159 OwningRewritePatternList &patterns) {
160 MLIRContext *context = func.getContext();
161 mlir::Dialect *tf = context->getLoadedDialect<mlir::TF::TensorFlowDialect>();
162 // Load all official canonicalization patterns. Here we skip the
163 // canonicalization of the ops in the tf dialect, because they couldn't
164 // propagate the attributes correctly. These optimization will be played by
165 // bridge.
166 func->walk([&](Operation *op) {
167 if (op->getDialect() != tf) {
168 op->getAbstractOperation()->getCanonicalizationPatterns(patterns,
169 context);
170 }
171 });
172 patterns.insert<UnrollSCFForOp, SimplifySCFIfOp>(context);
173 }
174
175 } // namespace TFR
176 } // namespace mlir
177