1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/STLExtras.h"
17 #include "mlir/IR/SymbolTable.h" // from @llvm-project
18 #include "mlir/Pass/Pass.h" // from @llvm-project
19 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
20 #include "mlir/Support/LLVM.h" // from @llvm-project
21 #include "mlir/Support/LogicalResult.h" // from @llvm-project
22 #include "mlir/Transforms/Utils.h" // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
24
25 namespace mlir {
26 namespace TF {
27
28 namespace {
29
30 // Clones FuncOp's until they have a single use only (or no users).
31 //
32 // The tf-shape-inference pass doesn't support functions that have more than
33 // a single use. But some real code from frontends does end up creating code
34 // like that. For example, the same LSTM cell function or loop body function
35 // will be reused.
36 //
37 // This pass clones functions as needed to establish the invariant that all
38 // functions have a single use. This can in principle cause exponential code
39 // size bloat, and should in general be guided by a proper cost model.
40 //
41 // There are two factors which should be considered by a principled replacement
42 // to this pass:
43 //
44 // 1. TF currently relies on "sufficiently good shape inference" for
45 // correctness so for now the cost of doing this seems acceptable since
46 // pathological cases haven't hit us yet.
47 //
48 // 2. Cloning functions can help by allowing code to be specialized (much as
49 // inlining does). In fact, tf-shape-inference attempts to do specialization
50 // of callees which is difficult if callees have multiple uses.
51 class GuaranteeAllFuncsOneUse
52 : public PassWrapper<GuaranteeAllFuncsOneUse, OperationPass<ModuleOp>> {
53 public:
runOnOperation()54 void runOnOperation() override {
55 if (failed(Run())) {
56 signalPassFailure();
57 }
58 }
59
getArgument() const60 StringRef getArgument() const final {
61 // This is the argument used to refer to the pass in
62 // the textual format (on the commandline for example).
63 return "tf-guarantee-all-funcs-one-use";
64 }
getDescription() const65 StringRef getDescription() const final {
66 // This is a brief description of the pass.
67 return "Guarantee all FuncOp's have only a single use.";
68 }
69
Run()70 LogicalResult Run() {
71 auto module = getOperation();
72
73 // Overall strategy:
74 // Fixed point iteration, iteratively applying a rule that clones
75 // any FuncOp with more than one use to eliminate its uses.
76 SymbolTableCollection symbol_table_collection;
77 SymbolTable &symbol_table = symbol_table_collection.getSymbolTable(module);
78 bool made_changes = false;
79
80 // This value needs to be low enough to actually stop compilation in a
81 // reasonable time, but not too low that it blocks real programs.
82 // This number was chosen semi-randomly.
83 // TODO(jpienaar): Switch to a more context aware heuristic.
84 const int kMaxClones = 10000;
85 int num_clones = 0;
86 do {
87 SymbolUserMap symbol_users(symbol_table_collection, module);
88
89 made_changes = false;
90 for (auto func : llvm::make_early_inc_range(module.getOps<FuncOp>())) {
91 ArrayRef<Operation *> users = symbol_users.getUsers(func);
92 if (users.size() <= 1) {
93 continue;
94 }
95
96 // At this point, we know we are going to change the module.
97 made_changes = true;
98 for (Operation *user : users.drop_front()) {
99 if (num_clones++ > kMaxClones) {
100 return func.emitError()
101 << "reached cloning limit (likely recursive call graph or "
102 "repeated diamond-like call structure "
103 "or just very large program)";
104 }
105 FuncOp new_func = func.clone();
106 symbol_table.insert(new_func);
107 new_func.setPrivate();
108 if (failed(SymbolTable::replaceAllSymbolUses(func, new_func.getName(),
109 user))) {
110 return func.emitError() << "could not replace symbol use";
111 }
112 }
113 }
114 } while (made_changes);
115
116 return success();
117 }
118 };
119
120 } // namespace
121
CreateGuaranteeAllFuncsOneUsePass()122 std::unique_ptr<OperationPass<ModuleOp>> CreateGuaranteeAllFuncsOneUsePass() {
123 return std::make_unique<GuaranteeAllFuncsOneUse>();
124 }
125
126 static PassRegistration<GuaranteeAllFuncsOneUse> pass;
127
128 } // namespace TF
129
130 } // namespace mlir
131