• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h"
17 
18 #include "llvm/Support/Casting.h"
19 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
20 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
21 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
22 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
24 
25 namespace mlir {
26 namespace TF {
27 namespace {
IsResourceType(Type type)28 bool IsResourceType(Type type) {
29   if (auto tensor_type = type.dyn_cast<TensorType>()) {
30     return tensor_type.getElementType().isa<TF::ResourceType>();
31   }
32   return false;
33 }
34 
IsResource(Value value)35 bool IsResource(Value value) { return IsResourceType(value.getType()); }
36 
37 // Helper that returns the FuncOp that is the SessionInit function which
38 // will be called to initialize all resources.
39 // Returns nullptr if no function is found.
GetSessionInitializerFunc(ModuleOp module)40 FuncOp GetSessionInitializerFunc(ModuleOp module) {
41   auto session_init_op = tf_saved_model::GetSessionInitializerOp(module);
42   if (session_init_op && !session_init_op.initializers().empty()) {
43     SymbolTable symbol_table(module);
44     FuncOp init_func_op = symbol_table.lookup<mlir::FuncOp>(
45         session_init_op.initializers()[0].cast<FlatSymbolRefAttr>().getValue());
46     return init_func_op;
47   }
48   return nullptr;
49 }
50 
51 // Returns ID for identifying a resource.
GetResourceKey(Operation * op)52 std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef> GetResourceKey(
53     Operation* op) {
54   llvm::StringRef device;
55   if (auto attr = op->getAttrOfType<mlir::StringAttr>("device")) {
56     device = attr.getValue();
57   }
58 
59   llvm::StringRef container;
60   if (auto attr = op->getAttrOfType<mlir::StringAttr>("container")) {
61     container = attr.getValue();
62   }
63 
64   llvm::StringRef shared_name;
65   if (auto attr = op->getAttrOfType<mlir::StringAttr>("shared_name")) {
66     shared_name = attr.getValue();
67   }
68 
69   return std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>{
70       device, container, shared_name};
71 }
72 }  // namespace
ResourceAnalyzer(ModuleOp module,bool skip_session_init)73 ResourceAnalyzer::ResourceAnalyzer(ModuleOp module, bool skip_session_init) {
74   auto session_init_func = GetSessionInitializerFunc(module);
75   for (auto func : module.getOps<FuncOp>()) {
76     if (skip_session_init && func == session_init_func) continue;
77     (void)AnalyzeRegion(func.getRegion());
78   }
79 }
80 
SetPotentiallyWritten(Value resource)81 void ResourceAnalyzer::SetPotentiallyWritten(Value resource) {
82   assert(IsResource(resource));
83   resource_infos_[resource].potentially_written = true;
84   auto* operation = resource.getDefiningOp();
85   if (operation && llvm::isa<TF::VarHandleOp>(operation)) {
86     mutable_variables_.insert(GetResourceKey(operation));
87   }
88 }
89 
IsPotentiallyWritten(Value resource) const90 bool ResourceAnalyzer::IsPotentiallyWritten(Value resource) const {
91   assert(IsResource(resource));
92   auto* operation = resource.getDefiningOp();
93   if (operation && llvm::isa<TF::VarHandleOp>(operation))
94     return mutable_variables_.contains(GetResourceKey(operation));
95   auto it = resource_infos_.find(resource);
96   if (it == resource_infos_.end()) {
97     return false;
98   }
99   return it->second.potentially_written;
100 }
101 
102 // Analyze the specified region for resource mutating operations, namely
103 // TF::AssignVariableOp, if so, set the resource associated as "potentially
104 // written". Do this recursively across the chain of regions via call or
105 // control flow ops.
106 // TODO(ashwinm): Move to iterative traversal.
AnalyzeRegion(Region & region)107 LogicalResult ResourceAnalyzer::AnalyzeRegion(Region& region) {
108   // Avoid infinite recursion.
109   if (!discovered_.insert(&region).second) {
110     return success();
111   }
112 
113   region.walk([&](Operation* op) {
114     if (isa<TF::ReadVariableOp, ReturnOp, YieldOp>(op)) {
115       return;
116     }
117     if (auto assign_variable = dyn_cast<TF::AssignVariableOp>(op)) {
118       SetPotentiallyWritten(assign_variable.resource());
119       return;
120     }
121     if (auto call = dyn_cast<CallOpInterface>(op)) {
122       if (auto func = dyn_cast<FuncOp>(call.resolveCallable())) {
123         PropagatePotentiallyWrittenUpFromCallee(func.getRegion(),
124                                                 call.getArgOperands());
125       }
126       return;
127     }
128     if (auto if_op = dyn_cast<TF::IfOp>(op)) {
129       for (auto callee : {if_op.then_function(), if_op.else_function()}) {
130         PropagatePotentiallyWrittenUpFromCallee(callee.getRegion(),
131                                                 if_op.input());
132       }
133       return;
134     }
135     if (auto if_op = dyn_cast<TF::IfRegionOp>(op)) {
136       PropagatePotentiallyWrittenUpFromCallee(if_op.then_branch(),
137                                               if_op.getODSOperands(1));
138       PropagatePotentiallyWrittenUpFromCallee(if_op.else_branch(),
139                                               if_op.getODSOperands(1));
140       return;
141     }
142     if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
143       for (auto callee : {while_op.cond_function(), while_op.body_function()}) {
144         PropagatePotentiallyWrittenUpFromCallee(callee.getRegion(),
145                                                 while_op.input());
146       }
147       return;
148     }
149     if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) {
150       PropagatePotentiallyWrittenUpFromCallee(while_op.cond(),
151                                               while_op.input());
152       PropagatePotentiallyWrittenUpFromCallee(while_op.body(),
153                                               while_op.input());
154       return;
155     }
156     // For all other ops, we assume it mutates all resources it uses, so
157     // this errs on the side of being conservative. We should improve
158     // this by using either a property or a trait that clearly
159     // identifies ops with resource mutating behavior.
160     PropagatePotentiallyWrittenWithinUnhandledOp(op);
161   });
162   return success();
163 }
164 
PropagatePotentiallyWrittenWithinUnhandledOp(Operation * op)165 void ResourceAnalyzer::PropagatePotentiallyWrittenWithinUnhandledOp(
166     Operation* op) {
167   for (auto operand : op->getOperands()) {
168     if (IsResource(operand)) {
169       SetPotentiallyWritten(operand);
170     }
171   }
172   visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) {
173     if (IsResource(operand->get())) {
174       SetPotentiallyWritten(operand->get());
175     }
176   });
177 }
178 
PropagatePotentiallyWrittenUpFromCallee(Region & region,Operation::operand_range propagate_to)179 void ResourceAnalyzer::PropagatePotentiallyWrittenUpFromCallee(
180     Region& region, Operation::operand_range propagate_to) {
181   (void)AnalyzeRegion(region);
182   for (auto t : llvm::zip(region.getArguments(), propagate_to)) {
183     if (!IsResource(std::get<0>(t))) {
184       continue;
185     }
186     if (IsPotentiallyWritten(std::get<0>(t))) {
187       SetPotentiallyWritten(std::get<1>(t));
188     }
189   }
190 }
191 }  // namespace TF
192 }  // namespace mlir
193