1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/SmallVector.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
18 #include "mlir/IR/Attributes.h" // from @llvm-project
19 #include "mlir/IR/Block.h" // from @llvm-project
20 #include "mlir/IR/Builders.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/IR/Operation.h" // from @llvm-project
23 #include "mlir/Pass/Pass.h" // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
25 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
30
31 namespace mlir {
32 namespace TFDevice {
33
34 namespace {
35
36 constexpr char kFuncAttr[] = "func";
37
38 struct ClusterOutliningPass
39 : public TF::ClusterOutliningPassBase<ClusterOutliningPass> {
40 void runOnOperation() override;
41 };
42
43 struct LaunchOutliningPass
44 : public TF::LaunchOutliningPassBase<LaunchOutliningPass> {
45 void runOnOperation() override;
46 };
47
ReplaceClusterReturnWithReturn(tf_device::ReturnOp cluster_return_op,OpBuilder * builder)48 void ReplaceClusterReturnWithReturn(tf_device::ReturnOp cluster_return_op,
49 OpBuilder* builder) {
50 builder->create<ReturnOp>(cluster_return_op.getLoc(),
51 cluster_return_op.getOperands());
52 cluster_return_op.erase();
53 }
54
55 // Builds a function that outlines region attached to cluster_op or launch_op,
56 // and inserts built function into given module.
57 template <typename ClusterOrLaunchOp>
BuildFunction(llvm::ArrayRef<Value> live_ins,ClusterOrLaunchOp op,SymbolTable * symbol_table,OpBuilder * builder)58 FuncOp BuildFunction(llvm::ArrayRef<Value> live_ins, ClusterOrLaunchOp op,
59 SymbolTable* symbol_table, OpBuilder* builder) {
60 llvm::SmallVector<Type, 4> operand_types;
61 operand_types.reserve(live_ins.size());
62 for (Value v : live_ins) operand_types.emplace_back(v.getType());
63
64 auto func_type = builder->getFunctionType(operand_types, op.getResultTypes());
65
66 // TODO(lyandy): Define better name for outlined function. Potentially some
67 // name can be added during cluster formation.
68 FuncOp outlined_func = FuncOp::create(op.getLoc(), "_func", func_type);
69
70 // This function is not externally visible and marking it private would allow
71 // symbol-dce pass to remove it when it is not referenced anymore.
72 outlined_func.setPrivate();
73
74 // Create function body.
75 Block* outlined_func_block = outlined_func.addEntryBlock();
76
77 // Replace uses of live-in values within cluster_op region with function
78 // arguments.
79 Region& op_region = op.body();
80 for (auto p : llvm::zip(live_ins, outlined_func_block->getArguments())) {
81 replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op_region);
82 }
83
84 // Move all instructions in cluster_op into outlined_function's only block.
85 auto& op_body = op.GetBody().getOperations();
86 outlined_func_block->getOperations().splice(
87 outlined_func_block->end(), op_body, op_body.begin(), op_body.end());
88
89 // Replace `tf_device.return` terminator with `std.return` in function
90 // body.
91 auto return_op =
92 cast<tf_device::ReturnOp>(outlined_func_block->getTerminator());
93 builder->setInsertionPoint(return_op);
94 ReplaceClusterReturnWithReturn(return_op, builder);
95
96 symbol_table->insert(outlined_func);
97 return outlined_func;
98 }
99
100 // Outlines body of `tf_device.cluster` into a function and create a
101 // `tf_device.cluster_func` to invoke that function. `tf_device.cluster` is
102 // removed afterwards.`
OutlineCluster(tf_device::ClusterOp cluster_op,SymbolTable * symbol_table,OpBuilder * builder)103 void OutlineCluster(tf_device::ClusterOp cluster_op, SymbolTable* symbol_table,
104 OpBuilder* builder) {
105 llvm::SetVector<Value> live_ins;
106 getUsedValuesDefinedAbove(cluster_op.body(), cluster_op.body(), live_ins);
107
108 FuncOp outlined_func =
109 BuildFunction(live_ins.getArrayRef(), cluster_op, symbol_table, builder);
110 cluster_op->setAttr(builder->getIdentifier(kFuncAttr),
111 builder->getSymbolRefAttr(outlined_func.getName()));
112
113 builder->setInsertionPoint(cluster_op);
114 auto cluster_func_op = builder->create<tf_device::ClusterFuncOp>(
115 cluster_op.getLoc(), outlined_func.getType().getResults(),
116 live_ins.getArrayRef(), cluster_op->getAttrs());
117
118 cluster_op.replaceAllUsesWith(cluster_func_op);
119 cluster_op.erase();
120 }
121
122 // Outlines body of `tf_device.launch` into a function and create a
123 // `tf_device.launch_func` to invoke that function. `tf_device.launch` is
124 // removed afterwards.`
OutlineLaunch(tf_device::LaunchOp launch_op,SymbolTable * symbol_table,OpBuilder * builder)125 void OutlineLaunch(tf_device::LaunchOp launch_op, SymbolTable* symbol_table,
126 OpBuilder* builder) {
127 llvm::SetVector<Value> live_ins;
128 getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins);
129
130 FuncOp outlined_func =
131 BuildFunction(live_ins.getArrayRef(), launch_op, symbol_table, builder);
132 launch_op->setAttr(builder->getIdentifier(kFuncAttr),
133 builder->getSymbolRefAttr(outlined_func.getName()));
134
135 builder->setInsertionPoint(launch_op);
136 auto cluster_func_op = builder->create<tf_device::LaunchFuncOp>(
137 launch_op.getLoc(), outlined_func.getType().getResults(),
138 live_ins.getArrayRef(), launch_op->getAttrs());
139
140 launch_op.replaceAllUsesWith(cluster_func_op);
141 launch_op.erase();
142 }
143
runOnOperation()144 void ClusterOutliningPass::runOnOperation() {
145 ModuleOp module = getOperation();
146 SymbolTable symbol_table(module);
147 OpBuilder builder(module.getContext());
148 module.walk([&](tf_device::ClusterOp cluster) {
149 OutlineCluster(cluster, &symbol_table, &builder);
150 });
151 }
152
runOnOperation()153 void LaunchOutliningPass::runOnOperation() {
154 ModuleOp module = getOperation();
155 SymbolTable symbol_table(module);
156 OpBuilder builder(module.getContext());
157 module.walk([&](tf_device::LaunchOp launch) {
158 OutlineLaunch(launch, &symbol_table, &builder);
159 });
160 }
161
162 } // namespace
163
CreateClusterOutliningPass()164 std::unique_ptr<OperationPass<ModuleOp>> CreateClusterOutliningPass() {
165 return std::make_unique<ClusterOutliningPass>();
166 }
167
CreateLaunchOutliningPass()168 std::unique_ptr<OperationPass<ModuleOp>> CreateLaunchOutliningPass() {
169 return std::make_unique<LaunchOutliningPass>();
170 }
171
172 } // namespace TFDevice
173 } // namespace mlir
174