• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering TensorFlow dialect's control flow to
17 // the XLA dialect.
18 
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23 #include <tuple>
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/Operation.h"  // from @llvm-project
33 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
34 #include "mlir/IR/Types.h"  // from @llvm-project
35 #include "mlir/Pass/Pass.h"  // from @llvm-project
36 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
37 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 
41 using mlir::PassRegistration;
42 
43 namespace mlir {
44 namespace mhlo {
45 namespace {
46 class LegalizeTFControlFlow
47     : public PassWrapper<LegalizeTFControlFlow, OperationPass<ModuleOp>> {
48  public:
49   void runOnOperation() override;
50 };
51 }  // namespace
52 
53 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createLegalizeTFControlFlowPass()54 createLegalizeTFControlFlowPass() {
55   return std::make_unique<LegalizeTFControlFlow>();
56 }
57 
58 namespace {
59 
Detuple(Value tuple,ValueRange replace,OpBuilder * builder)60 void Detuple(Value tuple, ValueRange replace, OpBuilder* builder) {
61   // De-tuple the results of the xla hlo if result.
62   for (auto result_it : llvm::enumerate(replace)) {
63     auto get_tuple_value = builder->create<mhlo::GetTupleElementOp>(
64         result_it.value().getLoc(), tuple, result_it.index());
65     result_it.value().replaceAllUsesWith(get_tuple_value);
66   }
67 }
68 
69 // Imports the source region into the destination region. The XLA if
70 // operation only supports one argument per branch. Therefore any branch that
71 // requires additional arguments requires their values be tupled together. Then,
72 // to support multiple returns (as XLA only supports a single return value) the
73 // results of the if operation are tupled together.
ImportXlaRegion(mlir::FuncOp func,Region * dest_region,Location loc,bool tuple_return=true)74 void ImportXlaRegion(mlir::FuncOp func, Region* dest_region, Location loc,
75                      bool tuple_return = true) {
76   OpBuilder builder(dest_region);
77 
78   auto entry_block = builder.createBlock(dest_region);
79   auto tuple_arg = entry_block->addArgument(
80       builder.getTupleType(func.getType().getInputs()));
81   llvm::SmallVector<Value, 4> detupled_args;
82   detupled_args.reserve(func.getNumArguments());
83 
84   for (int64_t i = 0, s = func.getNumArguments(); i < s; i++) {
85     auto extract = builder.create<GetTupleElementOp>(loc, tuple_arg, i);
86     detupled_args.push_back(extract);
87   }
88 
89   auto result = builder.create<CallOp>(loc, func, detupled_args).getResults();
90   if (!tuple_return) {
91     builder.create<mhlo::ReturnOp>(loc, result);
92   } else {
93     auto tuple_op = builder.create<TupleOp>(loc, result);
94     builder.create<mhlo::ReturnOp>(loc, tuple_op.getResult());
95   }
96 }
97 
LowerIf(TF::IfOp op)98 void LowerIf(TF::IfOp op) {
99   Location loc = op.getLoc();
100   OpBuilder builder(op);
101 
102   // XLA prefers tuple arguments for control flow due to XLA not supporting
103   // multiple return values.
104   SmallVector<Value, 3> inputs(op.input());
105   auto tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
106 
107   // Create the new `mhlo.if` op with tuple inputs.
108   auto result_type = builder.getTupleType(op.getResultTypes());
109   auto if_op = builder.create<mhlo::IfOp>(loc, result_type, op.cond(),
110                                           tuple_input, tuple_input);
111 
112   // Import the regions for both the true and false cases. These regions
113   // must be updated to tuple the return results together and use the xla hlo
114   // return op.
115   ImportXlaRegion(op.then_function(), &if_op.true_branch(), loc);
116   ImportXlaRegion(op.else_function(), &if_op.false_branch(), loc);
117 
118   // De-tuple the results of the `mhlo.if`.
119   Detuple(if_op.getResult(), op.getResults(), &builder);
120   op.erase();
121 }
122 
LowerCase(TF::CaseOp op)123 void LowerCase(TF::CaseOp op) {
124   Location loc = op.getLoc();
125   OpBuilder builder(op);
126 
127   // XLA requires one argument per branch so we create a tuple of inputs to pass
128   // to each branch.
129   SmallVector<Value, 4> inputs(op.input());
130   auto tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
131 
132   // Create replica of input tuple for each branch
133   SmallVector<Value, 4> n_tuple_inputs(op.num_branches(), tuple_input);
134 
135   // Create the new `mhlo.case` op with tuple inputs.
136   auto case_op =
137       builder.create<mhlo::CaseOp>(loc, op.getResultTypes(), op.branch_index(),
138                                    n_tuple_inputs, op.branches().size());
139 
140   // Import the regions for all branches.
141   for (unsigned i = 0; i < op.num_branches(); ++i) {
142     mlir::FuncOp branch_func = op.branch_function(i);
143     ImportXlaRegion(branch_func, &case_op.branches()[i], loc,
144                     /*tuple_return=*/false);
145   }
146 
147   op.replaceAllUsesWith(case_op.getResults());
148   op.erase();
149 }
150 
LowerWhile(TF::WhileOp op)151 void LowerWhile(TF::WhileOp op) {
152   Location loc = op.getLoc();
153   OpBuilder builder(op);
154 
155   // XLA prefers tuple arguments for control flow due to XLA not supporting
156   // multiple return values.
157   SmallVector<Value, 3> inputs(op.input());
158   builder.setInsertionPoint(op);
159   Value tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
160 
161   // Create the new `mhlo.while` op with tuple inputs.
162   auto while_op = builder.create<mhlo::WhileOp>(
163       loc, builder.getTupleType(op.getResultTypes()), tuple_input);
164 
165   // Import the regions for both the cond and body. These regions must be
166   // updated to tuple the return results together and use the xla hlo return op.
167   ImportXlaRegion(op.body_function(), &while_op.body(), loc);
168   ImportXlaRegion(op.cond_function(), &while_op.cond(), loc,
169                   /*tuple_return=*/false);
170 
171   // De-tuple the results of the `mhlo.while`.
172   Detuple(while_op.getResult(), op.getResults(), &builder);
173   op.erase();
174 }
175 
176 // Replaces all block arguments of a block with a single block arg of Tuple
177 // type `tuple_type`. Single block arguments are removed and remapped to
178 // get_tuple_element(tuple_arg, index).
ReplaceBlockArgs(Block * block,Type tuple_type,OpBuilder * builder)179 void ReplaceBlockArgs(Block* block, Type tuple_type, OpBuilder* builder) {
180   auto tuple_arg = block->addArgument(tuple_type);
181   Detuple(tuple_arg, block->getArguments().drop_back(1), builder);
182   for (int i = block->getNumArguments() - 2; i >= 0; --i)
183     block->eraseArgument(i);
184 }
185 
186 // Replaces implicitly captured value uses with tuple block argument.
187 // get_tuple_element's are created to extract specific values. Values from
188 // get_tuple_element's are returned in the order of `implicit_inputs`.
ReplaceImplicitInputs(Block * block,int offset,ArrayRef<Value> implicit_inputs,OpBuilder * builder)189 llvm::SmallVector<Value, 4> ReplaceImplicitInputs(
190     Block* block, int offset, ArrayRef<Value> implicit_inputs,
191     OpBuilder* builder) {
192   llvm::SmallVector<Value, 4> implicit_input_elements;
193   implicit_input_elements.reserve(implicit_inputs.size());
194 
195   Region* region = block->getParent();
196   assert(block->getNumArguments() == 1);
197 
198   BlockArgument tuple_arg = block->getArgument(0);
199   for (auto& implicit_input : llvm::enumerate(implicit_inputs)) {
200     Value implicit_input_value = implicit_input.value();
201     auto get_tuple_element = builder->create<mhlo::GetTupleElementOp>(
202         implicit_input_value.getLoc(), tuple_arg,
203         implicit_input.index() + offset);
204     implicit_input_elements.emplace_back(get_tuple_element.getResult());
205     for (auto& use :
206          llvm::make_early_inc_range(implicit_input_value.getUses())) {
207       if (!region->isAncestor(use.getOwner()->getParentRegion())) continue;
208       use.set(get_tuple_element.getResult());
209     }
210   }
211 
212   return implicit_input_elements;
213 }
214 
215 // Finds and replaces implicitly captured value uses with tuple block argument.
216 // A tuple of implicitly captured values is also created and returned, for use
217 // as an operand to the associated mhlo control flow op.
TupleImplicitInputs(Region & region,Location loc,OpBuilder * builder)218 Value TupleImplicitInputs(Region& region, Location loc, OpBuilder* builder) {
219   llvm::SetVector<Value> implicit_inputs;
220   getUsedValuesDefinedAbove(region, region, implicit_inputs);
221   llvm::ArrayRef<Value> implicit_inputs_ref = implicit_inputs.getArrayRef();
222   Value tuple_input = builder->create<mhlo::TupleOp>(loc, implicit_inputs_ref);
223   Block& block = region.front();
224   // `tf.CaseRegion`/`tf.IfRegion` are expected to have no block arguments and
225   // instead all inputs used by their branch regions are implicitly captured
226   // from above.
227   assert(block.getNumArguments() == 0);
228   block.addArgument(tuple_input.getType());
229   builder->setInsertionPointToStart(&block);
230   ReplaceImplicitInputs(&block, /*offset=*/0, implicit_inputs_ref, builder);
231   return tuple_input;
232 }
233 
234 // Replaces block terminator (tf.Yield) with `mhlo.return`. Additional results
235 // can be returned if `extra_results` is not empty. If `tuple_return` is
236 // set, a tuple of the return values will be set as the terminator operand.
ReplaceTerminator(Block * block,ArrayRef<Value> extra_results,OpBuilder * builder,bool tuple_return=true)237 void ReplaceTerminator(Block* block, ArrayRef<Value> extra_results,
238                        OpBuilder* builder, bool tuple_return = true) {
239   Operation* terminator = block->getTerminator();
240   assert(isa<TF::YieldOp>(terminator));
241   Location loc = terminator->getLoc();
242 
243   builder->setInsertionPoint(terminator);
244   auto results = llvm::to_vector<4>(terminator->getOperands());
245   results.append(extra_results.begin(), extra_results.end());
246   if (tuple_return) {
247     auto tuple_results = builder->create<mhlo::TupleOp>(loc, results);
248     builder->create<mhlo::ReturnOp>(loc, tuple_results.getResult());
249   } else {
250     builder->create<mhlo::ReturnOp>(loc, results);
251   }
252 
253   terminator->erase();
254 }
255 
LowerIfRegion(TF::IfRegionOp op)256 void LowerIfRegion(TF::IfRegionOp op) {
257   Location loc = op.getLoc();
258   OpBuilder builder(op);
259 
260   // Tuple implicit inputs per region and update terminators to return tuples.
261   builder.setInsertionPoint(op);
262   Value then_input = TupleImplicitInputs(op.then_branch(), loc, &builder);
263   ReplaceTerminator(&op.then_branch().front(), /*extra_results=*/{}, &builder);
264 
265   builder.setInsertionPoint(op);
266   Value else_input = TupleImplicitInputs(op.else_branch(), loc, &builder);
267   ReplaceTerminator(&op.else_branch().front(), /*extra_results=*/{}, &builder);
268 
269   // Create the new `mhlo.if` op with tuple inputs and take ownership of regions
270   // from `tf.IfRegion` op.
271   builder.setInsertionPoint(op);
272   auto result_type = builder.getTupleType(op.getResultTypes());
273   auto if_op = builder.create<mhlo::IfOp>(loc, result_type, op.cond(),
274                                           then_input, else_input);
275   if_op.true_branch().takeBody(op.then_branch());
276   if_op.false_branch().takeBody(op.else_branch());
277 
278   // De-tuple the results of the `mhlo.if`.
279   Detuple(if_op.getResult(), op.getResults(), &builder);
280   op.erase();
281 }
282 
LowerCaseRegion(TF::CaseRegionOp op)283 void LowerCaseRegion(TF::CaseRegionOp op) {
284   Location loc = op.getLoc();
285   OpBuilder builder(op);
286 
287   llvm::SmallVector<Value, 4> branch_inputs;
288   branch_inputs.reserve(op.branches().size());
289   // Tuple implicit inputs per region and update terminators.
290   for (Region& region : op.branches()) {
291     builder.setInsertionPoint(op);
292     Value branch_input = TupleImplicitInputs(region, loc, &builder);
293     branch_inputs.emplace_back(branch_input);
294     ReplaceTerminator(&region.front(), /*extra_results=*/{}, &builder,
295                       /*tuple_return=*/false);
296   }
297 
298   // Create the new `mhlo.case` op with tuple inputs and take ownership of
299   // regions from `tf.CaseRegion` op.
300   builder.setInsertionPoint(op);
301   auto case_op =
302       builder.create<mhlo::CaseOp>(loc, op.getResultTypes(), op.branch_index(),
303                                    branch_inputs, branch_inputs.size());
304   for (auto region : llvm::zip(case_op.branches(), op.branches()))
305     std::get<0>(region).takeBody(std::get<1>(region));
306 
307   op.replaceAllUsesWith(case_op.getResults());
308   op.erase();
309 }
310 
LowerWhileRegion(TF::WhileRegionOp op)311 void LowerWhileRegion(TF::WhileRegionOp op) {
312   Location loc = op.getLoc();
313   OpBuilder builder(op);
314 
315   // XLA prefers tuple arguments for control flow due to XLA not supporting
316   // multiple return values.
317   SmallVector<Value, 3> inputs(op.input());
318   const int inputs_size = inputs.size();
319   llvm::SetVector<Value> implicit_inputs;
320   getUsedValuesDefinedAbove(op.getOperation()->getRegions(), implicit_inputs);
321   inputs.append(implicit_inputs.begin(), implicit_inputs.end());
322 
323   builder.setInsertionPoint(op);
324   Value tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
325 
326   // Create the new `mhlo.while` op with tuple inputs. Implicit inputs are also
327   // returned.
328   auto while_result_types = llvm::to_vector<4>(op.getResultTypes());
329   while_result_types.reserve(while_result_types.size() +
330                              implicit_inputs.size());
331   for (const auto& implicit_input : implicit_inputs)
332     while_result_types.emplace_back(implicit_input.getType());
333   auto while_op = builder.create<mhlo::WhileOp>(
334       loc, builder.getTupleType(while_result_types), tuple_input);
335 
336   // Rewrite cond and associated block arguments and terminator. Ownership of
337   // cond region is transfered over from `tf.WhileRegion` to `mhlo.while`.
338   Region& cond = while_op.cond();
339   cond.takeBody(op.cond());
340   Block& cond_block = cond.front();
341   builder.setInsertionPointToStart(&cond_block);
342   ReplaceBlockArgs(&cond_block, tuple_input.getType(), &builder);
343   ReplaceImplicitInputs(&cond_block, inputs_size, implicit_inputs.getArrayRef(),
344                         &builder);
345   // Cond always returns a single result of bool type.
346   ReplaceTerminator(&cond_block, /*extra_results=*/{}, &builder,
347                     /*tuple_return=*/false);
348 
349   // Rewrite body and associated block arguments and terminator. Ownership of
350   // body region is transfered over from `tf.WhileRegion` to `mhlo.while`.
351   Region& body = while_op.body();
352   body.takeBody(op.body());
353   Block& body_block = body.front();
354   builder.setInsertionPointToStart(&body_block);
355   ReplaceBlockArgs(&body_block, tuple_input.getType(), &builder);
356   // Capture implicit inputs that were added as a tuple block arguments. These
357   // are to be returned by the body in addition to explicit inputs.
358   auto implicit_input_elements = ReplaceImplicitInputs(
359       &body_block, inputs_size, implicit_inputs.getArrayRef(), &builder);
360   ReplaceTerminator(&body_block, implicit_input_elements, &builder);
361 
362   // De-tuple the results of the `mhlo.while`.
363   builder.setInsertionPoint(op);
364   Detuple(while_op.getResult(), op.getResults(), &builder);
365   op.erase();
366 }
367 }  // namespace
368 
runOnOperation()369 void LegalizeTFControlFlow::runOnOperation() {
370   getOperation().walk([&](Operation* op) {
371     if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
372       LowerWhile(while_op);
373       return;
374     }
375     if (auto while_region_op = dyn_cast<TF::WhileRegionOp>(op)) {
376       LowerWhileRegion(while_region_op);
377       return;
378     }
379     if (auto if_op = dyn_cast<TF::IfOp>(op)) {
380       LowerIf(if_op);
381       return;
382     }
383     if (auto if_region_op = dyn_cast<TF::IfRegionOp>(op)) {
384       LowerIfRegion(if_region_op);
385       return;
386     }
387     if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
388       LowerCase(case_op);
389       return;
390     }
391     if (auto case_region_op = dyn_cast<TF::CaseRegionOp>(op)) {
392       LowerCaseRegion(case_region_op);
393       return;
394     }
395   });
396 }
397 }  // namespace mhlo
398 }  // namespace mlir
399 
400 static PassRegistration<mlir::mhlo::LegalizeTFControlFlow> cfpass(
401     "xla-legalize-tf-control-flow",
402     "Legalize TensorFlow control flow to the XLA dialect");
403