• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir/Dialect/SCF/IR/SCF.h"  // from @llvm-project
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
21 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
22 
23 namespace mlir {
24 namespace TF {
25 
26 namespace {
27 
28 /// Move the ops of `source_block` into `destination_block`, keeping the later's
29 /// block arguments' type as `block_arguments_type`.
moveBlock(Block * source_block,Block * destination_block,TypeRange block_arguments_type,PatternRewriter & rewriter)30 static void moveBlock(Block* source_block, Block* destination_block,
31                       TypeRange block_arguments_type,
32                       PatternRewriter& rewriter) {
33   // If `destination_block` isn't empty, erase its terminator to ensure that it
34   // never contains two terminator-like ops after merging.
35   if (!destination_block->empty())
36     rewriter.eraseOp(destination_block->getTerminator());
37 
38   destination_block->addArguments(
39       block_arguments_type,
40       SmallVector<Location>(block_arguments_type.size(),
41                             source_block->getParent()->getLoc()));
42   rewriter.mergeBlocks(source_block, destination_block,
43                        destination_block->getArguments());
44 }
45 
46 /// Convert the `tf.IfRegion` op to the `scf.if` op.
47 class ConvertIfRegionOp : public OpRewritePattern<IfRegionOp> {
48  public:
49   using OpRewritePattern<IfRegionOp>::OpRewritePattern;
50 
matchAndRewrite(IfRegionOp op,PatternRewriter & rewriter) const51   LogicalResult matchAndRewrite(IfRegionOp op,
52                                 PatternRewriter& rewriter) const override {
53     // Creates the `then` or `else` region of the `scf.if` op. Note that
54     // `tf_then_or_else_region` is the `then` or `else` region of the
55     // `tf.IfRegion` op and `scf_then_or_else_region` is the `then` or `else`
56     // region of the new `scf.if` op. Further, `tf_if_region_return_type` is the
57     // list of return types of the `tf.IfRegion` op.
58     auto createScfThenOrElse = [](Region& tf_then_or_else_region,
59                                   Region& scf_then_or_else_region,
60                                   TypeRange tf_if_region_return_type,
61                                   PatternRewriter& rewriter) {
62       // Move the first block of `tf_then_or_else_region` into the first block
63       // of `scf_then_or_else_region` and do not add any arguments to the block.
64       moveBlock(&tf_then_or_else_region.front(),
65                 &scf_then_or_else_region.front(), TypeRange(), rewriter);
66 
67       // Replace the current terminator (a `tf.Yield` op) with an `scf.yield`
68       // op. The input of the `scf.yield` op is a list of results of `tf.Cast`
69       // ops, each of which casts an operand of the current terminator to the
70       // corresponding result type of the `tf.IfRegion` op.
71       Operation* current_terminator =
72           scf_then_or_else_region.front().getTerminator();
73       rewriter.setInsertionPoint(current_terminator);
74       SmallVector<Value, 4> scf_yield_input;
75       for (auto it : llvm::zip(tf_if_region_return_type,
76                                current_terminator->getOperands())) {
77         scf_yield_input.push_back(rewriter.create<CastOp>(
78             current_terminator->getLoc(), std::get<0>(it), std::get<1>(it)));
79       }
80 
81       rewriter.replaceOpWithNewOp<scf::YieldOp>(current_terminator,
82                                                 scf_yield_input);
83     };
84 
85     Location loc = op.getLoc();
86 
87     // The condition of an `scf.if` op is a 1-bit signless integer. Whereas, the
88     // condition of the `tf.IfRegion` op is a 0-D tensor of 1-bit signless
89     // integers. Thus, we use the `tensor.extract` op to compute the condition
90     // of `scf.if` from that of `tf.IfRegion`.
91     auto scf_if_condition = rewriter.create<tensor::ExtractOp>(loc, op.cond());
92 
93     TypeRange tf_if_region_return_type = op.getResultTypes();
94 
95     // Create the `scf.if` op.
96     auto scf_if_op =
97         rewriter.create<scf::IfOp>(loc, tf_if_region_return_type,
98                                    scf_if_condition, /*withElseRegion=*/true);
99 
100     Region& then_region = op.then_branch();
101     Region& else_region = op.else_branch();
102 
103     // Create the `then` and `else` regions of the `scf.if` op.
104     createScfThenOrElse(then_region, scf_if_op.getThenRegion(),
105                         tf_if_region_return_type, rewriter);
106     createScfThenOrElse(else_region, scf_if_op.getElseRegion(),
107                         tf_if_region_return_type, rewriter);
108 
109     // Replace the `tf.IfRegion` op with the results of the `scf.if` op.
110     rewriter.replaceOp(op, scf_if_op.getResults());
111     return success();
112   }
113 };
114 
115 /// Convert the `tf.WhileRegion` op to the `scf.while` op.
116 class ConvertWhileRegionOp : public OpRewritePattern<WhileRegionOp> {
117  public:
118   using OpRewritePattern<WhileRegionOp>::OpRewritePattern;
119 
matchAndRewrite(WhileRegionOp op,PatternRewriter & rewriter) const120   LogicalResult matchAndRewrite(WhileRegionOp op,
121                                 PatternRewriter& rewriter) const override {
122     // Creates the `before` or `after` region of the `scf.while` op. Note that
123     // `tf_cond_or_body_region` is the `cond` or `body` region of the
124     // `tf.WhileRegion` op. `scf_before_or_after_region` is the `before` or
125     // `after` region of the new `scf.while` op. `scf_block_arguments_type` is
126     // the type of arguments that need to be in the first block of
127     // `scf_before_or_after_region`.
128     auto createScfCondOrBody =
129         [](Region& tf_cond_or_body_region, Region& scf_before_or_after_region,
130            TypeRange scf_block_arguments_type, PatternRewriter& rewriter) {
131           // Move the first block of `tf_cond_or_body_region` into the first
132           // block of `scf_before_or_after_region` and keep the later's
133           // arguments' type as `scf_block_arguments_type`.
134           moveBlock(&tf_cond_or_body_region.front(),
135                     &scf_before_or_after_region.front(),
136                     scf_block_arguments_type, rewriter);
137 
138           Operation* cond_or_body_terminator =
139               scf_before_or_after_region.front().getTerminator();
140           rewriter.setInsertionPoint(cond_or_body_terminator);
141           return cond_or_body_terminator;
142         };
143 
144     ValueRange opInput = op.input();
145     TypeRange scf_block_arguments_type = opInput.getType();
146 
147     // Create the `scf.while` op.
148     auto scf_while_op = rewriter.create<scf::WhileOp>(
149         op.getLoc(), op.getResultTypes(), opInput);
150 
151     // Create the `before` block of the `scf.while` op (with an `scf.condition`
152     // op as the terminator). Note that the arguments' type of this block is
153     // kept as `opInput`'s type. Note that the input of an `scf.condition` op is
154     // a 1-bit signless integer. But, the condition of the `tf.WhileRegion` op
155     // is a 0-D tensor of 1-bit signless integers. Thus, we use the
156     // `tensor.extract` op to compute the input of `scf.condition`.
157     rewriter.createBlock(&scf_while_op.getBefore());
158     Operation* cond_terminator =
159         createScfCondOrBody(op.cond(), scf_while_op.getBefore(),
160                             scf_block_arguments_type, rewriter);
161     auto scf_condition_input = rewriter.create<tensor::ExtractOp>(
162         cond_terminator->getLoc(), cond_terminator->getOperand(0));
163     rewriter.replaceOpWithNewOp<scf::ConditionOp>(
164         cond_terminator, scf_condition_input.getResult(),
165         scf_while_op.getBefore().front().getArguments());
166 
167     // Create the `after` block of the `scf.while` op (with an `scf.yield` op as
168     // the terminator). Note that the arguments' type of this block is kept as
169     // `opInput`'s type.
170     rewriter.createBlock(&scf_while_op.getAfter());
171     Operation* body_terminator = createScfCondOrBody(
172         op.body(), scf_while_op.getAfter(), scf_block_arguments_type, rewriter);
173     rewriter.replaceOpWithNewOp<scf::YieldOp>(body_terminator,
174                                               body_terminator->getOperands());
175 
176     // Replace the `tf.WhileRegion` op with the `scf.while` op.
177     rewriter.replaceOp(op, scf_while_op.getResults());
178 
179     return success();
180   }
181 };
182 
183 }  // end anonymous namespace
184 
populateTfControlFlowToScfPatterns(MLIRContext * context,RewritePatternSet * patterns)185 void populateTfControlFlowToScfPatterns(MLIRContext* context,
186                                         RewritePatternSet* patterns) {
187   patterns->add<ConvertIfRegionOp, ConvertWhileRegionOp>(context);
188 }
189 
190 struct ConvertTfControlFlowToScf
191     : public ConvertTfControlFlowToScfPassBase<ConvertTfControlFlowToScf> {
runOnOperationmlir::TF::ConvertTfControlFlowToScf192   void runOnOperation() override {
193     RewritePatternSet patterns(&getContext());
194     populateTfControlFlowToScfPatterns(&getContext(), &patterns);
195     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
196   }
197 };
198 
createConvertTfControlFlowToScfPass()199 std::unique_ptr<OperationPass<ModuleOp>> createConvertTfControlFlowToScfPass() {
200   return std::make_unique<ConvertTfControlFlowToScf>();
201 }
202 
203 }  // namespace TF
204 }  // end namespace mlir
205