• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass optimizes tf_saved_model.global_tensor ops.
17 
18 #include <cstddef>
19 #include <map>
20 #include <set>
21 
22 #include "llvm/ADT/DenseMap.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/Operation.h"  // from @llvm-project
28 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
29 #include "mlir/IR/Types.h"  // from @llvm-project
30 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
31 #include "mlir/Pass/Pass.h"  // from @llvm-project
32 #include "mlir/Support/LLVM.h"  // from @llvm-project
33 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
34 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
39 
40 namespace mlir {
41 namespace tf_saved_model {
42 namespace {
43 struct OptimizeGlobalTensorsPass
44     : public PassWrapper<OptimizeGlobalTensorsPass, OperationPass<ModuleOp>> {
getArgumentmlir::tf_saved_model::__anon8eb751700111::OptimizeGlobalTensorsPass45   StringRef getArgument() const final {
46     return "tf-saved-model-optimize-global-tensors";
47   }
48 
getDescriptionmlir::tf_saved_model::__anon8eb751700111::OptimizeGlobalTensorsPass49   StringRef getDescription() const final {
50     return "Optimize tf_saved_model.global_tensor's.";
51   }
52 
53   void runOnOperation() override;
54 };
55 
56 // A global tensor is bound to arguments of multiple funcs.
57 // This struct tracks which funcs (and which argument to that func) the global
58 // tensor is bound to.
59 struct GlobalTensorUse {
60   mutable FuncOp func;
61   size_t arg_index;
62 };
63 
64 using GlobalTensorUsesMap =
65     std::map<GlobalTensorOp, std::vector<GlobalTensorUse>>;
66 
IsImmutable(GlobalTensorOp global_tensor,ArrayRef<GlobalTensorUse> global_tensor_uses,const TF::ResourceAnalyzer & resource_analyzer)67 bool IsImmutable(GlobalTensorOp global_tensor,
68                  ArrayRef<GlobalTensorUse> global_tensor_uses,
69                  const TF::ResourceAnalyzer& resource_analyzer) {
70   // Global tensor is already known to be immutable.
71   if (!global_tensor.is_mutable()) {
72     return false;
73   }
74   // An exported global tensor that is not already known to be immutable might
75   // be externally mutated.
76   if (IsExported(global_tensor)) {
77     return false;
78   }
79 
80   // A global tensor is immutable if the resource analyzer deems it so.
81   for (auto& global_tensor_use : global_tensor_uses) {
82     auto arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index);
83     if (resource_analyzer.IsPotentiallyWritten(arg)) {
84       return false;
85     }
86   }
87   return true;
88 }
89 
CreateGlobalTensorUsesMap(ModuleOp module)90 GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) {
91   GlobalTensorUsesMap global_tensor_uses;
92 
93   SymbolTable symbol_table(module);
94   for (auto func : module.getOps<FuncOp>()) {
95     for (size_t i = 0, e = func.getNumArguments(); i < e; i++) {
96       auto sym =
97           func.getArgAttrOfType<SymbolRefAttr>(i, "tf_saved_model.bound_input");
98       if (!sym) {
99         continue;
100       }
101       auto global_tensor = symbol_table.lookup<GlobalTensorOp>(
102           sym.cast<FlatSymbolRefAttr>().getValue());
103       if (!global_tensor) {
104         continue;
105       }
106       global_tensor_uses[global_tensor].push_back({func, i});
107     }
108   }
109 
110   return global_tensor_uses;
111 }
112 
113 // Removes `is_mutable` attribute from tf_saved_model.global_tensor ops where we
114 // can prove it is safe to do so.
MarkGlobalTensorsImmutable(ModuleOp module,const GlobalTensorUsesMap & global_tensor_uses_map,const TF::ResourceAnalyzer & resource_analyzer)115 void MarkGlobalTensorsImmutable(
116     ModuleOp module, const GlobalTensorUsesMap& global_tensor_uses_map,
117     const TF::ResourceAnalyzer& resource_analyzer) {
118   for (const auto& kv : global_tensor_uses_map) {
119     auto global_tensor = kv.first;
120     const auto& global_tensor_uses = kv.second;
121     if (IsImmutable(global_tensor, global_tensor_uses, resource_analyzer)) {
122       global_tensor->removeAttr("is_mutable");
123     }
124   }
125 }
126 
EraseUnusedGlobalTensors(ModuleOp module,const GlobalTensorUsesMap & global_tensor_uses)127 void EraseUnusedGlobalTensors(ModuleOp module,
128                               const GlobalTensorUsesMap& global_tensor_uses) {
129   for (auto global_tensor :
130        llvm::make_early_inc_range(module.getOps<GlobalTensorOp>())) {
131     // If the tensor is exported, then it is used.
132     if (IsExported(global_tensor)) {
133       continue;
134     }
135     // If the tensor is bound to an argument, then it is used.
136     if (global_tensor_uses.find(global_tensor) != global_tensor_uses.end()) {
137       continue;
138     }
139     // Erase it.
140     global_tensor.erase();
141   }
142 }
143 
EraseUnusedBoundInputs(ModuleOp module)144 void EraseUnusedBoundInputs(ModuleOp module) {
145   for (auto func : module.getOps<FuncOp>()) {
146     SmallVector<unsigned, 4> args_to_erase;
147     for (int i = 0, e = func.getNumArguments(); i < e; i++) {
148       if (func.getArgAttr(i, "tf_saved_model.bound_input") &&
149           func.getArgument(i).use_empty()) {
150         args_to_erase.push_back(i);
151       }
152     }
153     func.eraseArguments(args_to_erase);
154   }
155 }
156 
runOnOperation()157 void OptimizeGlobalTensorsPass::runOnOperation() {
158   auto module = getOperation();
159   if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
160     return;
161   }
162 
163   EraseUnusedBoundInputs(module);
164 
165   TF::ResourceAnalyzer resource_analyzer(module);
166 
167   GlobalTensorUsesMap global_tensor_uses = CreateGlobalTensorUsesMap(module);
168 
169   MarkGlobalTensorsImmutable(module, global_tensor_uses, resource_analyzer);
170 
171   EraseUnusedGlobalTensors(module, global_tensor_uses);
172 }
173 
174 // For "opt" to pick up this pass.
175 PassRegistration<OptimizeGlobalTensorsPass> pass;
176 
177 }  // namespace
178 
CreateOptimizeGlobalTensorsPass()179 std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeGlobalTensorsPass() {
180   return std::make_unique<OptimizeGlobalTensorsPass>();
181 }
182 
183 }  // namespace tf_saved_model
184 }  // namespace mlir
185