1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass transforms region bases control flow operations in
17 // the TensorFlow dialect to their functional counterparts, i.e.,
18 // tf.IfRegion -> tf.If and tf.WhileRegion -> tf.While
19
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
24 #include "mlir/IR/Attributes.h" // from @llvm-project
25 #include "mlir/IR/Builders.h" // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
27 #include "mlir/IR/Operation.h" // from @llvm-project
28 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
29 #include "mlir/IR/Value.h" // from @llvm-project
30 #include "mlir/IR/Verifier.h" // from @llvm-project
31 #include "mlir/IR/Visitors.h" // from @llvm-project
32 #include "mlir/Pass/Pass.h" // from @llvm-project
33 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
34 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
38 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
40 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
41
42 #define DEBUG_TYPE "tf-region-cf-to-functional"
43
44 namespace mlir {
45 namespace TF {
46
47 namespace {
48
49 constexpr char kElseFuncNameAttr[] = "_else_func_name";
50 constexpr char kThenFuncNameAttr[] = "_then_func_name";
51 constexpr char kXlaPropagateCompileTimeConsts[] =
52 "_xla_propagate_compile_time_consts";
53
54 struct RegionControlFlowToFunctional
55 : public TF::RegionControlFlowToFunctionalPassBase<
56 RegionControlFlowToFunctional> {
57 void runOnOperation() override;
58
59 private:
60 LogicalResult ConvertIfOp(IfRegionOp if_region);
61 LogicalResult ConvertWhileOp(WhileRegionOp while_region);
62
63 // Get unique name by using the loc to name mapping.
64 std::string GetName(Operation* op, StringRef suffix);
65
66 tensorflow::OpOrArgLocNameMapper mapper;
67 llvm::SmallVector<FuncOp, 4> worklist;
68 };
69
GetName(Operation * op,StringRef suffix)70 std::string RegionControlFlowToFunctional::GetName(Operation* op,
71 StringRef suffix) {
72 return (mapper.GetUniqueName(op) + suffix).str();
73 }
74
75 // Returns all the external values referenced from the given regions. If the
76 // external value is a constant, sink it into the region instead (and do not
77 // add it to the returned vector).
CollectExternValues(Region & first,Region & second)78 llvm::SmallVector<Value, 4> CollectExternValues(Region& first, Region& second) {
79 llvm::SetVector<Value> extern_values;
80
81 for (Region* region : {&first, &second}) {
82 llvm::SetVector<Value> region_extern_values;
83 getUsedValuesDefinedAbove(*region, region_extern_values);
84
85 // Sink down constants into the functions.
86 for (auto extern_value : region_extern_values) {
87 if (!matchPattern(extern_value, m_Constant())) {
88 extern_values.insert(extern_value);
89 continue;
90 }
91 // Add constant at start of region.
92 auto const_builder = OpBuilder::atBlockBegin(®ion->front());
93 auto const_value = const_builder.clone(*extern_value.getDefiningOp());
94 replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
95 *region);
96 }
97 }
98
99 return llvm::to_vector<4>(extern_values);
100 }
101
102 // Copies over optional attributes from source region op `src` to the given
103 // functional op `dst` and appropriately overrides any necessary attributes.
CopyAndOverrideAttributes(Operation * src,Operation * dst,OpBuilder * builder)104 void CopyAndOverrideAttributes(Operation* src, Operation* dst,
105 OpBuilder* builder) {
106 CopyDeviceAndUnderscoredAttributes(src, dst);
107
108 // Explicitly override attribute to propagate constants to the functions
109 // before compiling to XLA. This is necessary along with conversion to
110 // functional format because inlined regions may have moved loop invariant ops
111 // outside of the region which may cause some new legalization failures.
112 // TODO(b/126739593): Enable this attribute in TensorFlow by default. Also,
113 // see b/185542519 for the context.
114 dst->setAttr(kXlaPropagateCompileTimeConsts, builder->getBoolAttr(true));
115 }
116
117 // Extracts the contents of a region with a single block into a new function.
118 // `extern_values` is the set of external values that the region refers to.
119 //
120 // Inputs to the terminator of the region are converted to return values of
121 // the function. If `extern_values_passthrough` is true, all the extern values
122 // are also added as return values from the function
ExtractSingleBlockRegion(Region & region,StringRef name,llvm::SmallVectorImpl<Value> & extern_values,llvm::SmallVectorImpl<FuncOp> & worklist,bool extern_values_passthrough)123 void ExtractSingleBlockRegion(Region& region, StringRef name,
124 llvm::SmallVectorImpl<Value>& extern_values,
125 llvm::SmallVectorImpl<FuncOp>& worklist,
126 bool extern_values_passthrough) {
127 ModuleOp module = region.getParentOfType<ModuleOp>();
128 auto builder = OpBuilder::atBlockBegin(module.getBody());
129 auto loc = region.getParentOp()->getLoc();
130 Block& entry = region.front();
131 int num_region_arguments = entry.getNumArguments();
132 Operation* terminator = entry.getTerminator();
133
134 // Build the function type. Region arguments and extern values together
135 // become the function arguments, with region arguments going first.
136 auto input_types = llvm::to_vector<4>(entry.getArgumentTypes());
137 for (auto input : extern_values) input_types.push_back(input.getType());
138
139 // Terminator operands and pass through extern values (if enabled) together
140 // become the function return values.
141 auto return_types = llvm::to_vector<4>(terminator->getOperandTypes());
142 if (extern_values_passthrough)
143 for (auto input : extern_values) return_types.push_back(input.getType());
144
145 auto type = FunctionType::get(region.getContext(), input_types, return_types);
146
147 // Create new function and extract region body into the function.
148 auto outlined_func = builder.create<FuncOp>(loc, name, type);
149 Region& func_region = outlined_func.getBody();
150 func_region.takeBody(region);
151 Block& first_block = func_region.front();
152
153 // Replace all external uses with function arguments.
154 for (auto it : llvm::enumerate(extern_values)) {
155 Value arg = first_block.addArgument(it.value().getType());
156 replaceAllUsesInRegionWith(it.value(), arg, func_region);
157 }
158
159 // Function return values are all the terminator operands + pass through
160 // extern values (if enabled).
161 auto return_values = llvm::to_vector<4>(terminator->getOperands());
162 if (extern_values_passthrough)
163 return_values.insert(return_values.end(),
164 first_block.args_begin() + num_region_arguments,
165 first_block.args_end());
166
167 // Replace the existing terminator with a return.
168 terminator = first_block.getTerminator();
169 builder.setInsertionPoint(terminator);
170 builder.create<ReturnOp>(terminator->getLoc(), return_values);
171 terminator->erase();
172
173 outlined_func.setPrivate();
174
175 // Add the outlined function to the worklist in case its body has
176 // IfRegion or WhileRegion ops that need to converted.
177 worklist.push_back(outlined_func);
178 }
179
180 // Returns call for region with single call whose result feeds into the
181 // terminator of the region. if `allow_to_bool` is true, also allows a single
182 // ToBoolOp between the region yield and the call. Returns none if the region
183 // does not conform to this pattern.
IsSingleCallRegion(Region & region,bool allow_to_bool=false)184 llvm::Optional<CallOp> IsSingleCallRegion(Region& region,
185 bool allow_to_bool = false) {
186 if (!llvm::hasSingleElement(region)) return llvm::None;
187
188 Block& block = region.front();
189 auto it = block.rbegin();
190 YieldOp yield = dyn_cast<YieldOp>(*it++);
191
192 if (it == block.rend()) return llvm::None;
193
194 // Operation which is expected to consume all the call results.
195 Operation* call_consumer = yield;
196
197 // Allow a single ToBoolOp between the call and the yield (valid only
198 // when the yield has a single operand)
199 if (allow_to_bool && yield.getNumOperands() == 1 && isa<ToBoolOp>(*it)) {
200 if (it->getResult(0) != yield.getOperand(0)) return llvm::None;
201 call_consumer = cast<ToBoolOp>(*it);
202 it++;
203 }
204
205 // Check if there is a Call before the Yield.
206 CallOp call = dyn_cast<CallOp>(*it++);
207 if (!call) return llvm::None;
208
209 // All call results should feed into expected consumer
210 // All results of the call should feed into the yield.
211 if (call.getNumResults() != call_consumer->getNumOperands())
212 return llvm::None;
213
214 for (auto res_it : llvm::zip(call.getResults(), call_consumer->getOperands()))
215 if (std::get<0>(res_it) != std::get<1>(res_it)) return llvm::None;
216
217 // There can only be non-truncating cast op's prior to the call.
218 for (; it != block.rend(); ++it) {
219 CastOp cast = dyn_cast<CastOp>(*it);
220 if (!cast || cast.Truncate()) return llvm::None;
221 }
222
223 return call;
224 }
225
226 using ArgMatcherFn = function_ref<bool(Value, Region&, Value, Region&)>;
227
228 // Returns whether the arguments of the given 2 calls are match (after looking
229 // through cast ops). `matcher` is the predicate used to check if two arguments
230 // match.
MatchCallArgs(CallOp first,CallOp second,ArgMatcherFn matcher)231 bool MatchCallArgs(CallOp first, CallOp second, ArgMatcherFn matcher) {
232 if (first.getNumOperands() != second.getNumOperands()) return false;
233
234 Region& first_region = *first->getParentRegion();
235 Region& second_region = *second->getParentRegion();
236
237 for (auto it : llvm::zip(first.getArgOperands(), second.getArgOperands())) {
238 // Get the defining Op, skipping over casts.
239 auto get_defining_op = [](Value value) {
240 while (auto cast_op =
241 llvm::dyn_cast_or_null<CastOp>(value.getDefiningOp())) {
242 // Consider cast compatibility in case
243 // %cast = "tf.Cast"(%0) : (tensor<2xi64>) -> tensor<2xf32>
244 // is skipped.
245 if (cast_op.SrcT() != cast_op.DstT()) {
246 break;
247 }
248 value = cast_op.getOperand();
249 }
250 return value;
251 };
252 Value first_arg = get_defining_op(std::get<0>(it));
253 Value second_arg = get_defining_op(std::get<1>(it));
254
255 if (!matcher(first_arg, first_region, second_arg, second_region))
256 return false;
257 }
258 return true;
259 }
260
261 // Summary information for trivially transforming region based op's to
262 // functional ops. A trivial transformation can be done when the regions are
263 // just calls to functions, in which case no outlining is needed.
264 struct TrivialTransformInfo {
265 // Can the op be transformed trivially?
266 bool can_transform = false;
267
268 // List of callee names (one for each region).
269 llvm::SmallVector<StringRef, 2> callee_names;
270
271 // Analyzes the given calls (from regions attached to the same parent op) to
272 // check if the parent op be transformed to functional form trivially (i.e.,
273 // reusing existing functions and without outlining). This is possible when
274 // all the regions are single call regions (checked using matchers outside
275 // this class) and the all the calls match using the given argument matcher.
276 //
277 // If such a trivial transformation is possible, stash the relevant
278 // information needed for the transformation, else indicate that a trivial
279 // transformation is not possible by setting `can_transform` to false.
TrivialTransformInfomlir::TF::__anonc19eae910111::TrivialTransformInfo280 TrivialTransformInfo(llvm::Optional<CallOp> first_call,
281 llvm::Optional<CallOp> second_call,
282 ArgMatcherFn arg_matcher) {
283 if (!first_call || !second_call) return;
284
285 if (!MatchCallArgs(first_call.getValue(), second_call.getValue(),
286 arg_matcher))
287 return;
288
289 can_transform = true;
290 callee_names = {first_call.getValue().getCallee(),
291 second_call.getValue().getCallee()};
292 }
293 };
294
295 // Transform IfRegionOp to IfOp.
ConvertIfOp(IfRegionOp if_region)296 LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) {
297 llvm::SmallVector<Value, 4> extern_values;
298
299 // For IfOp, arguments of calls in the then and else regions match if they
300 // are the same value.
301 auto if_arg_matcher = [&](Value first, Region&, Value second, Region&) {
302 if (first != second) return false;
303
304 // collect the call arguments post lookup through cast Op's
305 extern_values.push_back(first);
306 return true;
307 };
308
309 const TrivialTransformInfo tti(IsSingleCallRegion(if_region.then_branch()),
310 IsSingleCallRegion(if_region.else_branch()),
311 if_arg_matcher);
312
313 std::string then_name, else_name;
314
315 if (tti.can_transform) {
316 // We can transform to functional form trivially without outlining.
317 then_name = tti.callee_names[0].str();
318 else_name = tti.callee_names[1].str();
319 } else {
320 // Collect external values that are used within the else and then bodies.
321 extern_values =
322 CollectExternValues(if_region.then_branch(), if_region.else_branch());
323
324 // These external values need to be added as inputs to the generated If. The
325 // order is determined by the order of these values the `extern_vales`.
326
327 // Create 2 new functions with the input signature matching this order,
328 // and outline the `then` and `else` regions by moving the bodies of these
329 // regions into these functions. Replace tf.yield with a regular return.
330 if (if_region->hasAttrOfType<StringAttr>(kThenFuncNameAttr) &&
331 !if_region._then_func_nameAttr().getValue().empty()) {
332 then_name =
333 mapper.GetUniqueName(if_region._then_func_nameAttr().getValue())
334 .str();
335 } else {
336 then_name = GetName(if_region, "_then");
337 }
338 ExtractSingleBlockRegion(if_region.then_branch(), then_name, extern_values,
339 worklist, /*extern_values_passthrough=*/false);
340
341 if (if_region->hasAttrOfType<StringAttr>(kElseFuncNameAttr) &&
342 !if_region._else_func_nameAttr().getValue().empty()) {
343 else_name =
344 mapper.GetUniqueName(if_region._else_func_nameAttr().getValue())
345 .str();
346 } else {
347 else_name = GetName(if_region, "_else");
348 }
349 ExtractSingleBlockRegion(if_region.else_branch(), else_name, extern_values,
350 worklist, /*extern_values_passthrough=*/false);
351 }
352
353 // Look through ToBool operations for the condition.
354 Value cond = if_region.cond();
355 auto to_bool = dyn_cast_or_null<ToBoolOp>(cond.getDefiningOp());
356 if (to_bool) cond = to_bool.getOperand();
357
358 // Once we have the `then` and `else` functions ready (either outlined or
359 // existing ones), replace the region based op with a functional control flow
360 // op.
361 OpBuilder builder(if_region);
362 auto if_op = builder.create<IfOp>(
363 if_region.getLoc(), if_region.getResultTypes(), cond, extern_values,
364 then_name, else_name, if_region.is_stateless());
365 CopyAndOverrideAttributes(if_region, if_op, &builder);
366
367 if_region.replaceAllUsesWith(if_op.getResults());
368 if_region.erase();
369
370 if (to_bool && to_bool.use_empty()) to_bool.erase();
371 return success();
372 }
373
374 // Transform WhileRegion to WhileOp.
ConvertWhileOp(WhileRegionOp while_region)375 LogicalResult RegionControlFlowToFunctional::ConvertWhileOp(
376 WhileRegionOp while_region) {
377 // For While, the arguments of the calls in the body and cond regions match
378 // if they are region arguments with the same region argument numbers. If the
379 // 2 calls have the same value (an extern value) used as an argument, we
380 // cannot do a trivial transformation because post transform, we will need to
381 // pass this extern value as an argument to the function, so we cannot use the
382 // existing function as is.
383 auto while_arg_matcher = [](Value first, Region& first_region, Value second,
384 Region& second_region) {
385 if (!first.isa<BlockArgument>() || !second.isa<BlockArgument>())
386 return false;
387 BlockArgument first_block_arg = first.cast<BlockArgument>();
388 BlockArgument second_block_arg = second.cast<BlockArgument>();
389
390 // 2 block arguments will match if they are the same argument number, and
391 // are block arguments of the corresponding containing regions.
392 return first_block_arg.getArgNumber() == second_block_arg.getArgNumber() &&
393 first_block_arg.getParentBlock() == &first_region.front() &&
394 second_block_arg.getParentBlock() == &second_region.front();
395 };
396
397 const TrivialTransformInfo tti(
398 IsSingleCallRegion(while_region.cond(), /*allow_to_bool=*/true),
399 IsSingleCallRegion(while_region.body()), while_arg_matcher);
400
401 // All existing inputs to while region are inputs to the functional while.
402 auto new_inputs = llvm::to_vector<4>(while_region.getOperands());
403
404 // All existing results will also be generated by the functional while.
405 auto new_result_types = llvm::to_vector<4>(while_region.getResultTypes());
406
407 std::string cond_name, body_name;
408 if (tti.can_transform) {
409 // We can transform to functional form trivially without outlining.
410 cond_name = tti.callee_names[0].str();
411 body_name = tti.callee_names[1].str();
412 } else {
413 // The WhileRegion regions can refer to either arguments of the region, or
414 // external values implicitly captured by the region. When converting to
415 // functional form, all such external values need to become function
416 // arguments of the outlined functions, and become pass through values in
417 // the outlined body function. So when outlining the while body, in addition
418 // to the region arguments, all these external references need to be added
419 // as function arguments.
420 llvm::SmallVector<Value, 4> extern_values =
421 CollectExternValues(while_region.cond(), while_region.body());
422
423 // Outline the `cond` and `body` regions by moving the bodies of these
424 // regions into new functions. Replace tf.yield with a regular return.
425 cond_name = GetName(while_region, "_cond");
426 ExtractSingleBlockRegion(while_region.cond(), cond_name, extern_values,
427 worklist, /*extern_values_passthrough=*/false);
428
429 body_name = GetName(while_region, "_body");
430 ExtractSingleBlockRegion(while_region.body(), body_name, extern_values,
431 worklist, /*extern_values_passthrough=*/true);
432
433 // All extern values become additional inputs and additional output types
434 // for the functional while.
435 new_inputs.append(extern_values.begin(), extern_values.end());
436 for (auto ext : extern_values) new_result_types.push_back(ext.getType());
437 }
438
439 // Once we have the `cond` and `body` functions ready (either outlined or
440 // existing ones), replace the region based op with a functional op.
441 OpBuilder builder(while_region);
442 auto while_op = builder.create<WhileOp>(
443 while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name,
444 while_region.parallel_iterations(), while_region.is_stateless(),
445 while_region.shape_invariant());
446 CopyAndOverrideAttributes(while_region, while_op, &builder);
447
448 // Redirect old results to new results.
449 for (auto it : llvm::zip(
450 while_region.getResults(),
451 while_op.getResults().take_front(while_region.getNumResults())))
452 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
453
454 while_region.erase();
455 return success();
456 }
457
runOnOperation()458 void RegionControlFlowToFunctional::runOnOperation() {
459 ModuleOp module = getOperation();
460
461 // Seed worklist with all functions in the module.
462 worklist = llvm::to_vector<4>(module.getOps<FuncOp>());
463 while (!worklist.empty()) {
464 FuncOp function = worklist.pop_back_val();
465
466 auto result = function.walk([&](Operation* op) {
467 if (auto if_region = llvm::dyn_cast<IfRegionOp>(op)) {
468 if (failed(ConvertIfOp(if_region))) {
469 op->emitOpError() << "failed to convert to functional form";
470 return WalkResult::interrupt();
471 }
472 } else if (auto while_region = llvm::dyn_cast<WhileRegionOp>(op)) {
473 if (failed(ConvertWhileOp(while_region))) {
474 op->emitOpError() << "failed to convert to functional form";
475 return WalkResult::interrupt();
476 }
477 }
478 return WalkResult::advance();
479 });
480
481 if (result.wasInterrupted()) return signalPassFailure();
482 }
483 }
484
485 } // namespace
486
487 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFRegionControlFlowToFunctional()488 CreateTFRegionControlFlowToFunctional() {
489 return std::make_unique<RegionControlFlowToFunctional>();
490 }
491
492 } // namespace TF
493 } // namespace mlir
494