• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/SmallVector.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
18 #include "mlir/IR/Attributes.h"  // from @llvm-project
19 #include "mlir/IR/Block.h"  // from @llvm-project
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/IR/Operation.h"  // from @llvm-project
23 #include "mlir/Pass/Pass.h"  // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
25 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
30 
31 namespace mlir {
32 namespace TFDevice {
33 
34 namespace {
35 
36 constexpr char kFuncAttr[] = "func";
37 
38 struct ClusterOutliningPass
39     : public TF::ClusterOutliningPassBase<ClusterOutliningPass> {
40   void runOnOperation() override;
41 };
42 
43 struct LaunchOutliningPass
44     : public TF::LaunchOutliningPassBase<LaunchOutliningPass> {
45   void runOnOperation() override;
46 };
47 
ReplaceClusterReturnWithReturn(tf_device::ReturnOp cluster_return_op,OpBuilder * builder)48 void ReplaceClusterReturnWithReturn(tf_device::ReturnOp cluster_return_op,
49                                     OpBuilder* builder) {
50   builder->create<ReturnOp>(cluster_return_op.getLoc(),
51                             cluster_return_op.getOperands());
52   cluster_return_op.erase();
53 }
54 
55 // Builds a function that outlines region attached to cluster_op or launch_op,
56 // and inserts built function into given module.
57 template <typename ClusterOrLaunchOp>
BuildFunction(llvm::ArrayRef<Value> live_ins,ClusterOrLaunchOp op,SymbolTable * symbol_table,OpBuilder * builder)58 FuncOp BuildFunction(llvm::ArrayRef<Value> live_ins, ClusterOrLaunchOp op,
59                      SymbolTable* symbol_table, OpBuilder* builder) {
60   llvm::SmallVector<Type, 4> operand_types;
61   operand_types.reserve(live_ins.size());
62   for (Value v : live_ins) operand_types.emplace_back(v.getType());
63 
64   auto func_type = builder->getFunctionType(operand_types, op.getResultTypes());
65 
66   // TODO(lyandy): Define better name for outlined function. Potentially some
67   // name can be added during cluster formation.
68   FuncOp outlined_func = FuncOp::create(op.getLoc(), "_func", func_type);
69 
70   // This function is not externally visible and marking it private would allow
71   // symbol-dce pass to remove it when it is not referenced anymore.
72   outlined_func.setPrivate();
73 
74   // Create function body.
75   Block* outlined_func_block = outlined_func.addEntryBlock();
76 
77   // Replace uses of live-in values within cluster_op region with function
78   // arguments.
79   Region& op_region = op.body();
80   for (auto p : llvm::zip(live_ins, outlined_func_block->getArguments())) {
81     replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op_region);
82   }
83 
84   // Move all instructions in cluster_op into outlined_function's only block.
85   auto& op_body = op.GetBody().getOperations();
86   outlined_func_block->getOperations().splice(
87       outlined_func_block->end(), op_body, op_body.begin(), op_body.end());
88 
89   // Replace `tf_device.return` terminator with `std.return` in function
90   // body.
91   auto return_op =
92       cast<tf_device::ReturnOp>(outlined_func_block->getTerminator());
93   builder->setInsertionPoint(return_op);
94   ReplaceClusterReturnWithReturn(return_op, builder);
95 
96   symbol_table->insert(outlined_func);
97   return outlined_func;
98 }
99 
100 // Outlines body of `tf_device.cluster` into a function and create a
101 // `tf_device.cluster_func` to invoke that function. `tf_device.cluster` is
102 // removed afterwards.`
OutlineCluster(tf_device::ClusterOp cluster_op,SymbolTable * symbol_table,OpBuilder * builder)103 void OutlineCluster(tf_device::ClusterOp cluster_op, SymbolTable* symbol_table,
104                     OpBuilder* builder) {
105   llvm::SetVector<Value> live_ins;
106   getUsedValuesDefinedAbove(cluster_op.body(), cluster_op.body(), live_ins);
107 
108   FuncOp outlined_func =
109       BuildFunction(live_ins.getArrayRef(), cluster_op, symbol_table, builder);
110   cluster_op->setAttr(builder->getIdentifier(kFuncAttr),
111                       builder->getSymbolRefAttr(outlined_func.getName()));
112 
113   builder->setInsertionPoint(cluster_op);
114   auto cluster_func_op = builder->create<tf_device::ClusterFuncOp>(
115       cluster_op.getLoc(), outlined_func.getType().getResults(),
116       live_ins.getArrayRef(), cluster_op->getAttrs());
117 
118   cluster_op.replaceAllUsesWith(cluster_func_op);
119   cluster_op.erase();
120 }
121 
122 // Outlines body of `tf_device.launch` into a function and create a
123 // `tf_device.launch_func` to invoke that function. `tf_device.launch` is
124 // removed afterwards.`
OutlineLaunch(tf_device::LaunchOp launch_op,SymbolTable * symbol_table,OpBuilder * builder)125 void OutlineLaunch(tf_device::LaunchOp launch_op, SymbolTable* symbol_table,
126                    OpBuilder* builder) {
127   llvm::SetVector<Value> live_ins;
128   getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins);
129 
130   FuncOp outlined_func =
131       BuildFunction(live_ins.getArrayRef(), launch_op, symbol_table, builder);
132   launch_op->setAttr(builder->getIdentifier(kFuncAttr),
133                      builder->getSymbolRefAttr(outlined_func.getName()));
134 
135   builder->setInsertionPoint(launch_op);
136   auto cluster_func_op = builder->create<tf_device::LaunchFuncOp>(
137       launch_op.getLoc(), outlined_func.getType().getResults(),
138       live_ins.getArrayRef(), launch_op->getAttrs());
139 
140   launch_op.replaceAllUsesWith(cluster_func_op);
141   launch_op.erase();
142 }
143 
runOnOperation()144 void ClusterOutliningPass::runOnOperation() {
145   ModuleOp module = getOperation();
146   SymbolTable symbol_table(module);
147   OpBuilder builder(module.getContext());
148   module.walk([&](tf_device::ClusterOp cluster) {
149     OutlineCluster(cluster, &symbol_table, &builder);
150   });
151 }
152 
runOnOperation()153 void LaunchOutliningPass::runOnOperation() {
154   ModuleOp module = getOperation();
155   SymbolTable symbol_table(module);
156   OpBuilder builder(module.getContext());
157   module.walk([&](tf_device::LaunchOp launch) {
158     OutlineLaunch(launch, &symbol_table, &builder);
159   });
160 }
161 
162 }  // namespace
163 
CreateClusterOutliningPass()164 std::unique_ptr<OperationPass<ModuleOp>> CreateClusterOutliningPass() {
165   return std::make_unique<ClusterOutliningPass>();
166 }
167 
CreateLaunchOutliningPass()168 std::unique_ptr<OperationPass<ModuleOp>> CreateLaunchOutliningPass() {
169   return std::make_unique<LaunchOutliningPass>();
170 }
171 
172 }  // namespace TFDevice
173 }  // namespace mlir
174