• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass transforms functional control flow operations in the
17 // TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
18 
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
21 #include "mlir/IR/Attributes.h"  // from @llvm-project
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/IR/Operation.h"  // from @llvm-project
24 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
25 #include "mlir/IR/Value.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
31 
32 namespace mlir {
33 namespace TF {
34 
35 namespace {
36 
37 struct FunctionalControlFlowToCFG
38     : public PassWrapper<FunctionalControlFlowToCFG, FunctionPass> {
getDependentDialectsmlir::TF::__anone55cbe7d0111::FunctionalControlFlowToCFG39   void getDependentDialects(mlir::DialectRegistry& registry) const override {
40     registry.insert<tensor::TensorDialect>();
41   }
42 
getArgumentmlir::TF::__anone55cbe7d0111::FunctionalControlFlowToCFG43   StringRef getArgument() const final {
44     return "tf-functional-control-flow-to-cfg";
45   }
46 
getDescriptionmlir::TF::__anone55cbe7d0111::FunctionalControlFlowToCFG47   StringRef getDescription() const final {
48     return "Transform functional control flow Ops to MLIR Control Form Graph "
49            "(CFG) form";
50   }
51 
52   void runOnFunction() override;
53 };
54 
55 // Lowers a general tensor argument that is used as a condition to a functional
56 // control flow op into an i1 value.
LowerCondition(Location loc,Value value,OpBuilder * builder)57 static Value LowerCondition(Location loc, Value value, OpBuilder* builder) {
58   auto zero_d = builder->create<ToBoolOp>(loc, value);
59   auto scalar = builder->create<tensor::ExtractOp>(loc, zero_d);
60   return scalar.getResult();
61 }
62 
63 // Calls the function `fn` with arguments provided by the given function and
64 // return the CallOp. Arguments are cast to the required type before calling
65 // the function.
66 //
67 // Requires the function to provide arguments for each of the `fn` operands
68 // that is compatible for tensor cast.
CallFn(Location loc,const std::function<Value (int)> & get_arg,FuncOp fn,OpBuilder * builder)69 static Operation* CallFn(Location loc, const std::function<Value(int)>& get_arg,
70                          FuncOp fn, OpBuilder* builder) {
71   FunctionType fn_type = fn.getType();
72   llvm::SmallVector<Value, 4> operands;
73   int num_operands = fn_type.getNumInputs();
74   operands.reserve(num_operands);
75   for (int i = 0; i < num_operands; ++i) {
76     Value val = get_arg(i);
77     Type expected = fn_type.getInput(i);
78     if (val.getType() != expected) {
79       val =
80           builder->create<TF::CastOp>(loc, expected, val,
81                                       /*Truncate=*/builder->getBoolAttr(false));
82     }
83     operands.push_back(val);
84   }
85   return builder->create<CallOp>(loc, fn, operands).getOperation();
86 }
87 
88 // Prepares for jump to the given block by introducing necessary tensor_cast
89 // operations and returning Values of types required by the block.
90 //
91 // Requires the function to provide values for each of the block arguments and
92 // they should be pair-wise compatible for tensor cast.
PrepareValsForJump(Location loc,const std::function<Value (int)> & get_val,Block * block,OpBuilder * builder)93 static llvm::SmallVector<Value, 4> PrepareValsForJump(
94     Location loc, const std::function<Value(int)>& get_val, Block* block,
95     OpBuilder* builder) {
96   llvm::SmallVector<Value, 4> result;
97   int num_vals = block->getNumArguments();
98   result.reserve(num_vals);
99   for (int i = 0; i < num_vals; ++i) {
100     Value val = get_val(i);
101     Type expected = block->getArgument(i).getType();
102     if (val.getType() != expected) {
103       val =
104           builder->create<TF::CastOp>(loc, expected, val,
105                                       /*Truncate=*/builder->getBoolAttr(false));
106     }
107     result.push_back(val);
108   }
109   return result;
110 }
111 
112 // Jumps to the given block with arguments provided by the function. Arguments
113 // are cast to the required type before the jump.
114 //
115 // Requires the function to provide values for each of the block arguments and
116 // they should be pair-wise compatible for tensor cast.
JumpToBlock(Location loc,const std::function<Value (int)> & get_arg,Block * block,OpBuilder * builder)117 static void JumpToBlock(Location loc, const std::function<Value(int)>& get_arg,
118                         Block* block, OpBuilder* builder) {
119   auto operands = PrepareValsForJump(loc, get_arg, block, builder);
120   builder->create<BranchOp>(loc, block, operands);
121 }
122 
123 // Replaces all uses of the operation results in this block with block
124 // arguments.
125 //
126 // Requires that the block has same number of arguments as number of results of
127 // the operation and either they have same types or are more generic types and
128 // it is possible to cast them to results' types.
ReplaceOpResultWithBlockArgs(Location loc,Operation * op,Block * block,OpBuilder * builder)129 static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
130                                          Block* block, OpBuilder* builder) {
131   assert(op->getNumResults() == block->getNumArguments());
132   for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
133     Value arg = block->getArgument(i);
134     Value result = op->getResult(i);
135     if (arg.getType() != result.getType()) {
136       arg =
137           builder->create<TF::CastOp>(loc, result.getType(), arg,
138                                       /*Truncate=*/builder->getBoolAttr(false));
139     }
140     result.replaceAllUsesWith(arg);
141   }
142 }
143 
144 // Given a functional IfOp, transforms the enclosing code to eliminate it
145 // completely from the IR, breaking it into operations to evaluate the condition
146 // as a bool, plus some branches.
LowerIfOp(IfOp op)147 static LogicalResult LowerIfOp(IfOp op) {
148   Operation* op_inst = op.getOperation();
149   Location loc = op_inst->getLoc();
150 
151   OpBuilder builder(op_inst);
152 
153   // Lower the condition to a boolean value (i1).
154   Value cond_i1 = LowerCondition(loc, op.cond(), &builder);
155   if (!cond_i1) return failure();
156 
157   // Split the basic block before the 'if'.  The new dest will be our merge
158   // point.
159   Block* orig_block = op_inst->getBlock();
160   Block* merge_block = orig_block->splitBlock(op);
161 
162   // Add the block arguments to the merge point, and replace all uses of the
163   // original operation results with them.
164   for (Value value : op_inst->getResults())
165     merge_block->addArgument(value.getType());
166   ReplaceOpResultWithBlockArgs(loc, op_inst, merge_block, &builder);
167 
168   // Get arguments to the branches after dropping the condition which is the
169   // first operand.
170   auto get_operand = [&](int i) { return op_inst->getOperand(i + 1); };
171 
172   // Set up the 'then' block.
173   Block* then_block = builder.createBlock(merge_block);
174   Operation* call_op = CallFn(loc, get_operand, op.then_function(), &builder);
175 
176   auto get_then_result = [&](int i) { return call_op->getResult(i); };
177   JumpToBlock(loc, get_then_result, merge_block, &builder);
178 
179   // Set up the 'else' block.
180   Block* else_block = builder.createBlock(merge_block);
181   call_op = CallFn(loc, get_operand, op.else_function(), &builder);
182 
183   auto get_else_result = [&](int i) { return call_op->getResult(i); };
184   JumpToBlock(loc, get_else_result, merge_block, &builder);
185 
186   // Now that we have the then and else blocks, replace the terminator of the
187   // orig_block with a conditional branch.
188   builder.setInsertionPointToEnd(orig_block);
189   builder.create<CondBranchOp>(loc, cond_i1, then_block,
190                                llvm::ArrayRef<Value>(), else_block,
191                                llvm::ArrayRef<Value>());
192 
193   // Finally, delete the op in question.
194   op_inst->erase();
195   return success();
196 }
197 
198 // Given a functional WhileOp, transforms the enclosing code to eliminate it
199 // completely from the IR, breaking it into operations to execute the loop body
200 // repeatedly while the loop condition is true.
LowerWhileOp(WhileOp op)201 static LogicalResult LowerWhileOp(WhileOp op) {
202   Operation* op_inst = op.getOperation();
203   Location loc = op_inst->getLoc();
204 
205   OpBuilder builder(op_inst);
206 
207   auto cond_fn = op.cond_function();
208   auto body_fn = op.body_function();
209 
210   // Split the block containing the While op into two blocks.  One containing
211   // operations before the While op and other containing the rest.  Create two
212   // new blocks to call condition and body functions.
213   //
214   // The final control flow graph would be as follows:
215   //
216   // ...
217   // orig_block_head(...):
218   //   ...
219   //   br cond_block(...)
220   // cond_block(...):
221   //   %A = call @cond(...)
222   //   cond br %A, body_block(...), orig_block_tail(...)
223   // body_block(...):
224   //   %B = call @body(...)
225   //   br cond_block(...)
226   // orig_block_tail(...):
227   //   ...
228   //
229   Block* orig_block_head = op_inst->getBlock();
230   Block* orig_block_tail = orig_block_head->splitBlock(op);
231   Block* cond_block = builder.createBlock(orig_block_tail);
232   Block* body_block = builder.createBlock(orig_block_tail);
233 
234   // Set argument types for the cond_block to be same as the types of the
235   // condition function and argument types for the other two blocks to be same
236   // as the input types of the body function. Note that it is always possible
237   // for body_block and orig_block_tail to have arguments of the same types as
238   // they have exactly one call-site and they are sharing the operands.
239   for (Type type : cond_fn.getType().getInputs()) {
240     cond_block->addArgument(type);
241   }
242   for (Type type : body_fn.getType().getInputs()) {
243     body_block->addArgument(type);
244     orig_block_tail->addArgument(type);
245   }
246 
247   auto get_operand = [&](int i) { return op_inst->getOperand(i); };
248 
249   // Unconditionally branch from the original block to the block containing the
250   // condition.
251   builder.setInsertionPointToEnd(orig_block_head);
252   JumpToBlock(loc, get_operand, cond_block, &builder);
253 
254   // Call condition function in the condition block and then branch to the body
255   // block or remainder of the original block depending on condition function
256   // result.
257   builder.setInsertionPointToEnd(cond_block);
258 
259   auto get_cond_arg = [&](int i) { return cond_block->getArgument(i); };
260   Operation* cond_call_op = CallFn(loc, get_cond_arg, cond_fn, &builder);
261 
262   assert(cond_call_op->getNumResults() == 1);
263   Value condition = LowerCondition(loc, cond_call_op->getResult(0), &builder);
264   auto br_operands =
265       PrepareValsForJump(loc, get_cond_arg, body_block, &builder);
266   builder.create<CondBranchOp>(loc, condition, body_block, br_operands,
267                                orig_block_tail, br_operands);
268 
269   // Call body function in the body block and then unconditionally branch back
270   // to the condition block.
271   builder.setInsertionPointToEnd(body_block);
272   auto get_body_arg = [&](int i) { return body_block->getArgument(i); };
273   Operation* body_call_op = CallFn(loc, get_body_arg, body_fn, &builder);
274 
275   auto get_body_result = [&](int i) { return body_call_op->getResult(i); };
276   JumpToBlock(loc, get_body_result, cond_block, &builder);
277 
278   // Replace use of the while loop results with block inputs in the remainder of
279   // the original block and then delete the original While operation.
280   builder.setInsertionPoint(&orig_block_tail->front());
281   ReplaceOpResultWithBlockArgs(loc, op_inst, orig_block_tail, &builder);
282   op_inst->erase();
283 
284   return success();
285 }
286 
runOnFunction()287 void FunctionalControlFlowToCFG::runOnFunction() {
288   // Scan the function looking for these ops.
289   for (Block& block : getFunction()) {
290     for (Operation& op : block) {
291       // If the operation is one of the control flow ops we know, lower it.
292       // If we lower an operation, then the current basic block will be split,
293       // and the operation will be removed, so we should continue looking at
294       // subsequent blocks.
295       //
296       // TODO: Use PatternRewriter to eliminate these function control flow ops.
297 
298       if (IfOp if_op = llvm::dyn_cast<IfOp>(op)) {
299         if (failed(LowerIfOp(if_op))) {
300           return signalPassFailure();
301         }
302         break;
303       }
304       if (WhileOp while_op = llvm::dyn_cast<WhileOp>(op)) {
305         if (failed(LowerWhileOp(while_op))) {
306           return signalPassFailure();
307         }
308         break;
309       }
310     }
311   }
312 }
313 
314 }  // namespace
315 
CreateTFFunctionalControlFlowToCFG()316 std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToCFG() {
317   return std::make_unique<FunctionalControlFlowToCFG>();
318 }
319 
320 static PassRegistration<FunctionalControlFlowToCFG> pass;
321 
322 }  // namespace TF
323 }  // namespace mlir
324