1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/SetVector.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/ADT/Twine.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
20 #include "mlir/IR/Attributes.h" // from @llvm-project
21 #include "mlir/IR/Builders.h" // from @llvm-project
22 #include "mlir/IR/SymbolTable.h" // from @llvm-project
23 #include "mlir/Pass/Pass.h" // from @llvm-project
24 #include "mlir/Pass/PassManager.h" // from @llvm-project
25 #include "mlir/Support/LLVM.h" // from @llvm-project
26 #include "mlir/Transforms/Passes.h" // from @llvm-project
27 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
33
34 namespace mlir {
35 namespace tf_executor {
36
37 namespace {
38 constexpr llvm::StringRef kNestedModule = "_tpu_v1_compat_outlined";
39 constexpr llvm::StringRef kOutlinedFuncPrefix = "_tpu_v1_compat_outlined_func";
40
41 // Extract the islands containing a TPU cluster computation into an outlined
42 // function in a nested module. This will allow to run the usual bridge on this
43 // nested module which exhibit a more friendly "V2-like" structure.
44 // This is only intended for V1 compatibility mode where the bridge runs without
45 // feed/fetches on session create/extend.
46 struct TPUBridgeExecutorIslandOutlining
47 : public PassWrapper<TPUBridgeExecutorIslandOutlining,
48 OperationPass<ModuleOp>> {
getArgumentmlir::tf_executor::__anon2b36e2810111::TPUBridgeExecutorIslandOutlining49 StringRef getArgument() const final {
50 return "tf-executor-tpu-v1-island-outlining";
51 }
52
getDescriptionmlir::tf_executor::__anon2b36e2810111::TPUBridgeExecutorIslandOutlining53 StringRef getDescription() const final {
54 return "Outline TPU clusters from island into a nested module, so it can "
55 "be processed like a V2 module, intended for V1 compatibility mode";
56 }
57
58 void runOnOperation() override;
59 };
60
61 // Move FuncOp referenced by `symbol_ref` from one symbol table to another.
MoveFuncOp(FlatSymbolRefAttr & symbol_ref,SymbolTable & from,SymbolTable & to)62 void MoveFuncOp(FlatSymbolRefAttr &symbol_ref, SymbolTable &from,
63 SymbolTable &to) {
64 if (to.lookup<FuncOp>(symbol_ref.getValue())) return;
65 FuncOp callee = from.lookup<FuncOp>(symbol_ref.getValue());
66 callee.getOperation()->getBlock()->getOperations().remove(
67 callee.getOperation());
68 to.insert(callee);
69 }
70
runOnOperation()71 void TPUBridgeExecutorIslandOutlining::runOnOperation() {
72 MLIRContext *ctx = &getContext();
73
74 SymbolTable symbol_table(getOperation());
75 if (Operation *nested_module = symbol_table.lookup(kNestedModule)) {
76 nested_module->emitOpError("unexpected already present outlined module.");
77 return signalPassFailure();
78 }
79 ModuleOp outlined_module = ModuleOp::create(getOperation().getLoc());
80 outlined_module->setAttrs(getOperation()->getAttrDictionary());
81 outlined_module->setAttr(SymbolTable::getSymbolAttrName(),
82 StringAttr::get(ctx, kNestedModule));
83 symbol_table.insert(outlined_module);
84 SymbolTable outlined_symbol_table(outlined_module);
85
86 // Find every island that contains a TPUReplicateMetadata node and extract it
87 // in a new module to run the V1 bridge there.
88 SmallVector<IslandOp, 8> islands_to_outline;
89 getOperation().walk([&](TF::TPUReplicateMetadataOp replicate_op) {
90 auto island_op = cast<IslandOp>(replicate_op->getParentOp());
91 if (!island_op || island_op.WrapsSingleOp()) return;
92 islands_to_outline.push_back(island_op);
93 });
94 int prefix_id = 0;
95 for (IslandOp island_op : islands_to_outline) {
96 // Build the function signature.
97
98 // First the captured values in the island are function arguments
99 llvm::SetVector<Value> operands;
100 getUsedValuesDefinedAbove(island_op.body(), operands);
101
102 SmallVector<Type, 16> func_operand_types;
103 func_operand_types.reserve(operands.size());
104 for (Value operand : operands)
105 func_operand_types.push_back(operand.getType());
106
107 // Function results are the yield operands
108 SmallVector<Type, 16> func_result_types;
109 for (Value operand : island_op.GetYield().getOperands())
110 func_result_types.push_back(operand.getType());
111 FunctionType func_type =
112 FunctionType::get(ctx, func_operand_types, func_result_types);
113
114 // Create the outlined function
115 SmallString<32> name = kOutlinedFuncPrefix;
116 name += llvm::Twine(prefix_id++).str();
117 auto outlined_func =
118 OpBuilder(ctx).create<FuncOp>(island_op.getLoc(), name, func_type);
119 outlined_symbol_table.insert(outlined_func);
120 outlined_func.setNested();
121
122 // We will "steal" the body of the island and replace it with a call to the
123 // new function later.
124 {
125 YieldOp yield_op = island_op.GetYield();
126 outlined_func.getBody().takeBody(island_op.body());
127
128 // Replace the yield with a return
129 OpBuilder replacer(yield_op);
130 island_op.body().push_back(new Block);
131 replacer.create<ReturnOp>(yield_op.getLoc(), yield_op.getOperands());
132 yield_op.erase();
133 }
134
135 // Remap the captured operands in the (former) island block with newly
136 // created entry block arguments in the function body.
137 {
138 Block &entry_block = outlined_func.getBody().front();
139 for (Value operand : operands) {
140 BlockArgument newArg = entry_block.addArgument(operand.getType());
141 replaceAllUsesInRegionWith(operand, newArg, outlined_func.getBody());
142 }
143 }
144
145 // The function is in place in the nested module, create a call and yield in
146 // the original island.
147 OpBuilder builder = OpBuilder::atBlockEnd(&island_op.GetBody());
148 auto call_op = builder.create<mlir::TF::PartitionedCallOp>(
149 island_op.getLoc(), func_result_types, operands.getArrayRef(),
150 builder.getSymbolRefAttr(
151 kNestedModule, builder.getSymbolRefAttr(outlined_func.getName())),
152 /*config=*/builder.getStringAttr(""),
153 /*config_proto=*/builder.getStringAttr(""),
154 /*executor_type=*/builder.getStringAttr(""));
155 SmallVector<Value, 16> yield_operands(call_op.getResults());
156 builder.create<YieldOp>(island_op.getLoc(), yield_operands);
157 }
158
159 // Outlined all the transitively called functions by moving them in the
160 // outlined module.
161 for (FuncOp func : outlined_module.getOps<FuncOp>()) {
162 func.walk([&](Operation *op) {
163 for (NamedAttribute attr : op->getAttrs()) {
164 if (auto symbol_ref = attr.second.dyn_cast<FlatSymbolRefAttr>()) {
165 MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table);
166 continue;
167 }
168 if (auto array_attr = attr.second.dyn_cast<ArrayAttr>()) {
169 for (const Attribute &attribute : array_attr) {
170 auto symbol_ref = attribute.dyn_cast<FlatSymbolRefAttr>();
171 if (!symbol_ref) continue;
172 MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table);
173 }
174 }
175 }
176 });
177 }
178 }
179
180 PassRegistration<TPUBridgeExecutorIslandOutlining> tpu_pass;
181
182 } // namespace
183
184 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFExecutorTPUV1IslandOutliningPass()185 CreateTFExecutorTPUV1IslandOutliningPass() {
186 return std::make_unique<TPUBridgeExecutorIslandOutlining>();
187 }
188
189 } // namespace tf_executor
190 } // namespace mlir
191