1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass optimizes tf_saved_model.global_tensor ops.
17
18 #include <cstddef>
19 #include <map>
20 #include <set>
21
22 #include "llvm/ADT/DenseMap.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/Operation.h" // from @llvm-project
28 #include "mlir/IR/SymbolTable.h" // from @llvm-project
29 #include "mlir/IR/Types.h" // from @llvm-project
30 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
31 #include "mlir/Pass/Pass.h" // from @llvm-project
32 #include "mlir/Support/LLVM.h" // from @llvm-project
33 #include "mlir/Support/LogicalResult.h" // from @llvm-project
34 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
38
39 namespace mlir {
40 namespace tf_saved_model {
41 namespace {
42 struct OptimizeGlobalTensorsPass
43 : public PassWrapper<OptimizeGlobalTensorsPass, OperationPass<ModuleOp>> {
44 void runOnOperation() override;
45 };
46
47 // A global tensor is bound to arguments of multiple funcs.
48 // This struct tracks which funcs (and which argument to that func) the global
49 // tensor is bound to.
50 struct GlobalTensorUse {
51 mutable FuncOp func;
52 size_t arg_index;
53 };
54
55 using GlobalTensorUsesMap =
56 std::map<GlobalTensorOp, std::vector<GlobalTensorUse>>;
57
IsResourceType(Type type)58 bool IsResourceType(Type type) {
59 if (auto tensor_type = type.dyn_cast<TensorType>()) {
60 return tensor_type.getElementType().isa<TF::ResourceType>();
61 }
62 return false;
63 }
64
IsResource(Value value)65 bool IsResource(Value value) { return IsResourceType(value.getType()); }
66
67 class ResourceAnalyzer {
68 public:
ResourceAnalyzer(ModuleOp module)69 explicit ResourceAnalyzer(ModuleOp module) {
70 for (auto func : module.getOps<FuncOp>()) {
71 (void)AnalyzeFunc(func);
72 }
73 }
74
IsPotentiallyWritten(Value resource) const75 bool IsPotentiallyWritten(Value resource) const {
76 assert(IsResource(resource));
77 auto it = resource_infos_.find(resource);
78 if (it == resource_infos_.end()) {
79 return false;
80 }
81 return it->second.potentially_written;
82 }
83
84 private:
85 // Analyze the specified func for resource mutating operations, namely
86 // TF::AssignVariableOp, if so, set the resource associated as "potentially
87 // written". Do this recursively across the chain of funcs via call or control
88 // flow ops.
89 // TODO(ashwinm): Move to iterative traversal.
AnalyzeFunc(FuncOp func)90 LogicalResult AnalyzeFunc(FuncOp func) {
91 // Avoid infinite recursion.
92 if (!discovered_.insert(func).second) {
93 return success();
94 }
95
96 func.walk([&](Operation* op) {
97 if (isa<TF::ReadVariableOp, ReturnOp>(op)) {
98 return;
99 }
100 if (auto assign_variable = dyn_cast<TF::AssignVariableOp>(op)) {
101 SetPotentiallyWritten(assign_variable.resource());
102 return;
103 }
104 if (auto call = dyn_cast<CallOpInterface>(op)) {
105 if (auto func = dyn_cast<FuncOp>(call.resolveCallable())) {
106 PropagatePotentiallyWrittenUpFromCallee(func, call.getArgOperands());
107 }
108 return;
109 }
110 if (auto if_op = dyn_cast<TF::IfOp>(op)) {
111 for (auto callee : {if_op.then_function(), if_op.else_function()}) {
112 PropagatePotentiallyWrittenUpFromCallee(callee, if_op.input());
113 }
114 return;
115 }
116 if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
117 for (auto callee :
118 {while_op.cond_function(), while_op.body_function()}) {
119 PropagatePotentiallyWrittenUpFromCallee(callee, while_op.input());
120 }
121 return;
122 }
123 // For all other ops, we assume it mutates all resources it uses, so
124 // this errs on the side of being conservative. We should improve
125 // this by using either a property or a trait that clearly
126 // identifies ops with resource mutating behavior.
127 PropagatePotentiallyWrittenWithinUnhandledOp(op);
128 });
129 return success();
130 }
131
132 // If an op is not one of the handled ones, we assume all resource usages
133 // within its purview are mutating in nature.
PropagatePotentiallyWrittenWithinUnhandledOp(Operation * op)134 void PropagatePotentiallyWrittenWithinUnhandledOp(Operation* op) {
135 for (auto operand : op->getOperands()) {
136 if (IsResource(operand)) {
137 SetPotentiallyWritten(operand);
138 }
139 }
140 visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) {
141 if (IsResource(operand->get())) {
142 SetPotentiallyWritten(operand->get());
143 }
144 });
145 }
146
147 // Given a FuncOp associated with the callee and operands from the
148 // corresponding callOp, propagate the potentially written decision to the
149 // callOp's operands, if the corresponding func's arguments are potentially
150 // written resources.
PropagatePotentiallyWrittenUpFromCallee(FuncOp func,Operation::operand_range propagate_to)151 void PropagatePotentiallyWrittenUpFromCallee(
152 FuncOp func, Operation::operand_range propagate_to) {
153 (void)AnalyzeFunc(func);
154 for (auto t : llvm::zip(func.getArguments(), propagate_to)) {
155 if (!IsResource(std::get<0>(t))) {
156 continue;
157 }
158 if (IsPotentiallyWritten(std::get<0>(t))) {
159 SetPotentiallyWritten(std::get<1>(t));
160 }
161 }
162 }
163
SetPotentiallyWritten(Value resource)164 void SetPotentiallyWritten(Value resource) {
165 assert(IsResource(resource));
166 resource_infos_[resource].potentially_written = true;
167 }
168 struct ResourceInfo {
169 bool potentially_written = false;
170 };
171 // Key: Resource Value's
172 // Value: Information we know about that Value.
173 // Note that these Value's are in general in different functions.
174 DenseMap<Value, ResourceInfo> resource_infos_;
175 // The set of func's we already discovered.
176 DenseSet<FuncOp> discovered_;
177 };
178
IsImmutable(GlobalTensorOp global_tensor,ArrayRef<GlobalTensorUse> global_tensor_uses,const ResourceAnalyzer & resource_analyzer)179 bool IsImmutable(GlobalTensorOp global_tensor,
180 ArrayRef<GlobalTensorUse> global_tensor_uses,
181 const ResourceAnalyzer& resource_analyzer) {
182 // Global tensor is already known to be immutable.
183 if (!global_tensor.is_mutable()) {
184 return false;
185 }
186 // An exported global tensor that is not already known to be immutable might
187 // be externally mutated.
188 if (IsExported(global_tensor)) {
189 return false;
190 }
191
192 // A global tensor is immutable if the resource analyzer deems it so.
193 for (auto& global_tensor_use : global_tensor_uses) {
194 auto arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index);
195 if (resource_analyzer.IsPotentiallyWritten(arg)) {
196 return false;
197 }
198 }
199 return true;
200 }
201
CreateGlobalTensorUsesMap(ModuleOp module)202 GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) {
203 GlobalTensorUsesMap global_tensor_uses;
204
205 SymbolTable symbol_table(module);
206 for (auto func : module.getOps<FuncOp>()) {
207 for (size_t i = 0, e = func.getNumArguments(); i < e; i++) {
208 auto sym =
209 func.getArgAttrOfType<SymbolRefAttr>(i, "tf_saved_model.bound_input");
210 if (!sym) {
211 continue;
212 }
213 auto global_tensor = symbol_table.lookup<GlobalTensorOp>(
214 sym.cast<FlatSymbolRefAttr>().getValue());
215 if (!global_tensor) {
216 continue;
217 }
218 global_tensor_uses[global_tensor].push_back({func, i});
219 }
220 }
221
222 return global_tensor_uses;
223 }
224
225 // Removes `is_mutable` attribute from tf_saved_model.global_tensor ops where we
226 // can prove it is safe to do so.
MarkGlobalTensorsImmutable(ModuleOp module,const GlobalTensorUsesMap & global_tensor_uses_map,const ResourceAnalyzer & resource_analyzer)227 void MarkGlobalTensorsImmutable(
228 ModuleOp module, const GlobalTensorUsesMap& global_tensor_uses_map,
229 const ResourceAnalyzer& resource_analyzer) {
230 for (const auto& kv : global_tensor_uses_map) {
231 auto global_tensor = kv.first;
232 const auto& global_tensor_uses = kv.second;
233 if (IsImmutable(global_tensor, global_tensor_uses, resource_analyzer)) {
234 global_tensor.removeAttr("is_mutable");
235 }
236 }
237 }
238
EraseUnusedGlobalTensors(ModuleOp module,const GlobalTensorUsesMap & global_tensor_uses)239 void EraseUnusedGlobalTensors(ModuleOp module,
240 const GlobalTensorUsesMap& global_tensor_uses) {
241 for (auto global_tensor :
242 llvm::make_early_inc_range(module.getOps<GlobalTensorOp>())) {
243 // If the tensor is exported, then it is used.
244 if (IsExported(global_tensor)) {
245 continue;
246 }
247 // If the tensor is bound to an argument, then it is used.
248 if (global_tensor_uses.find(global_tensor) != global_tensor_uses.end()) {
249 continue;
250 }
251 // Erase it.
252 global_tensor.erase();
253 }
254 }
255
EraseUnusedBoundInputs(ModuleOp module)256 void EraseUnusedBoundInputs(ModuleOp module) {
257 for (auto func : module.getOps<FuncOp>()) {
258 SmallVector<unsigned, 4> args_to_erase;
259 for (int i = 0, e = func.getNumArguments(); i < e; i++) {
260 if (func.getArgAttr(i, "tf_saved_model.bound_input") &&
261 func.getArgument(i).use_empty()) {
262 args_to_erase.push_back(i);
263 }
264 }
265 func.eraseArguments(args_to_erase);
266 }
267 }
268
runOnOperation()269 void OptimizeGlobalTensorsPass::runOnOperation() {
270 auto module = getOperation();
271 if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
272 return;
273 }
274
275 EraseUnusedBoundInputs(module);
276
277 ResourceAnalyzer resource_analyzer(module);
278
279 GlobalTensorUsesMap global_tensor_uses = CreateGlobalTensorUsesMap(module);
280
281 MarkGlobalTensorsImmutable(module, global_tensor_uses, resource_analyzer);
282
283 EraseUnusedGlobalTensors(module, global_tensor_uses);
284 }
285
286 // For "opt" to pick up this pass.
287 PassRegistration<OptimizeGlobalTensorsPass> pass(
288 "tf-saved-model-optimize-global-tensors",
289 "Optimize tf_saved_model.global_tensor's.");
290
291 } // namespace
292
CreateOptimizeGlobalTensorsPass()293 std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeGlobalTensorsPass() {
294 return std::make_unique<OptimizeGlobalTensorsPass>();
295 }
296
297 } // namespace tf_saved_model
298 } // namespace mlir
299