1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering TensorFlow dialect's control flow to
17 // the XLA dialect.
18
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23 #include <tuple>
24
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32 #include "mlir/IR/Operation.h" // from @llvm-project
33 #include "mlir/IR/OperationSupport.h" // from @llvm-project
34 #include "mlir/IR/Types.h" // from @llvm-project
35 #include "mlir/Pass/Pass.h" // from @llvm-project
36 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
37 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes_detail.h"
41
42 using mlir::PassRegistration;
43
44 namespace mlir {
45 namespace mhlo {
46 namespace {
47 class LegalizeTFControlFlow
48 : public LegalizeTFControlFlowBase<LegalizeTFControlFlow> {
49 public:
50 void runOnOperation() override;
51 };
52 } // namespace
53
54 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createLegalizeTFControlFlowPass()55 createLegalizeTFControlFlowPass() {
56 return std::make_unique<LegalizeTFControlFlow>();
57 }
58
59 namespace {
60
Detuple(Value tuple,ValueRange replace,OpBuilder * builder)61 void Detuple(Value tuple, ValueRange replace, OpBuilder* builder) {
62 // De-tuple the results of the xla hlo if result.
63 for (auto result_it : llvm::enumerate(replace)) {
64 auto get_tuple_value = builder->create<mhlo::GetTupleElementOp>(
65 result_it.value().getLoc(), tuple, result_it.index());
66 result_it.value().replaceAllUsesWith(get_tuple_value);
67 }
68 }
69
70 // Imports the source region into the destination region. The XLA if
71 // operation only supports one argument per branch. Therefore any branch that
72 // requires additional arguments requires their values be tupled together. Then,
73 // to support multiple returns (as XLA only supports a single return value) the
74 // results of the if operation are tupled together.
ImportXlaRegion(mlir::FuncOp func,Region * dest_region,Location loc,bool tuple_return=true)75 void ImportXlaRegion(mlir::FuncOp func, Region* dest_region, Location loc,
76 bool tuple_return = true) {
77 OpBuilder builder(dest_region);
78
79 auto entry_block = builder.createBlock(dest_region);
80 auto tuple_arg = entry_block->addArgument(
81 builder.getTupleType(func.getType().getInputs()));
82 llvm::SmallVector<Value, 4> detupled_args;
83 detupled_args.reserve(func.getNumArguments());
84
85 for (int64_t i = 0, s = func.getNumArguments(); i < s; i++) {
86 auto extract = builder.create<GetTupleElementOp>(loc, tuple_arg, i);
87 detupled_args.push_back(extract);
88 }
89
90 auto result = builder.create<CallOp>(loc, func, detupled_args).getResults();
91 if (!tuple_return) {
92 builder.create<mhlo::ReturnOp>(loc, result);
93 } else {
94 auto tuple_op = builder.create<TupleOp>(loc, result);
95 builder.create<mhlo::ReturnOp>(loc, tuple_op.getResult());
96 }
97 }
98
LowerIf(TF::IfOp op)99 void LowerIf(TF::IfOp op) {
100 Location loc = op.getLoc();
101 OpBuilder builder(op);
102
103 // XLA prefers tuple arguments for control flow due to XLA not supporting
104 // multiple return values.
105 SmallVector<Value, 3> inputs(op.input());
106 auto tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
107
108 // Create the new `mhlo.if` op with tuple inputs.
109 auto result_type = builder.getTupleType(op.getResultTypes());
110 auto if_op = builder.create<mhlo::IfOp>(loc, result_type, op.cond(),
111 tuple_input, tuple_input);
112
113 // Import the regions for both the true and false cases. These regions
114 // must be updated to tuple the return results together and use the xla hlo
115 // return op.
116 ImportXlaRegion(op.then_function(), &if_op.true_branch(), loc);
117 ImportXlaRegion(op.else_function(), &if_op.false_branch(), loc);
118
119 // De-tuple the results of the `mhlo.if`.
120 Detuple(if_op.getResult(), op.getResults(), &builder);
121 op.erase();
122 }
123
LowerCase(TF::CaseOp op)124 void LowerCase(TF::CaseOp op) {
125 Location loc = op.getLoc();
126 OpBuilder builder(op);
127
128 // XLA requires one argument per branch so we create a tuple of inputs to pass
129 // to each branch.
130 SmallVector<Value, 4> inputs(op.input());
131 auto tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
132
133 // Create replica of input tuple for each branch
134 SmallVector<Value, 4> n_tuple_inputs(op.num_branches(), tuple_input);
135
136 // Create the new `mhlo.case` op with tuple inputs.
137 auto case_op =
138 builder.create<mhlo::CaseOp>(loc, op.getResultTypes(), op.branch_index(),
139 n_tuple_inputs, op.branches().size());
140
141 // Import the regions for all branches.
142 for (unsigned i = 0; i < op.num_branches(); ++i) {
143 mlir::FuncOp branch_func = op.branch_function(i);
144 ImportXlaRegion(branch_func, &case_op.branches()[i], loc,
145 /*tuple_return=*/false);
146 }
147
148 op.replaceAllUsesWith(case_op.getResults());
149 op.erase();
150 }
151
LowerWhile(TF::WhileOp op)152 void LowerWhile(TF::WhileOp op) {
153 Location loc = op.getLoc();
154 OpBuilder builder(op);
155
156 // XLA prefers tuple arguments for control flow due to XLA not supporting
157 // multiple return values.
158 SmallVector<Value, 3> inputs(op.input());
159 builder.setInsertionPoint(op);
160 Value tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
161
162 // Create the new `mhlo.while` op with tuple inputs.
163 auto while_op = builder.create<mhlo::WhileOp>(
164 loc, builder.getTupleType(op.getResultTypes()), tuple_input);
165
166 // Import the regions for both the cond and body. These regions must be
167 // updated to tuple the return results together and use the xla hlo return op.
168 ImportXlaRegion(op.body_function(), &while_op.body(), loc);
169 ImportXlaRegion(op.cond_function(), &while_op.cond(), loc,
170 /*tuple_return=*/false);
171
172 // De-tuple the results of the `mhlo.while` if needed.
173 if (while_op.getNumResults() == 1 && while_op.getType(0).isa<TupleType>())
174 Detuple(while_op.getResult(0), op.getResults(), &builder);
175 else
176 op->replaceAllUsesWith(while_op);
177 op.erase();
178 }
179
180 // Replaces all block arguments of a block with a single block arg of Tuple
181 // type `tuple_type`. Single block arguments are removed and remapped to
182 // get_tuple_element(tuple_arg, index).
ReplaceBlockArgs(Block * block,Type tuple_type,OpBuilder * builder)183 void ReplaceBlockArgs(Block* block, Type tuple_type, OpBuilder* builder) {
184 auto tuple_arg = block->addArgument(tuple_type);
185 Detuple(tuple_arg, block->getArguments().drop_back(1), builder);
186 for (int i = block->getNumArguments() - 2; i >= 0; --i)
187 block->eraseArgument(i);
188 }
189
190 // Replaces implicitly captured value uses with tuple block argument.
191 // get_tuple_element's are created to extract specific values. Values from
192 // get_tuple_element's are returned in the order of `implicit_inputs`.
ReplaceImplicitInputs(Block * block,int offset,ArrayRef<Value> implicit_inputs,OpBuilder * builder)193 llvm::SmallVector<Value, 4> ReplaceImplicitInputs(
194 Block* block, int offset, ArrayRef<Value> implicit_inputs,
195 OpBuilder* builder) {
196 llvm::SmallVector<Value, 4> implicit_input_elements;
197 implicit_input_elements.reserve(implicit_inputs.size());
198
199 Region* region = block->getParent();
200 assert(block->getNumArguments() == 1);
201
202 BlockArgument tuple_arg = block->getArgument(0);
203 for (auto& implicit_input : llvm::enumerate(implicit_inputs)) {
204 Value implicit_input_value = implicit_input.value();
205 auto get_tuple_element = builder->create<mhlo::GetTupleElementOp>(
206 implicit_input_value.getLoc(), tuple_arg,
207 implicit_input.index() + offset);
208 implicit_input_elements.emplace_back(get_tuple_element.getResult());
209 for (auto& use :
210 llvm::make_early_inc_range(implicit_input_value.getUses())) {
211 if (!region->isAncestor(use.getOwner()->getParentRegion())) continue;
212 use.set(get_tuple_element.getResult());
213 }
214 }
215
216 return implicit_input_elements;
217 }
218
219 // Finds and replaces implicitly captured value uses with tuple block argument.
220 // A tuple of implicitly captured values is also created and returned, for use
221 // as an operand to the associated mhlo control flow op.
TupleImplicitInputs(Region & region,Location loc,OpBuilder * builder)222 Value TupleImplicitInputs(Region& region, Location loc, OpBuilder* builder) {
223 llvm::SetVector<Value> implicit_inputs;
224 getUsedValuesDefinedAbove(region, region, implicit_inputs);
225 llvm::ArrayRef<Value> implicit_inputs_ref = implicit_inputs.getArrayRef();
226 Value tuple_input = builder->create<mhlo::TupleOp>(loc, implicit_inputs_ref);
227 Block& block = region.front();
228 // `tf.CaseRegion`/`tf.IfRegion` are expected to have no block arguments and
229 // instead all inputs used by their branch regions are implicitly captured
230 // from above.
231 assert(block.getNumArguments() == 0);
232 block.addArgument(tuple_input.getType());
233 builder->setInsertionPointToStart(&block);
234 ReplaceImplicitInputs(&block, /*offset=*/0, implicit_inputs_ref, builder);
235 return tuple_input;
236 }
237
238 // Replaces block terminator (tf.Yield) with `mhlo.return`. Additional results
239 // can be returned if `extra_results` is not empty. If `tuple_return` is
240 // set, a tuple of the return values will be set as the terminator operand.
ReplaceTerminator(Block * block,ArrayRef<Value> extra_results,OpBuilder * builder,bool tuple_return=true)241 void ReplaceTerminator(Block* block, ArrayRef<Value> extra_results,
242 OpBuilder* builder, bool tuple_return = true) {
243 Operation* terminator = block->getTerminator();
244 assert(isa<TF::YieldOp>(terminator));
245 Location loc = terminator->getLoc();
246
247 builder->setInsertionPoint(terminator);
248 auto results = llvm::to_vector<4>(terminator->getOperands());
249 results.append(extra_results.begin(), extra_results.end());
250 if (tuple_return) {
251 auto tuple_results = builder->create<mhlo::TupleOp>(loc, results);
252 builder->create<mhlo::ReturnOp>(loc, tuple_results.getResult());
253 } else {
254 builder->create<mhlo::ReturnOp>(loc, results);
255 }
256
257 terminator->erase();
258 }
259
LowerIfRegion(TF::IfRegionOp op)260 void LowerIfRegion(TF::IfRegionOp op) {
261 Location loc = op.getLoc();
262 OpBuilder builder(op);
263
264 // Tuple implicit inputs per region and update terminators to return tuples.
265 builder.setInsertionPoint(op);
266 Value then_input = TupleImplicitInputs(op.then_branch(), loc, &builder);
267 ReplaceTerminator(&op.then_branch().front(), /*extra_results=*/{}, &builder);
268
269 builder.setInsertionPoint(op);
270 Value else_input = TupleImplicitInputs(op.else_branch(), loc, &builder);
271 ReplaceTerminator(&op.else_branch().front(), /*extra_results=*/{}, &builder);
272
273 // Create the new `mhlo.if` op with tuple inputs and take ownership of regions
274 // from `tf.IfRegion` op.
275 builder.setInsertionPoint(op);
276 auto result_type = builder.getTupleType(op.getResultTypes());
277 auto if_op = builder.create<mhlo::IfOp>(loc, result_type, op.cond(),
278 then_input, else_input);
279 if_op.true_branch().takeBody(op.then_branch());
280 if_op.false_branch().takeBody(op.else_branch());
281
282 // De-tuple the results of the `mhlo.if`.
283 Detuple(if_op.getResult(), op.getResults(), &builder);
284 op.erase();
285 }
286
LowerCaseRegion(TF::CaseRegionOp op)287 void LowerCaseRegion(TF::CaseRegionOp op) {
288 Location loc = op.getLoc();
289 OpBuilder builder(op);
290
291 llvm::SmallVector<Value, 4> branch_inputs;
292 branch_inputs.reserve(op.branches().size());
293 // Tuple implicit inputs per region and update terminators.
294 for (Region& region : op.branches()) {
295 builder.setInsertionPoint(op);
296 Value branch_input = TupleImplicitInputs(region, loc, &builder);
297 branch_inputs.emplace_back(branch_input);
298 ReplaceTerminator(®ion.front(), /*extra_results=*/{}, &builder,
299 /*tuple_return=*/false);
300 }
301
302 // Create the new `mhlo.case` op with tuple inputs and take ownership of
303 // regions from `tf.CaseRegion` op.
304 builder.setInsertionPoint(op);
305 auto case_op =
306 builder.create<mhlo::CaseOp>(loc, op.getResultTypes(), op.branch_index(),
307 branch_inputs, branch_inputs.size());
308 for (auto region : llvm::zip(case_op.branches(), op.branches()))
309 std::get<0>(region).takeBody(std::get<1>(region));
310
311 op.replaceAllUsesWith(case_op.getResults());
312 op.erase();
313 }
314
LowerWhileRegion(TF::WhileRegionOp op)315 void LowerWhileRegion(TF::WhileRegionOp op) {
316 Location loc = op.getLoc();
317 OpBuilder builder(op);
318
319 // XLA prefers tuple arguments for control flow due to XLA not supporting
320 // multiple return values.
321 SmallVector<Value, 3> inputs(op.input());
322 const int inputs_size = inputs.size();
323 llvm::SetVector<Value> implicit_inputs;
324 getUsedValuesDefinedAbove(op.getOperation()->getRegions(), implicit_inputs);
325 inputs.append(implicit_inputs.begin(), implicit_inputs.end());
326
327 builder.setInsertionPoint(op);
328 Value tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
329
330 // Create the new `mhlo.while` op with tuple inputs. Implicit inputs are also
331 // returned.
332 auto while_result_types = llvm::to_vector<4>(op.getResultTypes());
333 while_result_types.reserve(while_result_types.size() +
334 implicit_inputs.size());
335 for (const auto& implicit_input : implicit_inputs)
336 while_result_types.emplace_back(implicit_input.getType());
337 auto while_op = builder.create<mhlo::WhileOp>(
338 loc, builder.getTupleType(while_result_types), tuple_input);
339
340 // Rewrite cond and associated block arguments and terminator. Ownership of
341 // cond region is transfered over from `tf.WhileRegion` to `mhlo.while`.
342 Region& cond = while_op.cond();
343 cond.takeBody(op.cond());
344 Block& cond_block = cond.front();
345 builder.setInsertionPointToStart(&cond_block);
346 ReplaceBlockArgs(&cond_block, tuple_input.getType(), &builder);
347 ReplaceImplicitInputs(&cond_block, inputs_size, implicit_inputs.getArrayRef(),
348 &builder);
349 // Cond always returns a single result of bool type.
350 ReplaceTerminator(&cond_block, /*extra_results=*/{}, &builder,
351 /*tuple_return=*/false);
352
353 // Rewrite body and associated block arguments and terminator. Ownership of
354 // body region is transfered over from `tf.WhileRegion` to `mhlo.while`.
355 Region& body = while_op.body();
356 body.takeBody(op.body());
357 Block& body_block = body.front();
358 builder.setInsertionPointToStart(&body_block);
359 ReplaceBlockArgs(&body_block, tuple_input.getType(), &builder);
360 // Capture implicit inputs that were added as a tuple block arguments. These
361 // are to be returned by the body in addition to explicit inputs.
362 auto implicit_input_elements = ReplaceImplicitInputs(
363 &body_block, inputs_size, implicit_inputs.getArrayRef(), &builder);
364 ReplaceTerminator(&body_block, implicit_input_elements, &builder);
365
366 // De-tuple the results of the `mhlo.while`.
367 builder.setInsertionPoint(op);
368 if (while_op.getNumResults() == 1 && while_op.getType(0).isa<TupleType>())
369 Detuple(while_op.getResult(0), op.getResults(), &builder);
370 else
371 op->replaceAllUsesWith(while_op);
372 op.erase();
373 }
374 } // namespace
375
runOnOperation()376 void LegalizeTFControlFlow::runOnOperation() {
377 getOperation().walk([&](Operation* op) {
378 if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
379 LowerWhile(while_op);
380 return;
381 }
382 if (auto while_region_op = dyn_cast<TF::WhileRegionOp>(op)) {
383 LowerWhileRegion(while_region_op);
384 return;
385 }
386 if (auto if_op = dyn_cast<TF::IfOp>(op)) {
387 LowerIf(if_op);
388 return;
389 }
390 if (auto if_region_op = dyn_cast<TF::IfRegionOp>(op)) {
391 LowerIfRegion(if_region_op);
392 return;
393 }
394 if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
395 LowerCase(case_op);
396 return;
397 }
398 if (auto case_region_op = dyn_cast<TF::CaseRegionOp>(op)) {
399 LowerCaseRegion(case_region_op);
400 return;
401 }
402 });
403 }
404 } // namespace mhlo
405 } // namespace mlir
406