1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
17 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
21 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
22
23 namespace mlir {
24 namespace TF {
25
26 namespace {
27
28 /// Move the ops of `source_block` into `destination_block`, keeping the later's
29 /// block arguments' type as `block_arguments_type`.
moveBlock(Block * source_block,Block * destination_block,TypeRange block_arguments_type,PatternRewriter & rewriter)30 static void moveBlock(Block* source_block, Block* destination_block,
31 TypeRange block_arguments_type,
32 PatternRewriter& rewriter) {
33 // If `destination_block` isn't empty, erase its terminator to ensure that it
34 // never contains two terminator-like ops after merging.
35 if (!destination_block->empty())
36 rewriter.eraseOp(destination_block->getTerminator());
37
38 destination_block->addArguments(
39 block_arguments_type,
40 SmallVector<Location>(block_arguments_type.size(),
41 source_block->getParent()->getLoc()));
42 rewriter.mergeBlocks(source_block, destination_block,
43 destination_block->getArguments());
44 }
45
46 /// Convert the `tf.IfRegion` op to the `scf.if` op.
47 class ConvertIfRegionOp : public OpRewritePattern<IfRegionOp> {
48 public:
49 using OpRewritePattern<IfRegionOp>::OpRewritePattern;
50
matchAndRewrite(IfRegionOp op,PatternRewriter & rewriter) const51 LogicalResult matchAndRewrite(IfRegionOp op,
52 PatternRewriter& rewriter) const override {
53 // Creates the `then` or `else` region of the `scf.if` op. Note that
54 // `tf_then_or_else_region` is the `then` or `else` region of the
55 // `tf.IfRegion` op and `scf_then_or_else_region` is the `then` or `else`
56 // region of the new `scf.if` op. Further, `tf_if_region_return_type` is the
57 // list of return types of the `tf.IfRegion` op.
58 auto createScfThenOrElse = [](Region& tf_then_or_else_region,
59 Region& scf_then_or_else_region,
60 TypeRange tf_if_region_return_type,
61 PatternRewriter& rewriter) {
62 // Move the first block of `tf_then_or_else_region` into the first block
63 // of `scf_then_or_else_region` and do not add any arguments to the block.
64 moveBlock(&tf_then_or_else_region.front(),
65 &scf_then_or_else_region.front(), TypeRange(), rewriter);
66
67 // Replace the current terminator (a `tf.Yield` op) with an `scf.yield`
68 // op. The input of the `scf.yield` op is a list of results of `tf.Cast`
69 // ops, each of which casts an operand of the current terminator to the
70 // corresponding result type of the `tf.IfRegion` op.
71 Operation* current_terminator =
72 scf_then_or_else_region.front().getTerminator();
73 rewriter.setInsertionPoint(current_terminator);
74 SmallVector<Value, 4> scf_yield_input;
75 for (auto it : llvm::zip(tf_if_region_return_type,
76 current_terminator->getOperands())) {
77 scf_yield_input.push_back(rewriter.create<CastOp>(
78 current_terminator->getLoc(), std::get<0>(it), std::get<1>(it)));
79 }
80
81 rewriter.replaceOpWithNewOp<scf::YieldOp>(current_terminator,
82 scf_yield_input);
83 };
84
85 Location loc = op.getLoc();
86
87 // The condition of an `scf.if` op is a 1-bit signless integer. Whereas, the
88 // condition of the `tf.IfRegion` op is a 0-D tensor of 1-bit signless
89 // integers. Thus, we use the `tensor.extract` op to compute the condition
90 // of `scf.if` from that of `tf.IfRegion`.
91 auto scf_if_condition = rewriter.create<tensor::ExtractOp>(loc, op.cond());
92
93 TypeRange tf_if_region_return_type = op.getResultTypes();
94
95 // Create the `scf.if` op.
96 auto scf_if_op =
97 rewriter.create<scf::IfOp>(loc, tf_if_region_return_type,
98 scf_if_condition, /*withElseRegion=*/true);
99
100 Region& then_region = op.then_branch();
101 Region& else_region = op.else_branch();
102
103 // Create the `then` and `else` regions of the `scf.if` op.
104 createScfThenOrElse(then_region, scf_if_op.getThenRegion(),
105 tf_if_region_return_type, rewriter);
106 createScfThenOrElse(else_region, scf_if_op.getElseRegion(),
107 tf_if_region_return_type, rewriter);
108
109 // Replace the `tf.IfRegion` op with the results of the `scf.if` op.
110 rewriter.replaceOp(op, scf_if_op.getResults());
111 return success();
112 }
113 };
114
115 /// Convert the `tf.WhileRegion` op to the `scf.while` op.
116 class ConvertWhileRegionOp : public OpRewritePattern<WhileRegionOp> {
117 public:
118 using OpRewritePattern<WhileRegionOp>::OpRewritePattern;
119
matchAndRewrite(WhileRegionOp op,PatternRewriter & rewriter) const120 LogicalResult matchAndRewrite(WhileRegionOp op,
121 PatternRewriter& rewriter) const override {
122 // Creates the `before` or `after` region of the `scf.while` op. Note that
123 // `tf_cond_or_body_region` is the `cond` or `body` region of the
124 // `tf.WhileRegion` op. `scf_before_or_after_region` is the `before` or
125 // `after` region of the new `scf.while` op. `scf_block_arguments_type` is
126 // the type of arguments that need to be in the first block of
127 // `scf_before_or_after_region`.
128 auto createScfCondOrBody =
129 [](Region& tf_cond_or_body_region, Region& scf_before_or_after_region,
130 TypeRange scf_block_arguments_type, PatternRewriter& rewriter) {
131 // Move the first block of `tf_cond_or_body_region` into the first
132 // block of `scf_before_or_after_region` and keep the later's
133 // arguments' type as `scf_block_arguments_type`.
134 moveBlock(&tf_cond_or_body_region.front(),
135 &scf_before_or_after_region.front(),
136 scf_block_arguments_type, rewriter);
137
138 Operation* cond_or_body_terminator =
139 scf_before_or_after_region.front().getTerminator();
140 rewriter.setInsertionPoint(cond_or_body_terminator);
141 return cond_or_body_terminator;
142 };
143
144 ValueRange opInput = op.input();
145 TypeRange scf_block_arguments_type = opInput.getType();
146
147 // Create the `scf.while` op.
148 auto scf_while_op = rewriter.create<scf::WhileOp>(
149 op.getLoc(), op.getResultTypes(), opInput);
150
151 // Create the `before` block of the `scf.while` op (with an `scf.condition`
152 // op as the terminator). Note that the arguments' type of this block is
153 // kept as `opInput`'s type. Note that the input of an `scf.condition` op is
154 // a 1-bit signless integer. But, the condition of the `tf.WhileRegion` op
155 // is a 0-D tensor of 1-bit signless integers. Thus, we use the
156 // `tensor.extract` op to compute the input of `scf.condition`.
157 rewriter.createBlock(&scf_while_op.getBefore());
158 Operation* cond_terminator =
159 createScfCondOrBody(op.cond(), scf_while_op.getBefore(),
160 scf_block_arguments_type, rewriter);
161 auto scf_condition_input = rewriter.create<tensor::ExtractOp>(
162 cond_terminator->getLoc(), cond_terminator->getOperand(0));
163 rewriter.replaceOpWithNewOp<scf::ConditionOp>(
164 cond_terminator, scf_condition_input.getResult(),
165 scf_while_op.getBefore().front().getArguments());
166
167 // Create the `after` block of the `scf.while` op (with an `scf.yield` op as
168 // the terminator). Note that the arguments' type of this block is kept as
169 // `opInput`'s type.
170 rewriter.createBlock(&scf_while_op.getAfter());
171 Operation* body_terminator = createScfCondOrBody(
172 op.body(), scf_while_op.getAfter(), scf_block_arguments_type, rewriter);
173 rewriter.replaceOpWithNewOp<scf::YieldOp>(body_terminator,
174 body_terminator->getOperands());
175
176 // Replace the `tf.WhileRegion` op with the `scf.while` op.
177 rewriter.replaceOp(op, scf_while_op.getResults());
178
179 return success();
180 }
181 };
182
183 } // end anonymous namespace
184
populateTfControlFlowToScfPatterns(MLIRContext * context,RewritePatternSet * patterns)185 void populateTfControlFlowToScfPatterns(MLIRContext* context,
186 RewritePatternSet* patterns) {
187 patterns->add<ConvertIfRegionOp, ConvertWhileRegionOp>(context);
188 }
189
190 struct ConvertTfControlFlowToScf
191 : public ConvertTfControlFlowToScfPassBase<ConvertTfControlFlowToScf> {
runOnOperationmlir::TF::ConvertTfControlFlowToScf192 void runOnOperation() override {
193 RewritePatternSet patterns(&getContext());
194 populateTfControlFlowToScfPatterns(&getContext(), &patterns);
195 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
196 }
197 };
198
createConvertTfControlFlowToScfPass()199 std::unique_ptr<OperationPass<ModuleOp>> createConvertTfControlFlowToScfPass() {
200 return std::make_unique<ConvertTfControlFlowToScf>();
201 }
202
203 } // namespace TF
204 } // end namespace mlir
205