• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering TensorFlow dialect's control flow to
17 // the XLA dialect.
18 
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23 #include <tuple>
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/Operation.h"  // from @llvm-project
33 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
34 #include "mlir/IR/Types.h"  // from @llvm-project
35 #include "mlir/Pass/Pass.h"  // from @llvm-project
36 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
37 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes_detail.h"
41 
42 using mlir::PassRegistration;
43 
44 namespace mlir {
45 namespace mhlo {
46 namespace {
47 class LegalizeTFControlFlow
48     : public LegalizeTFControlFlowBase<LegalizeTFControlFlow> {
49  public:
50   void runOnOperation() override;
51 };
52 }  // namespace
53 
54 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createLegalizeTFControlFlowPass()55 createLegalizeTFControlFlowPass() {
56   return std::make_unique<LegalizeTFControlFlow>();
57 }
58 
59 namespace {
60 
Detuple(Value tuple,ValueRange replace,OpBuilder * builder)61 void Detuple(Value tuple, ValueRange replace, OpBuilder* builder) {
62   // De-tuple the results of the xla hlo if result.
63   for (auto result_it : llvm::enumerate(replace)) {
64     auto get_tuple_value = builder->create<mhlo::GetTupleElementOp>(
65         result_it.value().getLoc(), tuple, result_it.index());
66     result_it.value().replaceAllUsesWith(get_tuple_value);
67   }
68 }
69 
70 // Imports the source region into the destination region. The XLA if
71 // operation only supports one argument per branch. Therefore any branch that
72 // requires additional arguments requires their values be tupled together. Then,
73 // to support multiple returns (as XLA only supports a single return value) the
74 // results of the if operation are tupled together.
ImportXlaRegion(mlir::FuncOp func,Region * dest_region,Location loc,bool tuple_return=true)75 void ImportXlaRegion(mlir::FuncOp func, Region* dest_region, Location loc,
76                      bool tuple_return = true) {
77   OpBuilder builder(dest_region);
78 
79   auto entry_block = builder.createBlock(dest_region);
80   auto tuple_arg = entry_block->addArgument(
81       builder.getTupleType(func.getType().getInputs()));
82   llvm::SmallVector<Value, 4> detupled_args;
83   detupled_args.reserve(func.getNumArguments());
84 
85   for (int64_t i = 0, s = func.getNumArguments(); i < s; i++) {
86     auto extract = builder.create<GetTupleElementOp>(loc, tuple_arg, i);
87     detupled_args.push_back(extract);
88   }
89 
90   auto result = builder.create<CallOp>(loc, func, detupled_args).getResults();
91   if (!tuple_return) {
92     builder.create<mhlo::ReturnOp>(loc, result);
93   } else {
94     auto tuple_op = builder.create<TupleOp>(loc, result);
95     builder.create<mhlo::ReturnOp>(loc, tuple_op.getResult());
96   }
97 }
98 
LowerIf(TF::IfOp op)99 void LowerIf(TF::IfOp op) {
100   Location loc = op.getLoc();
101   OpBuilder builder(op);
102 
103   // XLA prefers tuple arguments for control flow due to XLA not supporting
104   // multiple return values.
105   SmallVector<Value, 3> inputs(op.input());
106   auto tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
107 
108   // Create the new `mhlo.if` op with tuple inputs.
109   auto result_type = builder.getTupleType(op.getResultTypes());
110   auto if_op = builder.create<mhlo::IfOp>(loc, result_type, op.cond(),
111                                           tuple_input, tuple_input);
112 
113   // Import the regions for both the true and false cases. These regions
114   // must be updated to tuple the return results together and use the xla hlo
115   // return op.
116   ImportXlaRegion(op.then_function(), &if_op.true_branch(), loc);
117   ImportXlaRegion(op.else_function(), &if_op.false_branch(), loc);
118 
119   // De-tuple the results of the `mhlo.if`.
120   Detuple(if_op.getResult(), op.getResults(), &builder);
121   op.erase();
122 }
123 
LowerCase(TF::CaseOp op)124 void LowerCase(TF::CaseOp op) {
125   Location loc = op.getLoc();
126   OpBuilder builder(op);
127 
128   // XLA requires one argument per branch so we create a tuple of inputs to pass
129   // to each branch.
130   SmallVector<Value, 4> inputs(op.input());
131   auto tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
132 
133   // Create replica of input tuple for each branch
134   SmallVector<Value, 4> n_tuple_inputs(op.num_branches(), tuple_input);
135 
136   // Create the new `mhlo.case` op with tuple inputs.
137   auto case_op =
138       builder.create<mhlo::CaseOp>(loc, op.getResultTypes(), op.branch_index(),
139                                    n_tuple_inputs, op.branches().size());
140 
141   // Import the regions for all branches.
142   for (unsigned i = 0; i < op.num_branches(); ++i) {
143     mlir::FuncOp branch_func = op.branch_function(i);
144     ImportXlaRegion(branch_func, &case_op.branches()[i], loc,
145                     /*tuple_return=*/false);
146   }
147 
148   op.replaceAllUsesWith(case_op.getResults());
149   op.erase();
150 }
151 
LowerWhile(TF::WhileOp op)152 void LowerWhile(TF::WhileOp op) {
153   Location loc = op.getLoc();
154   OpBuilder builder(op);
155 
156   // XLA prefers tuple arguments for control flow due to XLA not supporting
157   // multiple return values.
158   SmallVector<Value, 3> inputs(op.input());
159   builder.setInsertionPoint(op);
160   Value tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
161 
162   // Create the new `mhlo.while` op with tuple inputs.
163   auto while_op = builder.create<mhlo::WhileOp>(
164       loc, builder.getTupleType(op.getResultTypes()), tuple_input);
165 
166   // Import the regions for both the cond and body. These regions must be
167   // updated to tuple the return results together and use the xla hlo return op.
168   ImportXlaRegion(op.body_function(), &while_op.body(), loc);
169   ImportXlaRegion(op.cond_function(), &while_op.cond(), loc,
170                   /*tuple_return=*/false);
171 
172   // De-tuple the results of the `mhlo.while` if needed.
173   if (while_op.getNumResults() == 1 && while_op.getType(0).isa<TupleType>())
174     Detuple(while_op.getResult(0), op.getResults(), &builder);
175   else
176     op->replaceAllUsesWith(while_op);
177   op.erase();
178 }
179 
180 // Replaces all block arguments of a block with a single block arg of Tuple
181 // type `tuple_type`. Single block arguments are removed and remapped to
182 // get_tuple_element(tuple_arg, index).
ReplaceBlockArgs(Block * block,Type tuple_type,OpBuilder * builder)183 void ReplaceBlockArgs(Block* block, Type tuple_type, OpBuilder* builder) {
184   auto tuple_arg = block->addArgument(tuple_type);
185   Detuple(tuple_arg, block->getArguments().drop_back(1), builder);
186   for (int i = block->getNumArguments() - 2; i >= 0; --i)
187     block->eraseArgument(i);
188 }
189 
190 // Replaces implicitly captured value uses with tuple block argument.
191 // get_tuple_element's are created to extract specific values. Values from
192 // get_tuple_element's are returned in the order of `implicit_inputs`.
ReplaceImplicitInputs(Block * block,int offset,ArrayRef<Value> implicit_inputs,OpBuilder * builder)193 llvm::SmallVector<Value, 4> ReplaceImplicitInputs(
194     Block* block, int offset, ArrayRef<Value> implicit_inputs,
195     OpBuilder* builder) {
196   llvm::SmallVector<Value, 4> implicit_input_elements;
197   implicit_input_elements.reserve(implicit_inputs.size());
198 
199   Region* region = block->getParent();
200   assert(block->getNumArguments() == 1);
201 
202   BlockArgument tuple_arg = block->getArgument(0);
203   for (auto& implicit_input : llvm::enumerate(implicit_inputs)) {
204     Value implicit_input_value = implicit_input.value();
205     auto get_tuple_element = builder->create<mhlo::GetTupleElementOp>(
206         implicit_input_value.getLoc(), tuple_arg,
207         implicit_input.index() + offset);
208     implicit_input_elements.emplace_back(get_tuple_element.getResult());
209     for (auto& use :
210          llvm::make_early_inc_range(implicit_input_value.getUses())) {
211       if (!region->isAncestor(use.getOwner()->getParentRegion())) continue;
212       use.set(get_tuple_element.getResult());
213     }
214   }
215 
216   return implicit_input_elements;
217 }
218 
219 // Finds and replaces implicitly captured value uses with tuple block argument.
220 // A tuple of implicitly captured values is also created and returned, for use
221 // as an operand to the associated mhlo control flow op.
TupleImplicitInputs(Region & region,Location loc,OpBuilder * builder)222 Value TupleImplicitInputs(Region& region, Location loc, OpBuilder* builder) {
223   llvm::SetVector<Value> implicit_inputs;
224   getUsedValuesDefinedAbove(region, region, implicit_inputs);
225   llvm::ArrayRef<Value> implicit_inputs_ref = implicit_inputs.getArrayRef();
226   Value tuple_input = builder->create<mhlo::TupleOp>(loc, implicit_inputs_ref);
227   Block& block = region.front();
228   // `tf.CaseRegion`/`tf.IfRegion` are expected to have no block arguments and
229   // instead all inputs used by their branch regions are implicitly captured
230   // from above.
231   assert(block.getNumArguments() == 0);
232   block.addArgument(tuple_input.getType());
233   builder->setInsertionPointToStart(&block);
234   ReplaceImplicitInputs(&block, /*offset=*/0, implicit_inputs_ref, builder);
235   return tuple_input;
236 }
237 
238 // Replaces block terminator (tf.Yield) with `mhlo.return`. Additional results
239 // can be returned if `extra_results` is not empty. If `tuple_return` is
240 // set, a tuple of the return values will be set as the terminator operand.
ReplaceTerminator(Block * block,ArrayRef<Value> extra_results,OpBuilder * builder,bool tuple_return=true)241 void ReplaceTerminator(Block* block, ArrayRef<Value> extra_results,
242                        OpBuilder* builder, bool tuple_return = true) {
243   Operation* terminator = block->getTerminator();
244   assert(isa<TF::YieldOp>(terminator));
245   Location loc = terminator->getLoc();
246 
247   builder->setInsertionPoint(terminator);
248   auto results = llvm::to_vector<4>(terminator->getOperands());
249   results.append(extra_results.begin(), extra_results.end());
250   if (tuple_return) {
251     auto tuple_results = builder->create<mhlo::TupleOp>(loc, results);
252     builder->create<mhlo::ReturnOp>(loc, tuple_results.getResult());
253   } else {
254     builder->create<mhlo::ReturnOp>(loc, results);
255   }
256 
257   terminator->erase();
258 }
259 
LowerIfRegion(TF::IfRegionOp op)260 void LowerIfRegion(TF::IfRegionOp op) {
261   Location loc = op.getLoc();
262   OpBuilder builder(op);
263 
264   // Tuple implicit inputs per region and update terminators to return tuples.
265   builder.setInsertionPoint(op);
266   Value then_input = TupleImplicitInputs(op.then_branch(), loc, &builder);
267   ReplaceTerminator(&op.then_branch().front(), /*extra_results=*/{}, &builder);
268 
269   builder.setInsertionPoint(op);
270   Value else_input = TupleImplicitInputs(op.else_branch(), loc, &builder);
271   ReplaceTerminator(&op.else_branch().front(), /*extra_results=*/{}, &builder);
272 
273   // Create the new `mhlo.if` op with tuple inputs and take ownership of regions
274   // from `tf.IfRegion` op.
275   builder.setInsertionPoint(op);
276   auto result_type = builder.getTupleType(op.getResultTypes());
277   auto if_op = builder.create<mhlo::IfOp>(loc, result_type, op.cond(),
278                                           then_input, else_input);
279   if_op.true_branch().takeBody(op.then_branch());
280   if_op.false_branch().takeBody(op.else_branch());
281 
282   // De-tuple the results of the `mhlo.if`.
283   Detuple(if_op.getResult(), op.getResults(), &builder);
284   op.erase();
285 }
286 
LowerCaseRegion(TF::CaseRegionOp op)287 void LowerCaseRegion(TF::CaseRegionOp op) {
288   Location loc = op.getLoc();
289   OpBuilder builder(op);
290 
291   llvm::SmallVector<Value, 4> branch_inputs;
292   branch_inputs.reserve(op.branches().size());
293   // Tuple implicit inputs per region and update terminators.
294   for (Region& region : op.branches()) {
295     builder.setInsertionPoint(op);
296     Value branch_input = TupleImplicitInputs(region, loc, &builder);
297     branch_inputs.emplace_back(branch_input);
298     ReplaceTerminator(&region.front(), /*extra_results=*/{}, &builder,
299                       /*tuple_return=*/false);
300   }
301 
302   // Create the new `mhlo.case` op with tuple inputs and take ownership of
303   // regions from `tf.CaseRegion` op.
304   builder.setInsertionPoint(op);
305   auto case_op =
306       builder.create<mhlo::CaseOp>(loc, op.getResultTypes(), op.branch_index(),
307                                    branch_inputs, branch_inputs.size());
308   for (auto region : llvm::zip(case_op.branches(), op.branches()))
309     std::get<0>(region).takeBody(std::get<1>(region));
310 
311   op.replaceAllUsesWith(case_op.getResults());
312   op.erase();
313 }
314 
LowerWhileRegion(TF::WhileRegionOp op)315 void LowerWhileRegion(TF::WhileRegionOp op) {
316   Location loc = op.getLoc();
317   OpBuilder builder(op);
318 
319   // XLA prefers tuple arguments for control flow due to XLA not supporting
320   // multiple return values.
321   SmallVector<Value, 3> inputs(op.input());
322   const int inputs_size = inputs.size();
323   llvm::SetVector<Value> implicit_inputs;
324   getUsedValuesDefinedAbove(op.getOperation()->getRegions(), implicit_inputs);
325   inputs.append(implicit_inputs.begin(), implicit_inputs.end());
326 
327   builder.setInsertionPoint(op);
328   Value tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
329 
330   // Create the new `mhlo.while` op with tuple inputs. Implicit inputs are also
331   // returned.
332   auto while_result_types = llvm::to_vector<4>(op.getResultTypes());
333   while_result_types.reserve(while_result_types.size() +
334                              implicit_inputs.size());
335   for (const auto& implicit_input : implicit_inputs)
336     while_result_types.emplace_back(implicit_input.getType());
337   auto while_op = builder.create<mhlo::WhileOp>(
338       loc, builder.getTupleType(while_result_types), tuple_input);
339 
340   // Rewrite cond and associated block arguments and terminator. Ownership of
341   // cond region is transfered over from `tf.WhileRegion` to `mhlo.while`.
342   Region& cond = while_op.cond();
343   cond.takeBody(op.cond());
344   Block& cond_block = cond.front();
345   builder.setInsertionPointToStart(&cond_block);
346   ReplaceBlockArgs(&cond_block, tuple_input.getType(), &builder);
347   ReplaceImplicitInputs(&cond_block, inputs_size, implicit_inputs.getArrayRef(),
348                         &builder);
349   // Cond always returns a single result of bool type.
350   ReplaceTerminator(&cond_block, /*extra_results=*/{}, &builder,
351                     /*tuple_return=*/false);
352 
353   // Rewrite body and associated block arguments and terminator. Ownership of
354   // body region is transfered over from `tf.WhileRegion` to `mhlo.while`.
355   Region& body = while_op.body();
356   body.takeBody(op.body());
357   Block& body_block = body.front();
358   builder.setInsertionPointToStart(&body_block);
359   ReplaceBlockArgs(&body_block, tuple_input.getType(), &builder);
360   // Capture implicit inputs that were added as a tuple block arguments. These
361   // are to be returned by the body in addition to explicit inputs.
362   auto implicit_input_elements = ReplaceImplicitInputs(
363       &body_block, inputs_size, implicit_inputs.getArrayRef(), &builder);
364   ReplaceTerminator(&body_block, implicit_input_elements, &builder);
365 
366   // De-tuple the results of the `mhlo.while`.
367   builder.setInsertionPoint(op);
368   if (while_op.getNumResults() == 1 && while_op.getType(0).isa<TupleType>())
369     Detuple(while_op.getResult(0), op.getResults(), &builder);
370   else
371     op->replaceAllUsesWith(while_op);
372   op.erase();
373 }
374 }  // namespace
375 
runOnOperation()376 void LegalizeTFControlFlow::runOnOperation() {
377   getOperation().walk([&](Operation* op) {
378     if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
379       LowerWhile(while_op);
380       return;
381     }
382     if (auto while_region_op = dyn_cast<TF::WhileRegionOp>(op)) {
383       LowerWhileRegion(while_region_op);
384       return;
385     }
386     if (auto if_op = dyn_cast<TF::IfOp>(op)) {
387       LowerIf(if_op);
388       return;
389     }
390     if (auto if_region_op = dyn_cast<TF::IfRegionOp>(op)) {
391       LowerIfRegion(if_region_op);
392       return;
393     }
394     if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
395       LowerCase(case_op);
396       return;
397     }
398     if (auto case_region_op = dyn_cast<TF::CaseRegionOp>(op)) {
399       LowerCaseRegion(case_region_op);
400       return;
401     }
402   });
403 }
404 }  // namespace mhlo
405 }  // namespace mlir
406