• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass transforms functional control flow operations in the
17 // TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
18 
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
21 #include "mlir/IR/Attributes.h"  // from @llvm-project
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/IR/Operation.h"  // from @llvm-project
24 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
25 #include "mlir/IR/Value.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
31 
32 namespace mlir {
33 namespace TF {
34 
35 namespace {
36 
37 struct FunctionalControlFlowToCFG
38     : public PassWrapper<FunctionalControlFlowToCFG, FunctionPass> {
39   void runOnFunction() override;
40 };
41 
42 // Lowers a general tensor argument that is used as a condition to a functional
43 // control flow op into an i1 value.
LowerCondition(Location loc,Value value,OpBuilder * builder)44 static Value LowerCondition(Location loc, Value value, OpBuilder* builder) {
45   auto zero_d = builder->create<ToBoolOp>(loc, value);
46   auto scalar = builder->create<tensor::ExtractOp>(loc, zero_d);
47   return scalar.getResult();
48 }
49 
50 // Calls the function `fn` with arguments provided by the given function and
51 // return the CallOp. Arguments are cast to the required type before calling
52 // the function.
53 //
54 // Requires the function to provide arguments for each of the `fn` operands
55 // that is compatible for tensor cast.
CallFn(Location loc,const std::function<Value (int)> & get_arg,FuncOp fn,OpBuilder * builder)56 static Operation* CallFn(Location loc, const std::function<Value(int)>& get_arg,
57                          FuncOp fn, OpBuilder* builder) {
58   FunctionType fn_type = fn.getType();
59   llvm::SmallVector<Value, 4> operands;
60   int num_operands = fn_type.getNumInputs();
61   operands.reserve(num_operands);
62   for (int i = 0; i < num_operands; ++i) {
63     Value val = get_arg(i);
64     Type expected = fn_type.getInput(i);
65     if (val.getType() != expected) {
66       val =
67           builder->create<TF::CastOp>(loc, expected, val,
68                                       /*Truncate=*/builder->getBoolAttr(false));
69     }
70     operands.push_back(val);
71   }
72   return builder->create<CallOp>(loc, fn, operands).getOperation();
73 }
74 
75 // Prepares for jump to the given block by introducing necessary tensor_cast
76 // operations and returning Values of types required by the block.
77 //
78 // Requires the function to provide values for each of the block arguments and
79 // they should be pair-wise compatible for tensor cast.
PrepareValsForJump(Location loc,const std::function<Value (int)> & get_val,Block * block,OpBuilder * builder)80 static llvm::SmallVector<Value, 4> PrepareValsForJump(
81     Location loc, const std::function<Value(int)>& get_val, Block* block,
82     OpBuilder* builder) {
83   llvm::SmallVector<Value, 4> result;
84   int num_vals = block->getNumArguments();
85   result.reserve(num_vals);
86   for (int i = 0; i < num_vals; ++i) {
87     Value val = get_val(i);
88     Type expected = block->getArgument(i).getType();
89     if (val.getType() != expected) {
90       val =
91           builder->create<TF::CastOp>(loc, expected, val,
92                                       /*Truncate=*/builder->getBoolAttr(false));
93     }
94     result.push_back(val);
95   }
96   return result;
97 }
98 
99 // Jumps to the given block with arguments provided by the function. Arguments
100 // are cast to the required type before the jump.
101 //
102 // Requires the function to provide values for each of the block arguments and
103 // they should be pair-wise compatible for tensor cast.
JumpToBlock(Location loc,const std::function<Value (int)> & get_arg,Block * block,OpBuilder * builder)104 static void JumpToBlock(Location loc, const std::function<Value(int)>& get_arg,
105                         Block* block, OpBuilder* builder) {
106   auto operands = PrepareValsForJump(loc, get_arg, block, builder);
107   builder->create<BranchOp>(loc, block, operands);
108 }
109 
110 // Replaces all uses of the operation results in this block with block
111 // arguments.
112 //
113 // Requires that the block has same number of arguments as number of results of
114 // the operation and either they have same types or are more generic types and
115 // it is possible to cast them to results' types.
ReplaceOpResultWithBlockArgs(Location loc,Operation * op,Block * block,OpBuilder * builder)116 static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
117                                          Block* block, OpBuilder* builder) {
118   assert(op->getNumResults() == block->getNumArguments());
119   for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
120     Value arg = block->getArgument(i);
121     Value result = op->getResult(i);
122     if (arg.getType() != result.getType()) {
123       arg =
124           builder->create<TF::CastOp>(loc, result.getType(), arg,
125                                       /*Truncate=*/builder->getBoolAttr(false));
126     }
127     result.replaceAllUsesWith(arg);
128   }
129 }
130 
131 // Given a functional IfOp, transforms the enclosing code to eliminate it
132 // completely from the IR, breaking it into operations to evaluate the condition
133 // as a bool, plus some branches.
LowerIfOp(IfOp op)134 static LogicalResult LowerIfOp(IfOp op) {
135   Operation* op_inst = op.getOperation();
136   Location loc = op_inst->getLoc();
137 
138   OpBuilder builder(op_inst);
139 
140   // Lower the condition to a boolean value (i1).
141   Value cond_i1 = LowerCondition(loc, op.cond(), &builder);
142   if (!cond_i1) return failure();
143 
144   // Split the basic block before the 'if'.  The new dest will be our merge
145   // point.
146   Block* orig_block = op_inst->getBlock();
147   Block* merge_block = orig_block->splitBlock(op);
148 
149   // Add the block arguments to the merge point, and replace all uses of the
150   // original operation results with them.
151   for (Value value : op_inst->getResults())
152     merge_block->addArgument(value.getType());
153   ReplaceOpResultWithBlockArgs(loc, op_inst, merge_block, &builder);
154 
155   // Get arguments to the branches after dropping the condition which is the
156   // first operand.
157   auto get_operand = [&](int i) { return op_inst->getOperand(i + 1); };
158 
159   // Set up the 'then' block.
160   Block* then_block = builder.createBlock(merge_block);
161   Operation* call_op = CallFn(loc, get_operand, op.then_function(), &builder);
162 
163   auto get_then_result = [&](int i) { return call_op->getResult(i); };
164   JumpToBlock(loc, get_then_result, merge_block, &builder);
165 
166   // Set up the 'else' block.
167   Block* else_block = builder.createBlock(merge_block);
168   call_op = CallFn(loc, get_operand, op.else_function(), &builder);
169 
170   auto get_else_result = [&](int i) { return call_op->getResult(i); };
171   JumpToBlock(loc, get_else_result, merge_block, &builder);
172 
173   // Now that we have the then and else blocks, replace the terminator of the
174   // orig_block with a conditional branch.
175   builder.setInsertionPointToEnd(orig_block);
176   builder.create<CondBranchOp>(loc, cond_i1, then_block,
177                                llvm::ArrayRef<Value>(), else_block,
178                                llvm::ArrayRef<Value>());
179 
180   // Finally, delete the op in question.
181   op_inst->erase();
182   return success();
183 }
184 
185 // Given a functional WhileOp, transforms the enclosing code to eliminate it
186 // completely from the IR, breaking it into operations to execute the loop body
187 // repeatedly while the loop condition is true.
LowerWhileOp(WhileOp op)188 static LogicalResult LowerWhileOp(WhileOp op) {
189   Operation* op_inst = op.getOperation();
190   Location loc = op_inst->getLoc();
191 
192   OpBuilder builder(op_inst);
193 
194   auto cond_fn = op.cond_function();
195   auto body_fn = op.body_function();
196 
197   // Split the block containing the While op into two blocks.  One containing
198   // operations before the While op and other containing the rest.  Create two
199   // new blocks to call condition and body functions.
200   //
201   // The final control flow graph would be as follows:
202   //
203   // ...
204   // orig_block_head(...):
205   //   ...
206   //   br cond_block(...)
207   // cond_block(...):
208   //   %A = call @cond(...)
209   //   cond br %A, body_block(...), orig_block_tail(...)
210   // body_block(...):
211   //   %B = call @body(...)
212   //   br cond_block(...)
213   // orig_block_tail(...):
214   //   ...
215   //
216   Block* orig_block_head = op_inst->getBlock();
217   Block* orig_block_tail = orig_block_head->splitBlock(op);
218   Block* cond_block = builder.createBlock(orig_block_tail);
219   Block* body_block = builder.createBlock(orig_block_tail);
220 
221   // Set argument types for the cond_block to be same as the types of the
222   // condition function and argument types for the other two blocks to be same
223   // as the input types of the body function. Note that it is always possible
224   // for body_block and orig_block_tail to have arguments of the same types as
225   // they have exactly one call-site and they are sharing the operands.
226   for (Type type : cond_fn.getType().getInputs()) {
227     cond_block->addArgument(type);
228   }
229   for (Type type : body_fn.getType().getInputs()) {
230     body_block->addArgument(type);
231     orig_block_tail->addArgument(type);
232   }
233 
234   auto get_operand = [&](int i) { return op_inst->getOperand(i); };
235 
236   // Unconditionally branch from the original block to the block containing the
237   // condition.
238   builder.setInsertionPointToEnd(orig_block_head);
239   JumpToBlock(loc, get_operand, cond_block, &builder);
240 
241   // Call condition function in the condition block and then branch to the body
242   // block or remainder of the original block depending on condition function
243   // result.
244   builder.setInsertionPointToEnd(cond_block);
245 
246   auto get_cond_arg = [&](int i) { return cond_block->getArgument(i); };
247   Operation* cond_call_op = CallFn(loc, get_cond_arg, cond_fn, &builder);
248 
249   assert(cond_call_op->getNumResults() == 1);
250   Value condition = LowerCondition(loc, cond_call_op->getResult(0), &builder);
251   auto br_operands =
252       PrepareValsForJump(loc, get_cond_arg, body_block, &builder);
253   builder.create<CondBranchOp>(loc, condition, body_block, br_operands,
254                                orig_block_tail, br_operands);
255 
256   // Call body function in the body block and then unconditionally branch back
257   // to the condition block.
258   builder.setInsertionPointToEnd(body_block);
259   auto get_body_arg = [&](int i) { return body_block->getArgument(i); };
260   Operation* body_call_op = CallFn(loc, get_body_arg, body_fn, &builder);
261 
262   auto get_body_result = [&](int i) { return body_call_op->getResult(i); };
263   JumpToBlock(loc, get_body_result, cond_block, &builder);
264 
265   // Replace use of the while loop results with block inputs in the remainder of
266   // the original block and then delete the original While operation.
267   builder.setInsertionPoint(&orig_block_tail->front());
268   ReplaceOpResultWithBlockArgs(loc, op_inst, orig_block_tail, &builder);
269   op_inst->erase();
270 
271   return success();
272 }
273 
runOnFunction()274 void FunctionalControlFlowToCFG::runOnFunction() {
275   // Scan the function looking for these ops.
276   for (Block& block : getFunction()) {
277     for (Operation& op : block) {
278       // If the operation is one of the control flow ops we know, lower it.
279       // If we lower an operation, then the current basic block will be split,
280       // and the operation will be removed, so we should continue looking at
281       // subsequent blocks.
282       //
283       // TODO: Use PatternRewriter to eliminate these function control flow ops.
284 
285       if (IfOp if_op = llvm::dyn_cast<IfOp>(op)) {
286         if (failed(LowerIfOp(if_op))) {
287           return signalPassFailure();
288         }
289         break;
290       }
291       if (WhileOp while_op = llvm::dyn_cast<WhileOp>(op)) {
292         if (failed(LowerWhileOp(while_op))) {
293           return signalPassFailure();
294         }
295         break;
296       }
297     }
298   }
299 }
300 
301 }  // namespace
302 
CreateTFFunctionalControlFlowToCFG()303 std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToCFG() {
304   return std::make_unique<FunctionalControlFlowToCFG>();
305 }
306 
307 static PassRegistration<FunctionalControlFlowToCFG> pass(
308     "tf-functional-control-flow-to-cfg",
309     "Transform functional control flow Ops to MLIR Control Form Graph "
310     "(CFG) form");
311 
312 }  // namespace TF
313 }  // namespace mlir
314