• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass transforms functional control flow operations in the
17 // TensorFlow dialect to their region based counterparts, i.e.,
18 // tf.If -> tf.IfRegion and tf.While -> tf.WhileRegion
19 
20 #include "llvm/Support/raw_ostream.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "mlir/IR/Operation.h"  // from @llvm-project
27 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
28 #include "mlir/IR/Value.h"  // from @llvm-project
29 #include "mlir/IR/Verifier.h"  // from @llvm-project
30 #include "mlir/IR/Visitors.h"  // from @llvm-project
31 #include "mlir/Pass/Pass.h"  // from @llvm-project
32 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
38 
39 #define DEBUG_TYPE "tf-functional-cf-to-region"
40 
41 namespace mlir {
42 namespace TF {
43 
44 namespace {
45 
46 struct FunctionalControlFlowToRegions
47     : public TF::FunctionalControlFlowToRegionsPassBase<
48           FunctionalControlFlowToRegions> {
49   void runOnOperation() override;
50 };
51 
52 // Creates a call to function `func` in region `caller_region`. Use `args` as
53 // the call arguments, and terminate the region with a yield. The arguments are
54 // cast to the required type before the call. `use_region_args` control whether
55 // the input arguments are used as is (for IfOp) or block arguments of the same
56 // type as the input arguments are created and then used as call arguments (for
57 // While).
CreateCall(Operation * op,FuncOp func,Region & caller_region,ValueRange args,bool use_region_args)58 YieldOp CreateCall(Operation* op, FuncOp func, Region& caller_region,
59                    ValueRange args, bool use_region_args) {
60   assert(caller_region.empty() &&
61          "Expected empty region for newly created ops");
62   OpBuilder builder(caller_region);
63   Block* entry = builder.createBlock(&caller_region);
64 
65   if (use_region_args) {
66     entry->addArguments(func.getType().getInputs());
67     args = entry->getArguments();
68   }
69   llvm::SmallVector<Value, 4> casted_args;
70   casted_args.reserve(func.getNumArguments());
71   for (const auto& ArgAndType : zip(args, func.getType().getInputs())) {
72     Value arg = std::get<0>(ArgAndType);
73     Type expected_type = std::get<1>(ArgAndType);
74     if (arg.getType() != expected_type) {
75       arg = builder.create<CastOp>(op->getLoc(), expected_type, arg,
76                                    /*Truncate=*/builder.getBoolAttr(false));
77     }
78     casted_args.push_back(arg);
79   }
80   auto call = builder.create<CallOp>(op->getLoc(), func, casted_args);
81   return builder.create<YieldOp>(op->getLoc(), call.getResults());
82 }
83 
84 // Converts the condition for an IfOp/WhileOp to a boolean value.
ConvertConditionToBoolean(Operation * op,Value cond)85 Value ConvertConditionToBoolean(Operation* op, Value cond) {
86   if (auto ranked_type = cond.getType().dyn_cast<RankedTensorType>())
87     if (ranked_type.getRank() == 0 &&
88         ranked_type.getElementType().isSignlessInteger(1))
89       return cond;
90 
91   OpBuilder builder(op);
92   return builder.create<TF::ToBoolOp>(op->getLoc(), cond);
93 }
94 
95 // Transform a functional IfOp to a region based IfRegionOp.
ConvertIfOp(IfOp if_op)96 LogicalResult ConvertIfOp(IfOp if_op) {
97   Value cond = ConvertConditionToBoolean(if_op, if_op.cond());
98   OpBuilder builder(if_op);
99   auto if_region = builder.create<TF::IfRegionOp>(
100       if_op.getLoc(), if_op.getResultTypes(), cond, if_op.is_stateless(),
101       builder.getStringAttr(if_op.then_function().getName()),
102       builder.getStringAttr(if_op.else_function().getName()));
103   CopyDeviceAndUnderscoredAttributes(if_op, if_region);
104 
105   CreateCall(if_op, if_op.then_function(),
106              /*caller_region=*/if_region.then_branch(), if_op.input(),
107              /*use_region_args=*/false);
108   CreateCall(if_op, if_op.else_function(),
109              /*caller_region=*/if_region.else_branch(), if_op.input(),
110              /*use_region_args=*/false);
111   if_op.replaceAllUsesWith(if_region.getResults());
112   if_op.erase();
113   return success();
114 }
115 
ConvertWhileOp(WhileOp while_op)116 LogicalResult ConvertWhileOp(WhileOp while_op) {
117   auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>(
118       while_op.getLoc(), while_op.getResultTypes(), while_op.input(),
119       while_op.parallel_iterations(), while_op.is_stateless(),
120       while_op.shape_invariant());
121   CopyDeviceAndUnderscoredAttributes(while_op, while_region);
122 
123   YieldOp cond_yield =
124       CreateCall(while_op, while_op.cond_function(),
125                  /*caller_region=*/while_region.cond(), while_op.input(),
126                  /*use_region_args=*/true);
127   Value i1_cond =
128       ConvertConditionToBoolean(cond_yield, cond_yield.getOperand(0));
129   cond_yield.setOperand(0, i1_cond);
130 
131   CreateCall(while_op, while_op.body_function(),
132              /*caller_region=*/while_region.body(), while_op.input(),
133              /*use_region_args=*/true);
134   while_op.replaceAllUsesWith(while_region.getResults());
135   while_op.erase();
136   return success();
137 }
138 
runOnOperation()139 void FunctionalControlFlowToRegions::runOnOperation() {
140   ModuleOp module = getOperation();
141   auto result = module.walk([](Operation* op) {
142     if (IfOp if_op = llvm::dyn_cast<IfOp>(op)) {
143       if (failed(ConvertIfOp(if_op))) {
144         op->emitOpError() << "failed to convert to region form";
145         return WalkResult::interrupt();
146       }
147     } else if (auto while_op = llvm::dyn_cast<WhileOp>(op)) {
148       if (failed(ConvertWhileOp(while_op))) {
149         op->emitOpError() << "failed to convert to region form";
150         return WalkResult::interrupt();
151       }
152     }
153     return WalkResult::advance();
154   });
155   if (result.wasInterrupted()) return signalPassFailure();
156 }
157 }  // namespace
158 
159 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFFunctionalControlFlowToRegions()160 CreateTFFunctionalControlFlowToRegions() {
161   return std::make_unique<FunctionalControlFlowToRegions>();
162 }
163 
164 }  // namespace TF
165 }  // namespace mlir
166