1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass transforms functional control flow operations in the
17 // TensorFlow dialect to their region based counterparts, i.e.,
18 // tf.If -> tf.IfRegion and tf.While -> tf.WhileRegion
19
20 #include "llvm/Support/raw_ostream.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/Builders.h" // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
26 #include "mlir/IR/Operation.h" // from @llvm-project
27 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
28 #include "mlir/IR/Value.h" // from @llvm-project
29 #include "mlir/IR/Verifier.h" // from @llvm-project
30 #include "mlir/IR/Visitors.h" // from @llvm-project
31 #include "mlir/Pass/Pass.h" // from @llvm-project
32 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
38
39 #define DEBUG_TYPE "tf-functional-cf-to-region"
40
41 namespace mlir {
42 namespace TF {
43
44 namespace {
45
46 struct FunctionalControlFlowToRegions
47 : public TF::FunctionalControlFlowToRegionsPassBase<
48 FunctionalControlFlowToRegions> {
49 void runOnOperation() override;
50 };
51
52 // Creates a call to function `func` in region `caller_region`. Use `args` as
53 // the call arguments, and terminate the region with a yield. The arguments are
54 // cast to the required type before the call. `use_region_args` control whether
55 // the input arguments are used as is (for IfOp) or block arguments of the same
56 // type as the input arguments are created and then used as call arguments (for
57 // While).
CreateCall(Operation * op,FuncOp func,Region & caller_region,ValueRange args,bool use_region_args)58 YieldOp CreateCall(Operation* op, FuncOp func, Region& caller_region,
59 ValueRange args, bool use_region_args) {
60 assert(caller_region.empty() &&
61 "Expected empty region for newly created ops");
62 OpBuilder builder(caller_region);
63 Block* entry = builder.createBlock(&caller_region);
64
65 if (use_region_args) {
66 entry->addArguments(args.getType());
67 args = entry->getArguments();
68 }
69 llvm::SmallVector<Value, 4> casted_args;
70 casted_args.reserve(func.getNumArguments());
71 for (const auto& ArgAndType : zip(args, func.getType().getInputs())) {
72 Value arg = std::get<0>(ArgAndType);
73 Type expected_type = std::get<1>(ArgAndType);
74 if (arg.getType() != expected_type) {
75 arg = builder.create<CastOp>(op->getLoc(), expected_type, arg,
76 /*Truncate=*/builder.getBoolAttr(false));
77 }
78 casted_args.push_back(arg);
79 }
80 auto call = builder.create<CallOp>(op->getLoc(), func, casted_args);
81 return builder.create<YieldOp>(op->getLoc(), call.getResults());
82 }
83
84 // Converts the condition for an IfOp/WhileOp to a boolean value.
ConvertConditionToBoolean(Operation * op,Value cond)85 Value ConvertConditionToBoolean(Operation* op, Value cond) {
86 if (auto ranked_type = cond.getType().dyn_cast<RankedTensorType>())
87 if (ranked_type.getRank() == 0 &&
88 ranked_type.getElementType().isSignlessInteger(1))
89 return cond;
90
91 OpBuilder builder(op);
92 return builder.create<TF::ToBoolOp>(op->getLoc(), cond);
93 }
94
95 // Transform a functional IfOp to a region based IfRegionOp.
ConvertIfOp(IfOp if_op)96 LogicalResult ConvertIfOp(IfOp if_op) {
97 Value cond = ConvertConditionToBoolean(if_op, if_op.cond());
98 OpBuilder builder(if_op);
99 auto if_region = builder.create<TF::IfRegionOp>(
100 if_op.getLoc(), if_op.getResultTypes(), cond, if_op.is_stateless(),
101 builder.getStringAttr(if_op.then_function().getName()),
102 builder.getStringAttr(if_op.else_function().getName()));
103 CopyDeviceAndUnderscoredAttributes(if_op, if_region);
104
105 CreateCall(if_op, if_op.then_function(),
106 /*caller_region=*/if_region.then_branch(), if_op.input(),
107 /*use_region_args=*/false);
108 CreateCall(if_op, if_op.else_function(),
109 /*caller_region=*/if_region.else_branch(), if_op.input(),
110 /*use_region_args=*/false);
111 if_op.replaceAllUsesWith(if_region.getResults());
112 if_op.erase();
113 return success();
114 }
115
ConvertWhileOp(WhileOp while_op)116 LogicalResult ConvertWhileOp(WhileOp while_op) {
117 auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>(
118 while_op.getLoc(), while_op.getResultTypes(), while_op.input(),
119 while_op.parallel_iterations(), while_op.is_stateless(),
120 while_op.shape_invariant());
121 CopyDeviceAndUnderscoredAttributes(while_op, while_region);
122
123 YieldOp cond_yield =
124 CreateCall(while_op, while_op.cond_function(),
125 /*caller_region=*/while_region.cond(), while_op.input(),
126 /*use_region_args=*/true);
127 Value i1_cond =
128 ConvertConditionToBoolean(cond_yield, cond_yield.getOperand(0));
129 cond_yield.setOperand(0, i1_cond);
130
131 CreateCall(while_op, while_op.body_function(),
132 /*caller_region=*/while_region.body(), while_op.input(),
133 /*use_region_args=*/true);
134 while_op.replaceAllUsesWith(while_region.getResults());
135 while_op.erase();
136 return success();
137 }
138
runOnOperation()139 void FunctionalControlFlowToRegions::runOnOperation() {
140 ModuleOp module = getOperation();
141 auto result = module.walk([](Operation* op) {
142 if (IfOp if_op = llvm::dyn_cast<IfOp>(op)) {
143 if (failed(ConvertIfOp(if_op))) {
144 op->emitOpError() << "failed to convert to region form";
145 return WalkResult::interrupt();
146 }
147 } else if (auto while_op = llvm::dyn_cast<WhileOp>(op)) {
148 if (failed(ConvertWhileOp(while_op))) {
149 op->emitOpError() << "failed to convert to region form";
150 return WalkResult::interrupt();
151 }
152 }
153 return WalkResult::advance();
154 });
155 if (result.wasInterrupted()) return signalPassFailure();
156 }
157 } // namespace
158
159 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFFunctionalControlFlowToRegions()160 CreateTFFunctionalControlFlowToRegions() {
161 return std::make_unique<FunctionalControlFlowToRegions>();
162 }
163
164 } // namespace TF
165 } // namespace mlir
166