1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass optimizes tf_saved_model.global_tensor ops.
17
18 #include <cstddef>
19 #include <map>
20 #include <set>
21
22 #include "llvm/ADT/DenseMap.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/Operation.h" // from @llvm-project
28 #include "mlir/IR/SymbolTable.h" // from @llvm-project
29 #include "mlir/IR/Types.h" // from @llvm-project
30 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
31 #include "mlir/Pass/Pass.h" // from @llvm-project
32 #include "mlir/Support/LLVM.h" // from @llvm-project
33 #include "mlir/Support/LogicalResult.h" // from @llvm-project
34 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
39
40 namespace mlir {
41 namespace tf_saved_model {
42 namespace {
43 struct OptimizeGlobalTensorsPass
44 : public PassWrapper<OptimizeGlobalTensorsPass, OperationPass<ModuleOp>> {
getArgumentmlir::tf_saved_model::__anon8eb751700111::OptimizeGlobalTensorsPass45 StringRef getArgument() const final {
46 return "tf-saved-model-optimize-global-tensors";
47 }
48
getDescriptionmlir::tf_saved_model::__anon8eb751700111::OptimizeGlobalTensorsPass49 StringRef getDescription() const final {
50 return "Optimize tf_saved_model.global_tensor's.";
51 }
52
53 void runOnOperation() override;
54 };
55
56 // A global tensor is bound to arguments of multiple funcs.
57 // This struct tracks which funcs (and which argument to that func) the global
58 // tensor is bound to.
59 struct GlobalTensorUse {
60 mutable FuncOp func;
61 size_t arg_index;
62 };
63
64 using GlobalTensorUsesMap =
65 std::map<GlobalTensorOp, std::vector<GlobalTensorUse>>;
66
IsImmutable(GlobalTensorOp global_tensor,ArrayRef<GlobalTensorUse> global_tensor_uses,const TF::ResourceAnalyzer & resource_analyzer)67 bool IsImmutable(GlobalTensorOp global_tensor,
68 ArrayRef<GlobalTensorUse> global_tensor_uses,
69 const TF::ResourceAnalyzer& resource_analyzer) {
70 // Global tensor is already known to be immutable.
71 if (!global_tensor.is_mutable()) {
72 return false;
73 }
74 // An exported global tensor that is not already known to be immutable might
75 // be externally mutated.
76 if (IsExported(global_tensor)) {
77 return false;
78 }
79
80 // A global tensor is immutable if the resource analyzer deems it so.
81 for (auto& global_tensor_use : global_tensor_uses) {
82 auto arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index);
83 if (resource_analyzer.IsPotentiallyWritten(arg)) {
84 return false;
85 }
86 }
87 return true;
88 }
89
CreateGlobalTensorUsesMap(ModuleOp module)90 GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) {
91 GlobalTensorUsesMap global_tensor_uses;
92
93 SymbolTable symbol_table(module);
94 for (auto func : module.getOps<FuncOp>()) {
95 for (size_t i = 0, e = func.getNumArguments(); i < e; i++) {
96 auto sym =
97 func.getArgAttrOfType<SymbolRefAttr>(i, "tf_saved_model.bound_input");
98 if (!sym) {
99 continue;
100 }
101 auto global_tensor = symbol_table.lookup<GlobalTensorOp>(
102 sym.cast<FlatSymbolRefAttr>().getValue());
103 if (!global_tensor) {
104 continue;
105 }
106 global_tensor_uses[global_tensor].push_back({func, i});
107 }
108 }
109
110 return global_tensor_uses;
111 }
112
113 // Removes `is_mutable` attribute from tf_saved_model.global_tensor ops where we
114 // can prove it is safe to do so.
MarkGlobalTensorsImmutable(ModuleOp module,const GlobalTensorUsesMap & global_tensor_uses_map,const TF::ResourceAnalyzer & resource_analyzer)115 void MarkGlobalTensorsImmutable(
116 ModuleOp module, const GlobalTensorUsesMap& global_tensor_uses_map,
117 const TF::ResourceAnalyzer& resource_analyzer) {
118 for (const auto& kv : global_tensor_uses_map) {
119 auto global_tensor = kv.first;
120 const auto& global_tensor_uses = kv.second;
121 if (IsImmutable(global_tensor, global_tensor_uses, resource_analyzer)) {
122 global_tensor->removeAttr("is_mutable");
123 }
124 }
125 }
126
EraseUnusedGlobalTensors(ModuleOp module,const GlobalTensorUsesMap & global_tensor_uses)127 void EraseUnusedGlobalTensors(ModuleOp module,
128 const GlobalTensorUsesMap& global_tensor_uses) {
129 for (auto global_tensor :
130 llvm::make_early_inc_range(module.getOps<GlobalTensorOp>())) {
131 // If the tensor is exported, then it is used.
132 if (IsExported(global_tensor)) {
133 continue;
134 }
135 // If the tensor is bound to an argument, then it is used.
136 if (global_tensor_uses.find(global_tensor) != global_tensor_uses.end()) {
137 continue;
138 }
139 // Erase it.
140 global_tensor.erase();
141 }
142 }
143
EraseUnusedBoundInputs(ModuleOp module)144 void EraseUnusedBoundInputs(ModuleOp module) {
145 for (auto func : module.getOps<FuncOp>()) {
146 SmallVector<unsigned, 4> args_to_erase;
147 for (int i = 0, e = func.getNumArguments(); i < e; i++) {
148 if (func.getArgAttr(i, "tf_saved_model.bound_input") &&
149 func.getArgument(i).use_empty()) {
150 args_to_erase.push_back(i);
151 }
152 }
153 func.eraseArguments(args_to_erase);
154 }
155 }
156
runOnOperation()157 void OptimizeGlobalTensorsPass::runOnOperation() {
158 auto module = getOperation();
159 if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
160 return;
161 }
162
163 EraseUnusedBoundInputs(module);
164
165 TF::ResourceAnalyzer resource_analyzer(module);
166
167 GlobalTensorUsesMap global_tensor_uses = CreateGlobalTensorUsesMap(module);
168
169 MarkGlobalTensorsImmutable(module, global_tensor_uses, resource_analyzer);
170
171 EraseUnusedGlobalTensors(module, global_tensor_uses);
172 }
173
174 // For "opt" to pick up this pass.
175 PassRegistration<OptimizeGlobalTensorsPass> pass;
176
177 } // namespace
178
CreateOptimizeGlobalTensorsPass()179 std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeGlobalTensorsPass() {
180 return std::make_unique<OptimizeGlobalTensorsPass>();
181 }
182
183 } // namespace tf_saved_model
184 } // namespace mlir
185