1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/SmallVector.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
18 #include "mlir/IR/Attributes.h" // from @llvm-project
19 #include "mlir/IR/Block.h" // from @llvm-project
20 #include "mlir/IR/Builders.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/IR/Operation.h" // from @llvm-project
23 #include "mlir/Pass/Pass.h" // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
25 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
30
31 namespace mlir {
32 namespace TFDevice {
33
34 namespace {
35
36 constexpr char kFuncAttr[] = "func";
37
38 struct ClusterOutliningPass
39 : public TF::ClusterOutliningPassBase<ClusterOutliningPass> {
40 void runOnOperation() override;
41 };
42
ReplaceClusterReturnWithReturn(tf_device::ReturnOp cluster_return_op,OpBuilder * builder)43 void ReplaceClusterReturnWithReturn(tf_device::ReturnOp cluster_return_op,
44 OpBuilder* builder) {
45 builder->create<ReturnOp>(cluster_return_op.getLoc(),
46 cluster_return_op.getOperands());
47 cluster_return_op.erase();
48 }
49
50 // Builds a function that outlines region attached to cluster_op and inserts
51 // built function into given module.
BuildFunction(llvm::ArrayRef<Value> live_ins,tf_device::ClusterOp cluster_op,SymbolTable * symbol_table,OpBuilder * builder)52 FuncOp BuildFunction(llvm::ArrayRef<Value> live_ins,
53 tf_device::ClusterOp cluster_op, SymbolTable* symbol_table,
54 OpBuilder* builder) {
55 llvm::SmallVector<Type, 4> operand_types;
56 operand_types.reserve(live_ins.size());
57 for (Value v : live_ins) operand_types.emplace_back(v.getType());
58
59 auto func_type =
60 builder->getFunctionType(operand_types, cluster_op.getResultTypes());
61
62 // TODO(lyandy): Define better name for outlined function. Potentially some
63 // name can be added during cluster formation.
64 FuncOp outlined_func =
65 FuncOp::create(cluster_op.getLoc(), "_func", func_type);
66
67 // This function is not externally visible and marking it private would allow
68 // symbol-dce pass to remove it when it is not referenced anymore.
69 outlined_func.setPrivate();
70
71 // Create function body.
72 Block* outlined_func_block = outlined_func.addEntryBlock();
73
74 // Replace uses of live-in values within cluster_op region with function
75 // arguments.
76 Region& cluster_op_region = cluster_op.body();
77 for (auto p : llvm::zip(live_ins, outlined_func_block->getArguments())) {
78 replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p),
79 cluster_op_region);
80 }
81
82 // Move all instructions in cluster_op into outlined_function's only block.
83 auto& cluster_op_body = cluster_op.GetBody().getOperations();
84 outlined_func_block->getOperations().splice(
85 outlined_func_block->end(), cluster_op_body, cluster_op_body.begin(),
86 cluster_op_body.end());
87
88 // Replace `tf_device.return` terminator with `std.return` in function
89 // body.
90 auto cluster_return_op =
91 cast<tf_device::ReturnOp>(outlined_func_block->getTerminator());
92 builder->setInsertionPoint(cluster_return_op);
93 ReplaceClusterReturnWithReturn(cluster_return_op, builder);
94
95 symbol_table->insert(outlined_func);
96 return outlined_func;
97 }
98
99 // Outlines body of `tf_device.cluster` into a function and create a
100 // `tf_device.cluster_func` to invoke that function. `tf_device.cluster` is
101 // removed afterwards.`
OutlineCluster(tf_device::ClusterOp cluster_op,SymbolTable * symbol_table,OpBuilder * builder)102 void OutlineCluster(tf_device::ClusterOp cluster_op, SymbolTable* symbol_table,
103 OpBuilder* builder) {
104 llvm::SetVector<Value> live_ins;
105 getUsedValuesDefinedAbove(cluster_op.body(), cluster_op.body(), live_ins);
106
107 FuncOp outlined_func =
108 BuildFunction(live_ins.getArrayRef(), cluster_op, symbol_table, builder);
109 cluster_op->setAttr(builder->getIdentifier(kFuncAttr),
110 builder->getSymbolRefAttr(outlined_func.getName()));
111
112 builder->setInsertionPoint(cluster_op);
113 auto cluster_func_op = builder->create<tf_device::ClusterFuncOp>(
114 cluster_op.getLoc(), outlined_func.getType().getResults(),
115 live_ins.getArrayRef(), cluster_op.getAttrs());
116
117 cluster_op.replaceAllUsesWith(cluster_func_op);
118 cluster_op.erase();
119 }
120
runOnOperation()121 void ClusterOutliningPass::runOnOperation() {
122 ModuleOp module = getOperation();
123 SymbolTable symbol_table(module);
124 OpBuilder builder(module.getContext());
125 module.walk([&](tf_device::ClusterOp cluster) {
126 OutlineCluster(cluster, &symbol_table, &builder);
127 });
128 }
129
130 } // namespace
131
CreateClusterOutliningPass()132 std::unique_ptr<OperationPass<ModuleOp>> CreateClusterOutliningPass() {
133 return std::make_unique<ClusterOutliningPass>();
134 }
135
136 } // namespace TFDevice
137 } // namespace mlir
138