• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/SetVector.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/ADT/Twine.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/Builders.h"  // from @llvm-project
22 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
23 #include "mlir/Pass/Pass.h"  // from @llvm-project
24 #include "mlir/Pass/PassManager.h"  // from @llvm-project
25 #include "mlir/Support/LLVM.h"  // from @llvm-project
26 #include "mlir/Transforms/Passes.h"  // from @llvm-project
27 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
33 
34 namespace mlir {
35 namespace tf_executor {
36 
37 namespace {
38 constexpr llvm::StringRef kNestedModule = "_tpu_v1_compat_outlined";
39 constexpr llvm::StringRef kOutlinedFuncPrefix = "_tpu_v1_compat_outlined_func";
40 
41 // Extract the islands containing a TPU cluster computation into an outlined
42 // function in a nested module. This will allow to run the usual bridge on this
43 // nested module which exhibit a more friendly "V2-like" structure.
44 // This is only intended for V1 compatibility mode where the bridge runs without
45 // feed/fetches on session create/extend.
46 struct TPUBridgeExecutorIslandOutlining
47     : public PassWrapper<TPUBridgeExecutorIslandOutlining,
48                          OperationPass<ModuleOp>> {
getArgumentmlir::tf_executor::__anon2b36e2810111::TPUBridgeExecutorIslandOutlining49   StringRef getArgument() const final {
50     return "tf-executor-tpu-v1-island-outlining";
51   }
52 
getDescriptionmlir::tf_executor::__anon2b36e2810111::TPUBridgeExecutorIslandOutlining53   StringRef getDescription() const final {
54     return "Outline TPU clusters from island into a nested module, so it can "
55            "be processed like a V2 module, intended for V1 compatibility mode";
56   }
57 
58   void runOnOperation() override;
59 };
60 
61 // Move FuncOp referenced by `symbol_ref` from one symbol table to another.
MoveFuncOp(FlatSymbolRefAttr & symbol_ref,SymbolTable & from,SymbolTable & to)62 void MoveFuncOp(FlatSymbolRefAttr &symbol_ref, SymbolTable &from,
63                 SymbolTable &to) {
64   if (to.lookup<FuncOp>(symbol_ref.getValue())) return;
65   FuncOp callee = from.lookup<FuncOp>(symbol_ref.getValue());
66   callee.getOperation()->getBlock()->getOperations().remove(
67       callee.getOperation());
68   to.insert(callee);
69 }
70 
runOnOperation()71 void TPUBridgeExecutorIslandOutlining::runOnOperation() {
72   MLIRContext *ctx = &getContext();
73 
74   SymbolTable symbol_table(getOperation());
75   if (Operation *nested_module = symbol_table.lookup(kNestedModule)) {
76     nested_module->emitOpError("unexpected already present outlined module.");
77     return signalPassFailure();
78   }
79   ModuleOp outlined_module = ModuleOp::create(getOperation().getLoc());
80   outlined_module->setAttrs(getOperation()->getAttrDictionary());
81   outlined_module->setAttr(SymbolTable::getSymbolAttrName(),
82                            StringAttr::get(ctx, kNestedModule));
83   symbol_table.insert(outlined_module);
84   SymbolTable outlined_symbol_table(outlined_module);
85 
86   // Find every island that contains a TPUReplicateMetadata node and extract it
87   // in a new module to run the V1 bridge there.
88   SmallVector<IslandOp, 8> islands_to_outline;
89   getOperation().walk([&](TF::TPUReplicateMetadataOp replicate_op) {
90     auto island_op = cast<IslandOp>(replicate_op->getParentOp());
91     if (!island_op || island_op.WrapsSingleOp()) return;
92     islands_to_outline.push_back(island_op);
93   });
94   int prefix_id = 0;
95   for (IslandOp island_op : islands_to_outline) {
96     // Build the function signature.
97 
98     // First the captured values in the island are function arguments
99     llvm::SetVector<Value> operands;
100     getUsedValuesDefinedAbove(island_op.body(), operands);
101 
102     SmallVector<Type, 16> func_operand_types;
103     func_operand_types.reserve(operands.size());
104     for (Value operand : operands)
105       func_operand_types.push_back(operand.getType());
106 
107     // Function results are the yield operands
108     SmallVector<Type, 16> func_result_types;
109     for (Value operand : island_op.GetYield().getOperands())
110       func_result_types.push_back(operand.getType());
111     FunctionType func_type =
112         FunctionType::get(ctx, func_operand_types, func_result_types);
113 
114     // Create the outlined function
115     SmallString<32> name = kOutlinedFuncPrefix;
116     name += llvm::Twine(prefix_id++).str();
117     auto outlined_func =
118         OpBuilder(ctx).create<FuncOp>(island_op.getLoc(), name, func_type);
119     outlined_symbol_table.insert(outlined_func);
120     outlined_func.setNested();
121 
122     // We will "steal" the body of the island and replace it with a call to the
123     // new function later.
124     {
125       YieldOp yield_op = island_op.GetYield();
126       outlined_func.getBody().takeBody(island_op.body());
127 
128       // Replace the yield with a return
129       OpBuilder replacer(yield_op);
130       island_op.body().push_back(new Block);
131       replacer.create<ReturnOp>(yield_op.getLoc(), yield_op.getOperands());
132       yield_op.erase();
133     }
134 
135     // Remap the captured operands in the (former) island block with newly
136     // created entry block arguments in the function body.
137     {
138       Block &entry_block = outlined_func.getBody().front();
139       for (Value operand : operands) {
140         BlockArgument newArg = entry_block.addArgument(operand.getType());
141         replaceAllUsesInRegionWith(operand, newArg, outlined_func.getBody());
142       }
143     }
144 
145     // The function is in place in the nested module, create a call and yield in
146     // the original island.
147     OpBuilder builder = OpBuilder::atBlockEnd(&island_op.GetBody());
148     auto call_op = builder.create<mlir::TF::PartitionedCallOp>(
149         island_op.getLoc(), func_result_types, operands.getArrayRef(),
150         builder.getSymbolRefAttr(
151             kNestedModule, builder.getSymbolRefAttr(outlined_func.getName())),
152         /*config=*/builder.getStringAttr(""),
153         /*config_proto=*/builder.getStringAttr(""),
154         /*executor_type=*/builder.getStringAttr(""));
155     SmallVector<Value, 16> yield_operands(call_op.getResults());
156     builder.create<YieldOp>(island_op.getLoc(), yield_operands);
157   }
158 
159   // Outlined all the transitively called functions by moving them in the
160   // outlined module.
161   for (FuncOp func : outlined_module.getOps<FuncOp>()) {
162     func.walk([&](Operation *op) {
163       for (NamedAttribute attr : op->getAttrs()) {
164         if (auto symbol_ref = attr.second.dyn_cast<FlatSymbolRefAttr>()) {
165           MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table);
166           continue;
167         }
168         if (auto array_attr = attr.second.dyn_cast<ArrayAttr>()) {
169           for (const Attribute &attribute : array_attr) {
170             auto symbol_ref = attribute.dyn_cast<FlatSymbolRefAttr>();
171             if (!symbol_ref) continue;
172             MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table);
173           }
174         }
175       }
176     });
177   }
178 }
179 
180 PassRegistration<TPUBridgeExecutorIslandOutlining> tpu_pass;
181 
182 }  // namespace
183 
184 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFExecutorTPUV1IslandOutliningPass()185 CreateTFExecutorTPUV1IslandOutliningPass() {
186   return std::make_unique<TPUBridgeExecutorIslandOutlining>();
187 }
188 
189 }  // namespace tf_executor
190 }  // namespace mlir
191