• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/SmallVector.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
18 #include "mlir/IR/Attributes.h"  // from @llvm-project
19 #include "mlir/IR/Block.h"  // from @llvm-project
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/IR/Operation.h"  // from @llvm-project
23 #include "mlir/Pass/Pass.h"  // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
25 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
30 
31 namespace mlir {
32 namespace TFDevice {
33 
34 namespace {
35 
36 constexpr char kFuncAttr[] = "func";
37 
38 struct ClusterOutliningPass
39     : public TF::ClusterOutliningPassBase<ClusterOutliningPass> {
40   void runOnOperation() override;
41 };
42 
ReplaceClusterReturnWithReturn(tf_device::ReturnOp cluster_return_op,OpBuilder * builder)43 void ReplaceClusterReturnWithReturn(tf_device::ReturnOp cluster_return_op,
44                                     OpBuilder* builder) {
45   builder->create<ReturnOp>(cluster_return_op.getLoc(),
46                             cluster_return_op.getOperands());
47   cluster_return_op.erase();
48 }
49 
50 // Builds a function that outlines region attached to cluster_op and inserts
51 // built function into given module.
BuildFunction(llvm::ArrayRef<Value> live_ins,tf_device::ClusterOp cluster_op,SymbolTable * symbol_table,OpBuilder * builder)52 FuncOp BuildFunction(llvm::ArrayRef<Value> live_ins,
53                      tf_device::ClusterOp cluster_op, SymbolTable* symbol_table,
54                      OpBuilder* builder) {
55   llvm::SmallVector<Type, 4> operand_types;
56   operand_types.reserve(live_ins.size());
57   for (Value v : live_ins) operand_types.emplace_back(v.getType());
58 
59   auto func_type =
60       builder->getFunctionType(operand_types, cluster_op.getResultTypes());
61 
62   // TODO(lyandy): Define better name for outlined function. Potentially some
63   // name can be added during cluster formation.
64   FuncOp outlined_func =
65       FuncOp::create(cluster_op.getLoc(), "_func", func_type);
66 
67   // This function is not externally visible and marking it private would allow
68   // symbol-dce pass to remove it when it is not referenced anymore.
69   outlined_func.setPrivate();
70 
71   // Create function body.
72   Block* outlined_func_block = outlined_func.addEntryBlock();
73 
74   // Replace uses of live-in values within cluster_op region with function
75   // arguments.
76   Region& cluster_op_region = cluster_op.body();
77   for (auto p : llvm::zip(live_ins, outlined_func_block->getArguments())) {
78     replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p),
79                                cluster_op_region);
80   }
81 
82   // Move all instructions in cluster_op into outlined_function's only block.
83   auto& cluster_op_body = cluster_op.GetBody().getOperations();
84   outlined_func_block->getOperations().splice(
85       outlined_func_block->end(), cluster_op_body, cluster_op_body.begin(),
86       cluster_op_body.end());
87 
88   // Replace `tf_device.return` terminator with `std.return` in function
89   // body.
90   auto cluster_return_op =
91       cast<tf_device::ReturnOp>(outlined_func_block->getTerminator());
92   builder->setInsertionPoint(cluster_return_op);
93   ReplaceClusterReturnWithReturn(cluster_return_op, builder);
94 
95   symbol_table->insert(outlined_func);
96   return outlined_func;
97 }
98 
99 // Outlines body of `tf_device.cluster` into a function and create a
100 // `tf_device.cluster_func` to invoke that function. `tf_device.cluster` is
101 // removed afterwards.`
OutlineCluster(tf_device::ClusterOp cluster_op,SymbolTable * symbol_table,OpBuilder * builder)102 void OutlineCluster(tf_device::ClusterOp cluster_op, SymbolTable* symbol_table,
103                     OpBuilder* builder) {
104   llvm::SetVector<Value> live_ins;
105   getUsedValuesDefinedAbove(cluster_op.body(), cluster_op.body(), live_ins);
106 
107   FuncOp outlined_func =
108       BuildFunction(live_ins.getArrayRef(), cluster_op, symbol_table, builder);
109   cluster_op->setAttr(builder->getIdentifier(kFuncAttr),
110                       builder->getSymbolRefAttr(outlined_func.getName()));
111 
112   builder->setInsertionPoint(cluster_op);
113   auto cluster_func_op = builder->create<tf_device::ClusterFuncOp>(
114       cluster_op.getLoc(), outlined_func.getType().getResults(),
115       live_ins.getArrayRef(), cluster_op.getAttrs());
116 
117   cluster_op.replaceAllUsesWith(cluster_func_op);
118   cluster_op.erase();
119 }
120 
runOnOperation()121 void ClusterOutliningPass::runOnOperation() {
122   ModuleOp module = getOperation();
123   SymbolTable symbol_table(module);
124   OpBuilder builder(module.getContext());
125   module.walk([&](tf_device::ClusterOp cluster) {
126     OutlineCluster(cluster, &symbol_table, &builder);
127   });
128 }
129 
130 }  // namespace
131 
CreateClusterOutliningPass()132 std::unique_ptr<OperationPass<ModuleOp>> CreateClusterOutliningPass() {
133   return std::make_unique<ClusterOutliningPass>();
134 }
135 
136 }  // namespace TFDevice
137 }  // namespace mlir
138