• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass transforms region bases control flow operations in
17 // the TensorFlow dialect to their functional counterparts, i.e.,
18 // tf.IfRegion ->  tf.If and tf.WhileRegion -> tf.While
19 
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/Builders.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
27 #include "mlir/IR/Operation.h"  // from @llvm-project
28 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
29 #include "mlir/IR/Value.h"  // from @llvm-project
30 #include "mlir/IR/Verifier.h"  // from @llvm-project
31 #include "mlir/IR/Visitors.h"  // from @llvm-project
32 #include "mlir/Pass/Pass.h"  // from @llvm-project
33 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
34 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
38 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
40 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
41 
42 #define DEBUG_TYPE "tf-region-cf-to-functional"
43 
44 namespace mlir {
45 namespace TF {
46 
47 namespace {
48 
49 constexpr char kElseFuncNameAttr[] = "_else_func_name";
50 constexpr char kThenFuncNameAttr[] = "_then_func_name";
51 constexpr char kXlaPropagateCompileTimeConsts[] =
52     "_xla_propagate_compile_time_consts";
53 
54 struct RegionControlFlowToFunctional
55     : public TF::RegionControlFlowToFunctionalPassBase<
56           RegionControlFlowToFunctional> {
57   void runOnOperation() override;
58 
59  private:
60   LogicalResult ConvertIfOp(IfRegionOp if_region);
61   LogicalResult ConvertWhileOp(WhileRegionOp while_region);
62 
63   // Get unique name by using the loc to name mapping.
64   std::string GetName(Operation* op, StringRef suffix);
65 
66   tensorflow::OpOrArgLocNameMapper mapper;
67   llvm::SmallVector<FuncOp, 4> worklist;
68 };
69 
GetName(Operation * op,StringRef suffix)70 std::string RegionControlFlowToFunctional::GetName(Operation* op,
71                                                    StringRef suffix) {
72   return (mapper.GetUniqueName(op) + suffix).str();
73 }
74 
75 // Returns all the external values referenced from the given regions. If the
76 // external value is a constant, sink it into the region instead (and do not
77 // add it to the returned vector).
CollectExternValues(Region & first,Region & second)78 llvm::SmallVector<Value, 4> CollectExternValues(Region& first, Region& second) {
79   llvm::SetVector<Value> extern_values;
80 
81   for (Region* region : {&first, &second}) {
82     llvm::SetVector<Value> region_extern_values;
83     getUsedValuesDefinedAbove(*region, region_extern_values);
84 
85     // Sink down constants into the functions.
86     for (auto extern_value : region_extern_values) {
87       if (!matchPattern(extern_value, m_Constant())) {
88         extern_values.insert(extern_value);
89         continue;
90       }
91       // Add constant at start of region.
92       auto const_builder = OpBuilder::atBlockBegin(&region->front());
93       auto const_value = const_builder.clone(*extern_value.getDefiningOp());
94       replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
95                                  *region);
96     }
97   }
98 
99   return llvm::to_vector<4>(extern_values);
100 }
101 
102 // Copies over optional attributes from source region op `src` to the given
103 // functional op `dst` and appropriately overrides any necessary attributes.
CopyAndOverrideAttributes(Operation * src,Operation * dst,OpBuilder * builder)104 void CopyAndOverrideAttributes(Operation* src, Operation* dst,
105                                OpBuilder* builder) {
106   CopyDeviceAndUnderscoredAttributes(src, dst);
107 
108   // Explicitly override attribute to propagate constants to the functions
109   // before compiling to XLA. This is necessary along with conversion to
110   // functional format because inlined regions may have moved loop invariant ops
111   // outside of the region which may cause some new legalization failures.
112   // TODO(b/126739593): Enable this attribute in TensorFlow by default. Also,
113   // see b/185542519 for the context.
114   dst->setAttr(kXlaPropagateCompileTimeConsts, builder->getBoolAttr(true));
115 }
116 
117 // Extracts the contents of a region with a single block into a new function.
118 // `extern_values` is the set of external values that the region refers to.
119 //
120 // Inputs to the terminator of the region are converted to return values of
121 // the function. If `extern_values_passthrough` is true, all the extern values
122 // are also added as return values from the function
ExtractSingleBlockRegion(Region & region,StringRef name,llvm::SmallVectorImpl<Value> & extern_values,llvm::SmallVectorImpl<FuncOp> & worklist,bool extern_values_passthrough)123 void ExtractSingleBlockRegion(Region& region, StringRef name,
124                               llvm::SmallVectorImpl<Value>& extern_values,
125                               llvm::SmallVectorImpl<FuncOp>& worklist,
126                               bool extern_values_passthrough) {
127   ModuleOp module = region.getParentOfType<ModuleOp>();
128   auto builder = OpBuilder::atBlockBegin(module.getBody());
129   auto loc = region.getParentOp()->getLoc();
130   Block& entry = region.front();
131   int num_region_arguments = entry.getNumArguments();
132   Operation* terminator = entry.getTerminator();
133 
134   // Build the function type. Region arguments and extern values together
135   // become the function arguments, with region arguments going first.
136   auto input_types = llvm::to_vector<4>(entry.getArgumentTypes());
137   for (auto input : extern_values) input_types.push_back(input.getType());
138 
139   // Terminator operands and pass through extern values (if enabled) together
140   // become the function return values.
141   auto return_types = llvm::to_vector<4>(terminator->getOperandTypes());
142   if (extern_values_passthrough)
143     for (auto input : extern_values) return_types.push_back(input.getType());
144 
145   auto type = FunctionType::get(region.getContext(), input_types, return_types);
146 
147   // Create new function and extract region body into the function.
148   auto outlined_func = builder.create<FuncOp>(loc, name, type);
149   Region& func_region = outlined_func.getBody();
150   func_region.takeBody(region);
151   Block& first_block = func_region.front();
152 
153   // Replace all external uses with function arguments.
154   for (auto it : llvm::enumerate(extern_values)) {
155     Value arg = first_block.addArgument(it.value().getType());
156     replaceAllUsesInRegionWith(it.value(), arg, func_region);
157   }
158 
159   // Function return values are all the terminator operands + pass through
160   // extern values (if enabled).
161   auto return_values = llvm::to_vector<4>(terminator->getOperands());
162   if (extern_values_passthrough)
163     return_values.insert(return_values.end(),
164                          first_block.args_begin() + num_region_arguments,
165                          first_block.args_end());
166 
167   // Replace the existing terminator with a return.
168   terminator = first_block.getTerminator();
169   builder.setInsertionPoint(terminator);
170   builder.create<ReturnOp>(terminator->getLoc(), return_values);
171   terminator->erase();
172 
173   outlined_func.setPrivate();
174 
175   // Add the outlined function to the worklist in case its body has
176   // IfRegion or WhileRegion ops that need to converted.
177   worklist.push_back(outlined_func);
178 }
179 
180 // Returns call for region with single call whose result feeds into the
181 // terminator of the region. if `allow_to_bool` is true, also allows a single
182 // ToBoolOp between the region yield and the call. Returns none if the region
183 // does not conform to this pattern.
IsSingleCallRegion(Region & region,bool allow_to_bool=false)184 llvm::Optional<CallOp> IsSingleCallRegion(Region& region,
185                                           bool allow_to_bool = false) {
186   if (!llvm::hasSingleElement(region)) return llvm::None;
187 
188   Block& block = region.front();
189   auto it = block.rbegin();
190   YieldOp yield = dyn_cast<YieldOp>(*it++);
191 
192   if (it == block.rend()) return llvm::None;
193 
194   // Operation which is expected to consume all the call results.
195   Operation* call_consumer = yield;
196 
197   // Allow a single ToBoolOp between the call and the yield (valid only
198   // when the yield has a single operand)
199   if (allow_to_bool && yield.getNumOperands() == 1 && isa<ToBoolOp>(*it)) {
200     if (it->getResult(0) != yield.getOperand(0)) return llvm::None;
201     call_consumer = cast<ToBoolOp>(*it);
202     it++;
203   }
204 
205   // Check if there is a Call before the Yield.
206   CallOp call = dyn_cast<CallOp>(*it++);
207   if (!call) return llvm::None;
208 
209   // All call results should feed into expected consumer
210   // All results of the call should feed into the yield.
211   if (call.getNumResults() != call_consumer->getNumOperands())
212     return llvm::None;
213 
214   for (auto res_it : llvm::zip(call.getResults(), call_consumer->getOperands()))
215     if (std::get<0>(res_it) != std::get<1>(res_it)) return llvm::None;
216 
217   // There can only be non-truncating cast op's prior to the call.
218   for (; it != block.rend(); ++it) {
219     CastOp cast = dyn_cast<CastOp>(*it);
220     if (!cast || cast.Truncate()) return llvm::None;
221   }
222 
223   return call;
224 }
225 
226 using ArgMatcherFn = function_ref<bool(Value, Region&, Value, Region&)>;
227 
228 // Returns whether the arguments of the given 2 calls are match (after looking
229 // through cast ops). `matcher` is the predicate used to check if two arguments
230 // match.
MatchCallArgs(CallOp first,CallOp second,ArgMatcherFn matcher)231 bool MatchCallArgs(CallOp first, CallOp second, ArgMatcherFn matcher) {
232   if (first.getNumOperands() != second.getNumOperands()) return false;
233 
234   Region& first_region = *first->getParentRegion();
235   Region& second_region = *second->getParentRegion();
236 
237   for (auto it : llvm::zip(first.getArgOperands(), second.getArgOperands())) {
238     // Get the defining Op, skipping over casts.
239     auto get_defining_op = [](Value value) {
240       while (auto cast_op =
241                  llvm::dyn_cast_or_null<CastOp>(value.getDefiningOp())) {
242         // Consider cast compatibility in case
243         //    %cast = "tf.Cast"(%0) : (tensor<2xi64>) -> tensor<2xf32>
244         // is skipped.
245         if (cast_op.SrcT() != cast_op.DstT()) {
246           break;
247         }
248         value = cast_op.getOperand();
249       }
250       return value;
251     };
252     Value first_arg = get_defining_op(std::get<0>(it));
253     Value second_arg = get_defining_op(std::get<1>(it));
254 
255     if (!matcher(first_arg, first_region, second_arg, second_region))
256       return false;
257   }
258   return true;
259 }
260 
261 // Summary information for trivially transforming region based op's to
262 // functional ops. A trivial transformation can be done when the regions are
263 // just calls to functions, in which case no outlining is needed.
264 struct TrivialTransformInfo {
265   // Can the op be transformed trivially?
266   bool can_transform = false;
267 
268   // List of callee names (one for each region).
269   llvm::SmallVector<StringRef, 2> callee_names;
270 
271   // Analyzes the given calls (from regions attached to the same parent op) to
272   // check if the parent op be transformed to functional form trivially (i.e.,
273   // reusing existing functions and without outlining). This is possible when
274   // all the regions are single call regions (checked using matchers outside
275   // this class) and the all the calls match using the given argument matcher.
276   //
277   // If such a trivial transformation is possible, stash the relevant
278   // information needed for the transformation, else indicate that a trivial
279   // transformation is not possible by setting `can_transform` to false.
TrivialTransformInfomlir::TF::__anonc19eae910111::TrivialTransformInfo280   TrivialTransformInfo(llvm::Optional<CallOp> first_call,
281                        llvm::Optional<CallOp> second_call,
282                        ArgMatcherFn arg_matcher) {
283     if (!first_call || !second_call) return;
284 
285     if (!MatchCallArgs(first_call.getValue(), second_call.getValue(),
286                        arg_matcher))
287       return;
288 
289     can_transform = true;
290     callee_names = {first_call.getValue().getCallee(),
291                     second_call.getValue().getCallee()};
292   }
293 };
294 
295 // Transform IfRegionOp to IfOp.
ConvertIfOp(IfRegionOp if_region)296 LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) {
297   llvm::SmallVector<Value, 4> extern_values;
298 
299   // For IfOp, arguments of calls in the then and else regions match if they
300   // are the same value.
301   auto if_arg_matcher = [&](Value first, Region&, Value second, Region&) {
302     if (first != second) return false;
303 
304     // collect the call arguments post lookup through cast Op's
305     extern_values.push_back(first);
306     return true;
307   };
308 
309   const TrivialTransformInfo tti(IsSingleCallRegion(if_region.then_branch()),
310                                  IsSingleCallRegion(if_region.else_branch()),
311                                  if_arg_matcher);
312 
313   std::string then_name, else_name;
314 
315   if (tti.can_transform) {
316     // We can transform to functional form trivially without outlining.
317     then_name = tti.callee_names[0].str();
318     else_name = tti.callee_names[1].str();
319   } else {
320     // Collect external values that are used within the else and then bodies.
321     extern_values =
322         CollectExternValues(if_region.then_branch(), if_region.else_branch());
323 
324     // These external values need to be added as inputs to the generated If. The
325     // order is determined by the order of these values the `extern_vales`.
326 
327     // Create 2 new functions with the input signature matching this order,
328     // and outline the `then` and `else` regions by moving the bodies of these
329     // regions into these functions. Replace tf.yield with a regular return.
330     if (if_region->hasAttrOfType<StringAttr>(kThenFuncNameAttr) &&
331         !if_region._then_func_nameAttr().getValue().empty()) {
332       then_name =
333           mapper.GetUniqueName(if_region._then_func_nameAttr().getValue())
334               .str();
335     } else {
336       then_name = GetName(if_region, "_then");
337     }
338     ExtractSingleBlockRegion(if_region.then_branch(), then_name, extern_values,
339                              worklist, /*extern_values_passthrough=*/false);
340 
341     if (if_region->hasAttrOfType<StringAttr>(kElseFuncNameAttr) &&
342         !if_region._else_func_nameAttr().getValue().empty()) {
343       else_name =
344           mapper.GetUniqueName(if_region._else_func_nameAttr().getValue())
345               .str();
346     } else {
347       else_name = GetName(if_region, "_else");
348     }
349     ExtractSingleBlockRegion(if_region.else_branch(), else_name, extern_values,
350                              worklist, /*extern_values_passthrough=*/false);
351   }
352 
353   // Look through ToBool operations for the condition.
354   Value cond = if_region.cond();
355   auto to_bool = dyn_cast_or_null<ToBoolOp>(cond.getDefiningOp());
356   if (to_bool) cond = to_bool.getOperand();
357 
358   // Once we have the `then` and `else` functions ready (either outlined or
359   // existing ones), replace the region based op with a functional control flow
360   // op.
361   OpBuilder builder(if_region);
362   auto if_op = builder.create<IfOp>(
363       if_region.getLoc(), if_region.getResultTypes(), cond, extern_values,
364       then_name, else_name, if_region.is_stateless());
365   CopyAndOverrideAttributes(if_region, if_op, &builder);
366 
367   if_region.replaceAllUsesWith(if_op.getResults());
368   if_region.erase();
369 
370   if (to_bool && to_bool.use_empty()) to_bool.erase();
371   return success();
372 }
373 
374 // Transform WhileRegion to WhileOp.
ConvertWhileOp(WhileRegionOp while_region)375 LogicalResult RegionControlFlowToFunctional::ConvertWhileOp(
376     WhileRegionOp while_region) {
377   // For While, the arguments of the calls in the body and cond regions match
378   // if they are region arguments with the same region argument numbers. If the
379   // 2 calls have the same value (an extern value) used as an argument, we
380   // cannot do a trivial transformation because post transform, we will need to
381   // pass this extern value as an argument to the function, so we cannot use the
382   // existing function as is.
383   auto while_arg_matcher = [](Value first, Region& first_region, Value second,
384                               Region& second_region) {
385     if (!first.isa<BlockArgument>() || !second.isa<BlockArgument>())
386       return false;
387     BlockArgument first_block_arg = first.cast<BlockArgument>();
388     BlockArgument second_block_arg = second.cast<BlockArgument>();
389 
390     // 2 block arguments will match if they are the same argument number, and
391     // are block arguments of the corresponding containing regions.
392     return first_block_arg.getArgNumber() == second_block_arg.getArgNumber() &&
393            first_block_arg.getParentBlock() == &first_region.front() &&
394            second_block_arg.getParentBlock() == &second_region.front();
395   };
396 
397   const TrivialTransformInfo tti(
398       IsSingleCallRegion(while_region.cond(), /*allow_to_bool=*/true),
399       IsSingleCallRegion(while_region.body()), while_arg_matcher);
400 
401   // All existing inputs to while region are inputs to the functional while.
402   auto new_inputs = llvm::to_vector<4>(while_region.getOperands());
403 
404   // All existing results will also be generated by the functional while.
405   auto new_result_types = llvm::to_vector<4>(while_region.getResultTypes());
406 
407   std::string cond_name, body_name;
408   if (tti.can_transform) {
409     // We can transform to functional form trivially without outlining.
410     cond_name = tti.callee_names[0].str();
411     body_name = tti.callee_names[1].str();
412   } else {
413     // The WhileRegion regions can refer to either arguments of the region, or
414     // external values implicitly captured by the region. When converting to
415     // functional form, all such external values need to become function
416     // arguments of the outlined functions, and become pass through values in
417     // the outlined body function. So when outlining the while body, in addition
418     // to the region arguments, all these external references need to be added
419     // as function arguments.
420     llvm::SmallVector<Value, 4> extern_values =
421         CollectExternValues(while_region.cond(), while_region.body());
422 
423     // Outline the `cond` and `body` regions by moving the bodies of these
424     // regions into new functions. Replace tf.yield with a regular return.
425     cond_name = GetName(while_region, "_cond");
426     ExtractSingleBlockRegion(while_region.cond(), cond_name, extern_values,
427                              worklist, /*extern_values_passthrough=*/false);
428 
429     body_name = GetName(while_region, "_body");
430     ExtractSingleBlockRegion(while_region.body(), body_name, extern_values,
431                              worklist, /*extern_values_passthrough=*/true);
432 
433     // All extern values become additional inputs and additional output types
434     // for the functional while.
435     new_inputs.append(extern_values.begin(), extern_values.end());
436     for (auto ext : extern_values) new_result_types.push_back(ext.getType());
437   }
438 
439   // Once we have the `cond` and `body` functions ready (either outlined or
440   // existing ones), replace the region based op with a functional op.
441   OpBuilder builder(while_region);
442   auto while_op = builder.create<WhileOp>(
443       while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name,
444       while_region.parallel_iterations(), while_region.is_stateless(),
445       while_region.shape_invariant());
446   CopyAndOverrideAttributes(while_region, while_op, &builder);
447 
448   // Redirect old results to new results.
449   for (auto it : llvm::zip(
450            while_region.getResults(),
451            while_op.getResults().take_front(while_region.getNumResults())))
452     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
453 
454   while_region.erase();
455   return success();
456 }
457 
runOnOperation()458 void RegionControlFlowToFunctional::runOnOperation() {
459   ModuleOp module = getOperation();
460 
461   // Seed worklist with all functions in the module.
462   worklist = llvm::to_vector<4>(module.getOps<FuncOp>());
463   while (!worklist.empty()) {
464     FuncOp function = worklist.pop_back_val();
465 
466     auto result = function.walk([&](Operation* op) {
467       if (auto if_region = llvm::dyn_cast<IfRegionOp>(op)) {
468         if (failed(ConvertIfOp(if_region))) {
469           op->emitOpError() << "failed to convert to functional form";
470           return WalkResult::interrupt();
471         }
472       } else if (auto while_region = llvm::dyn_cast<WhileRegionOp>(op)) {
473         if (failed(ConvertWhileOp(while_region))) {
474           op->emitOpError() << "failed to convert to functional form";
475           return WalkResult::interrupt();
476         }
477       }
478       return WalkResult::advance();
479     });
480 
481     if (result.wasInterrupted()) return signalPassFailure();
482   }
483 }
484 
485 }  // namespace
486 
487 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFRegionControlFlowToFunctional()488 CreateTFRegionControlFlowToFunctional() {
489   return std::make_unique<RegionControlFlowToFunctional>();
490 }
491 
492 }  // namespace TF
493 }  // namespace mlir
494