1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering TensorFlow dialect's control flow to
17 // the XLA dialect.
18
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23 #include <tuple>
24
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32 #include "mlir/IR/Operation.h" // from @llvm-project
33 #include "mlir/IR/OperationSupport.h" // from @llvm-project
34 #include "mlir/IR/Types.h" // from @llvm-project
35 #include "mlir/Pass/Pass.h" // from @llvm-project
36 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
37 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40
41 using mlir::PassRegistration;
42
43 namespace mlir {
44 namespace mhlo {
45 namespace {
46 class LegalizeTFControlFlow
47 : public PassWrapper<LegalizeTFControlFlow, OperationPass<ModuleOp>> {
48 public:
49 void runOnOperation() override;
50 };
51 } // namespace
52
53 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createLegalizeTFControlFlowPass()54 createLegalizeTFControlFlowPass() {
55 return std::make_unique<LegalizeTFControlFlow>();
56 }
57
58 namespace {
59
Detuple(Value tuple,ValueRange replace,OpBuilder * builder)60 void Detuple(Value tuple, ValueRange replace, OpBuilder* builder) {
61 // De-tuple the results of the xla hlo if result.
62 for (auto result_it : llvm::enumerate(replace)) {
63 auto get_tuple_value = builder->create<mhlo::GetTupleElementOp>(
64 result_it.value().getLoc(), tuple, result_it.index());
65 result_it.value().replaceAllUsesWith(get_tuple_value);
66 }
67 }
68
69 // Imports the source region into the destination region. The XLA if
70 // operation only supports one argument per branch. Therefore any branch that
71 // requires additional arguments requires their values be tupled together. Then,
72 // to support multiple returns (as XLA only supports a single return value) the
73 // results of the if operation are tupled together.
ImportXlaRegion(mlir::FuncOp func,Region * dest_region,Location loc,bool tuple_return=true)74 void ImportXlaRegion(mlir::FuncOp func, Region* dest_region, Location loc,
75 bool tuple_return = true) {
76 OpBuilder builder(dest_region);
77
78 auto entry_block = builder.createBlock(dest_region);
79 auto tuple_arg = entry_block->addArgument(
80 builder.getTupleType(func.getType().getInputs()));
81 llvm::SmallVector<Value, 4> detupled_args;
82 detupled_args.reserve(func.getNumArguments());
83
84 for (int64_t i = 0, s = func.getNumArguments(); i < s; i++) {
85 auto extract = builder.create<GetTupleElementOp>(loc, tuple_arg, i);
86 detupled_args.push_back(extract);
87 }
88
89 auto result = builder.create<CallOp>(loc, func, detupled_args).getResults();
90 if (!tuple_return) {
91 builder.create<mhlo::ReturnOp>(loc, result);
92 } else {
93 auto tuple_op = builder.create<TupleOp>(loc, result);
94 builder.create<mhlo::ReturnOp>(loc, tuple_op.getResult());
95 }
96 }
97
LowerIf(TF::IfOp op)98 void LowerIf(TF::IfOp op) {
99 Location loc = op.getLoc();
100 OpBuilder builder(op);
101
102 // XLA prefers tuple arguments for control flow due to XLA not supporting
103 // multiple return values.
104 SmallVector<Value, 3> inputs(op.input());
105 auto tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
106
107 // Create the new `mhlo.if` op with tuple inputs.
108 auto result_type = builder.getTupleType(op.getResultTypes());
109 auto if_op = builder.create<mhlo::IfOp>(loc, result_type, op.cond(),
110 tuple_input, tuple_input);
111
112 // Import the regions for both the true and false cases. These regions
113 // must be updated to tuple the return results together and use the xla hlo
114 // return op.
115 ImportXlaRegion(op.then_function(), &if_op.true_branch(), loc);
116 ImportXlaRegion(op.else_function(), &if_op.false_branch(), loc);
117
118 // De-tuple the results of the `mhlo.if`.
119 Detuple(if_op.getResult(), op.getResults(), &builder);
120 op.erase();
121 }
122
LowerCase(TF::CaseOp op)123 void LowerCase(TF::CaseOp op) {
124 Location loc = op.getLoc();
125 OpBuilder builder(op);
126
127 // XLA requires one argument per branch so we create a tuple of inputs to pass
128 // to each branch.
129 SmallVector<Value, 4> inputs(op.input());
130 auto tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
131
132 // Create replica of input tuple for each branch
133 SmallVector<Value, 4> n_tuple_inputs(op.num_branches(), tuple_input);
134
135 // Create the new `mhlo.case` op with tuple inputs.
136 auto case_op =
137 builder.create<mhlo::CaseOp>(loc, op.getResultTypes(), op.branch_index(),
138 n_tuple_inputs, op.branches().size());
139
140 // Import the regions for all branches.
141 for (unsigned i = 0; i < op.num_branches(); ++i) {
142 mlir::FuncOp branch_func = op.branch_function(i);
143 ImportXlaRegion(branch_func, &case_op.branches()[i], loc,
144 /*tuple_return=*/false);
145 }
146
147 op.replaceAllUsesWith(case_op.getResults());
148 op.erase();
149 }
150
LowerWhile(TF::WhileOp op)151 void LowerWhile(TF::WhileOp op) {
152 Location loc = op.getLoc();
153 OpBuilder builder(op);
154
155 // XLA prefers tuple arguments for control flow due to XLA not supporting
156 // multiple return values.
157 SmallVector<Value, 3> inputs(op.input());
158 builder.setInsertionPoint(op);
159 Value tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
160
161 // Create the new `mhlo.while` op with tuple inputs.
162 auto while_op = builder.create<mhlo::WhileOp>(
163 loc, builder.getTupleType(op.getResultTypes()), tuple_input);
164
165 // Import the regions for both the cond and body. These regions must be
166 // updated to tuple the return results together and use the xla hlo return op.
167 ImportXlaRegion(op.body_function(), &while_op.body(), loc);
168 ImportXlaRegion(op.cond_function(), &while_op.cond(), loc,
169 /*tuple_return=*/false);
170
171 // De-tuple the results of the `mhlo.while`.
172 Detuple(while_op.getResult(), op.getResults(), &builder);
173 op.erase();
174 }
175
176 // Replaces all block arguments of a block with a single block arg of Tuple
177 // type `tuple_type`. Single block arguments are removed and remapped to
178 // get_tuple_element(tuple_arg, index).
ReplaceBlockArgs(Block * block,Type tuple_type,OpBuilder * builder)179 void ReplaceBlockArgs(Block* block, Type tuple_type, OpBuilder* builder) {
180 auto tuple_arg = block->addArgument(tuple_type);
181 Detuple(tuple_arg, block->getArguments().drop_back(1), builder);
182 for (int i = block->getNumArguments() - 2; i >= 0; --i)
183 block->eraseArgument(i);
184 }
185
186 // Replaces implicitly captured value uses with tuple block argument.
187 // get_tuple_element's are created to extract specific values. Values from
188 // get_tuple_element's are returned in the order of `implicit_inputs`.
ReplaceImplicitInputs(Block * block,int offset,ArrayRef<Value> implicit_inputs,OpBuilder * builder)189 llvm::SmallVector<Value, 4> ReplaceImplicitInputs(
190 Block* block, int offset, ArrayRef<Value> implicit_inputs,
191 OpBuilder* builder) {
192 llvm::SmallVector<Value, 4> implicit_input_elements;
193 implicit_input_elements.reserve(implicit_inputs.size());
194
195 Region* region = block->getParent();
196 assert(block->getNumArguments() == 1);
197
198 BlockArgument tuple_arg = block->getArgument(0);
199 for (auto& implicit_input : llvm::enumerate(implicit_inputs)) {
200 Value implicit_input_value = implicit_input.value();
201 auto get_tuple_element = builder->create<mhlo::GetTupleElementOp>(
202 implicit_input_value.getLoc(), tuple_arg,
203 implicit_input.index() + offset);
204 implicit_input_elements.emplace_back(get_tuple_element.getResult());
205 for (auto& use :
206 llvm::make_early_inc_range(implicit_input_value.getUses())) {
207 if (!region->isAncestor(use.getOwner()->getParentRegion())) continue;
208 use.set(get_tuple_element.getResult());
209 }
210 }
211
212 return implicit_input_elements;
213 }
214
215 // Finds and replaces implicitly captured value uses with tuple block argument.
216 // A tuple of implicitly captured values is also created and returned, for use
217 // as an operand to the associated mhlo control flow op.
TupleImplicitInputs(Region & region,Location loc,OpBuilder * builder)218 Value TupleImplicitInputs(Region& region, Location loc, OpBuilder* builder) {
219 llvm::SetVector<Value> implicit_inputs;
220 getUsedValuesDefinedAbove(region, region, implicit_inputs);
221 llvm::ArrayRef<Value> implicit_inputs_ref = implicit_inputs.getArrayRef();
222 Value tuple_input = builder->create<mhlo::TupleOp>(loc, implicit_inputs_ref);
223 Block& block = region.front();
224 // `tf.CaseRegion`/`tf.IfRegion` are expected to have no block arguments and
225 // instead all inputs used by their branch regions are implicitly captured
226 // from above.
227 assert(block.getNumArguments() == 0);
228 block.addArgument(tuple_input.getType());
229 builder->setInsertionPointToStart(&block);
230 ReplaceImplicitInputs(&block, /*offset=*/0, implicit_inputs_ref, builder);
231 return tuple_input;
232 }
233
234 // Replaces block terminator (tf.Yield) with `mhlo.return`. Additional results
235 // can be returned if `extra_results` is not empty. If `tuple_return` is
236 // set, a tuple of the return values will be set as the terminator operand.
ReplaceTerminator(Block * block,ArrayRef<Value> extra_results,OpBuilder * builder,bool tuple_return=true)237 void ReplaceTerminator(Block* block, ArrayRef<Value> extra_results,
238 OpBuilder* builder, bool tuple_return = true) {
239 Operation* terminator = block->getTerminator();
240 assert(isa<TF::YieldOp>(terminator));
241 Location loc = terminator->getLoc();
242
243 builder->setInsertionPoint(terminator);
244 auto results = llvm::to_vector<4>(terminator->getOperands());
245 results.append(extra_results.begin(), extra_results.end());
246 if (tuple_return) {
247 auto tuple_results = builder->create<mhlo::TupleOp>(loc, results);
248 builder->create<mhlo::ReturnOp>(loc, tuple_results.getResult());
249 } else {
250 builder->create<mhlo::ReturnOp>(loc, results);
251 }
252
253 terminator->erase();
254 }
255
LowerIfRegion(TF::IfRegionOp op)256 void LowerIfRegion(TF::IfRegionOp op) {
257 Location loc = op.getLoc();
258 OpBuilder builder(op);
259
260 // Tuple implicit inputs per region and update terminators to return tuples.
261 builder.setInsertionPoint(op);
262 Value then_input = TupleImplicitInputs(op.then_branch(), loc, &builder);
263 ReplaceTerminator(&op.then_branch().front(), /*extra_results=*/{}, &builder);
264
265 builder.setInsertionPoint(op);
266 Value else_input = TupleImplicitInputs(op.else_branch(), loc, &builder);
267 ReplaceTerminator(&op.else_branch().front(), /*extra_results=*/{}, &builder);
268
269 // Create the new `mhlo.if` op with tuple inputs and take ownership of regions
270 // from `tf.IfRegion` op.
271 builder.setInsertionPoint(op);
272 auto result_type = builder.getTupleType(op.getResultTypes());
273 auto if_op = builder.create<mhlo::IfOp>(loc, result_type, op.cond(),
274 then_input, else_input);
275 if_op.true_branch().takeBody(op.then_branch());
276 if_op.false_branch().takeBody(op.else_branch());
277
278 // De-tuple the results of the `mhlo.if`.
279 Detuple(if_op.getResult(), op.getResults(), &builder);
280 op.erase();
281 }
282
LowerCaseRegion(TF::CaseRegionOp op)283 void LowerCaseRegion(TF::CaseRegionOp op) {
284 Location loc = op.getLoc();
285 OpBuilder builder(op);
286
287 llvm::SmallVector<Value, 4> branch_inputs;
288 branch_inputs.reserve(op.branches().size());
289 // Tuple implicit inputs per region and update terminators.
290 for (Region& region : op.branches()) {
291 builder.setInsertionPoint(op);
292 Value branch_input = TupleImplicitInputs(region, loc, &builder);
293 branch_inputs.emplace_back(branch_input);
294 ReplaceTerminator(®ion.front(), /*extra_results=*/{}, &builder,
295 /*tuple_return=*/false);
296 }
297
298 // Create the new `mhlo.case` op with tuple inputs and take ownership of
299 // regions from `tf.CaseRegion` op.
300 builder.setInsertionPoint(op);
301 auto case_op =
302 builder.create<mhlo::CaseOp>(loc, op.getResultTypes(), op.branch_index(),
303 branch_inputs, branch_inputs.size());
304 for (auto region : llvm::zip(case_op.branches(), op.branches()))
305 std::get<0>(region).takeBody(std::get<1>(region));
306
307 op.replaceAllUsesWith(case_op.getResults());
308 op.erase();
309 }
310
LowerWhileRegion(TF::WhileRegionOp op)311 void LowerWhileRegion(TF::WhileRegionOp op) {
312 Location loc = op.getLoc();
313 OpBuilder builder(op);
314
315 // XLA prefers tuple arguments for control flow due to XLA not supporting
316 // multiple return values.
317 SmallVector<Value, 3> inputs(op.input());
318 const int inputs_size = inputs.size();
319 llvm::SetVector<Value> implicit_inputs;
320 getUsedValuesDefinedAbove(op.getOperation()->getRegions(), implicit_inputs);
321 inputs.append(implicit_inputs.begin(), implicit_inputs.end());
322
323 builder.setInsertionPoint(op);
324 Value tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
325
326 // Create the new `mhlo.while` op with tuple inputs. Implicit inputs are also
327 // returned.
328 auto while_result_types = llvm::to_vector<4>(op.getResultTypes());
329 while_result_types.reserve(while_result_types.size() +
330 implicit_inputs.size());
331 for (const auto& implicit_input : implicit_inputs)
332 while_result_types.emplace_back(implicit_input.getType());
333 auto while_op = builder.create<mhlo::WhileOp>(
334 loc, builder.getTupleType(while_result_types), tuple_input);
335
336 // Rewrite cond and associated block arguments and terminator. Ownership of
337 // cond region is transfered over from `tf.WhileRegion` to `mhlo.while`.
338 Region& cond = while_op.cond();
339 cond.takeBody(op.cond());
340 Block& cond_block = cond.front();
341 builder.setInsertionPointToStart(&cond_block);
342 ReplaceBlockArgs(&cond_block, tuple_input.getType(), &builder);
343 ReplaceImplicitInputs(&cond_block, inputs_size, implicit_inputs.getArrayRef(),
344 &builder);
345 // Cond always returns a single result of bool type.
346 ReplaceTerminator(&cond_block, /*extra_results=*/{}, &builder,
347 /*tuple_return=*/false);
348
349 // Rewrite body and associated block arguments and terminator. Ownership of
350 // body region is transfered over from `tf.WhileRegion` to `mhlo.while`.
351 Region& body = while_op.body();
352 body.takeBody(op.body());
353 Block& body_block = body.front();
354 builder.setInsertionPointToStart(&body_block);
355 ReplaceBlockArgs(&body_block, tuple_input.getType(), &builder);
356 // Capture implicit inputs that were added as a tuple block arguments. These
357 // are to be returned by the body in addition to explicit inputs.
358 auto implicit_input_elements = ReplaceImplicitInputs(
359 &body_block, inputs_size, implicit_inputs.getArrayRef(), &builder);
360 ReplaceTerminator(&body_block, implicit_input_elements, &builder);
361
362 // De-tuple the results of the `mhlo.while`.
363 builder.setInsertionPoint(op);
364 Detuple(while_op.getResult(), op.getResults(), &builder);
365 op.erase();
366 }
367 } // namespace
368
runOnOperation()369 void LegalizeTFControlFlow::runOnOperation() {
370 getOperation().walk([&](Operation* op) {
371 if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
372 LowerWhile(while_op);
373 return;
374 }
375 if (auto while_region_op = dyn_cast<TF::WhileRegionOp>(op)) {
376 LowerWhileRegion(while_region_op);
377 return;
378 }
379 if (auto if_op = dyn_cast<TF::IfOp>(op)) {
380 LowerIf(if_op);
381 return;
382 }
383 if (auto if_region_op = dyn_cast<TF::IfRegionOp>(op)) {
384 LowerIfRegion(if_region_op);
385 return;
386 }
387 if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
388 LowerCase(case_op);
389 return;
390 }
391 if (auto case_region_op = dyn_cast<TF::CaseRegionOp>(op)) {
392 LowerCaseRegion(case_region_op);
393 return;
394 }
395 });
396 }
397 } // namespace mhlo
398 } // namespace mlir
399
400 static PassRegistration<mlir::mhlo::LegalizeTFControlFlow> cfpass(
401 "xla-legalize-tf-control-flow",
402 "Legalize TensorFlow control flow to the XLA dialect");
403