1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h"
17
18 #include "llvm/Support/Casting.h"
19 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
20 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
21 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
22 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
24
25 namespace mlir {
26 namespace TF {
27 namespace {
IsResourceType(Type type)28 bool IsResourceType(Type type) {
29 if (auto tensor_type = type.dyn_cast<TensorType>()) {
30 return tensor_type.getElementType().isa<TF::ResourceType>();
31 }
32 return false;
33 }
34
IsResource(Value value)35 bool IsResource(Value value) { return IsResourceType(value.getType()); }
36
37 // Helper that returns the FuncOp that is the SessionInit function which
38 // will be called to initialize all resources.
39 // Returns nullptr if no function is found.
GetSessionInitializerFunc(ModuleOp module)40 FuncOp GetSessionInitializerFunc(ModuleOp module) {
41 auto session_init_op = tf_saved_model::GetSessionInitializerOp(module);
42 if (session_init_op && !session_init_op.initializers().empty()) {
43 SymbolTable symbol_table(module);
44 FuncOp init_func_op = symbol_table.lookup<mlir::FuncOp>(
45 session_init_op.initializers()[0].cast<FlatSymbolRefAttr>().getValue());
46 return init_func_op;
47 }
48 return nullptr;
49 }
50
51 // Returns ID for identifying a resource.
GetResourceKey(Operation * op)52 std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef> GetResourceKey(
53 Operation* op) {
54 llvm::StringRef device;
55 if (auto attr = op->getAttrOfType<mlir::StringAttr>("device")) {
56 device = attr.getValue();
57 }
58
59 llvm::StringRef container;
60 if (auto attr = op->getAttrOfType<mlir::StringAttr>("container")) {
61 container = attr.getValue();
62 }
63
64 llvm::StringRef shared_name;
65 if (auto attr = op->getAttrOfType<mlir::StringAttr>("shared_name")) {
66 shared_name = attr.getValue();
67 }
68
69 return std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>{
70 device, container, shared_name};
71 }
72 } // namespace
ResourceAnalyzer(ModuleOp module,bool skip_session_init)73 ResourceAnalyzer::ResourceAnalyzer(ModuleOp module, bool skip_session_init) {
74 auto session_init_func = GetSessionInitializerFunc(module);
75 for (auto func : module.getOps<FuncOp>()) {
76 if (skip_session_init && func == session_init_func) continue;
77 (void)AnalyzeRegion(func.getRegion());
78 }
79 }
80
SetPotentiallyWritten(Value resource)81 void ResourceAnalyzer::SetPotentiallyWritten(Value resource) {
82 assert(IsResource(resource));
83 resource_infos_[resource].potentially_written = true;
84 auto* operation = resource.getDefiningOp();
85 if (operation && llvm::isa<TF::VarHandleOp>(operation)) {
86 mutable_variables_.insert(GetResourceKey(operation));
87 }
88 }
89
IsPotentiallyWritten(Value resource) const90 bool ResourceAnalyzer::IsPotentiallyWritten(Value resource) const {
91 assert(IsResource(resource));
92 auto* operation = resource.getDefiningOp();
93 if (operation && llvm::isa<TF::VarHandleOp>(operation))
94 return mutable_variables_.contains(GetResourceKey(operation));
95 auto it = resource_infos_.find(resource);
96 if (it == resource_infos_.end()) {
97 return false;
98 }
99 return it->second.potentially_written;
100 }
101
102 // Analyze the specified region for resource mutating operations, namely
103 // TF::AssignVariableOp, if so, set the resource associated as "potentially
104 // written". Do this recursively across the chain of regions via call or
105 // control flow ops.
106 // TODO(ashwinm): Move to iterative traversal.
AnalyzeRegion(Region & region)107 LogicalResult ResourceAnalyzer::AnalyzeRegion(Region& region) {
108 // Avoid infinite recursion.
109 if (!discovered_.insert(®ion).second) {
110 return success();
111 }
112
113 region.walk([&](Operation* op) {
114 if (isa<TF::ReadVariableOp, ReturnOp, YieldOp>(op)) {
115 return;
116 }
117 if (auto assign_variable = dyn_cast<TF::AssignVariableOp>(op)) {
118 SetPotentiallyWritten(assign_variable.resource());
119 return;
120 }
121 if (auto call = dyn_cast<CallOpInterface>(op)) {
122 if (auto func = dyn_cast<FuncOp>(call.resolveCallable())) {
123 PropagatePotentiallyWrittenUpFromCallee(func.getRegion(),
124 call.getArgOperands());
125 }
126 return;
127 }
128 if (auto if_op = dyn_cast<TF::IfOp>(op)) {
129 for (auto callee : {if_op.then_function(), if_op.else_function()}) {
130 PropagatePotentiallyWrittenUpFromCallee(callee.getRegion(),
131 if_op.input());
132 }
133 return;
134 }
135 if (auto if_op = dyn_cast<TF::IfRegionOp>(op)) {
136 PropagatePotentiallyWrittenUpFromCallee(if_op.then_branch(),
137 if_op.getODSOperands(1));
138 PropagatePotentiallyWrittenUpFromCallee(if_op.else_branch(),
139 if_op.getODSOperands(1));
140 return;
141 }
142 if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
143 for (auto callee : {while_op.cond_function(), while_op.body_function()}) {
144 PropagatePotentiallyWrittenUpFromCallee(callee.getRegion(),
145 while_op.input());
146 }
147 return;
148 }
149 if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) {
150 PropagatePotentiallyWrittenUpFromCallee(while_op.cond(),
151 while_op.input());
152 PropagatePotentiallyWrittenUpFromCallee(while_op.body(),
153 while_op.input());
154 return;
155 }
156 // For all other ops, we assume it mutates all resources it uses, so
157 // this errs on the side of being conservative. We should improve
158 // this by using either a property or a trait that clearly
159 // identifies ops with resource mutating behavior.
160 PropagatePotentiallyWrittenWithinUnhandledOp(op);
161 });
162 return success();
163 }
164
PropagatePotentiallyWrittenWithinUnhandledOp(Operation * op)165 void ResourceAnalyzer::PropagatePotentiallyWrittenWithinUnhandledOp(
166 Operation* op) {
167 for (auto operand : op->getOperands()) {
168 if (IsResource(operand)) {
169 SetPotentiallyWritten(operand);
170 }
171 }
172 visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) {
173 if (IsResource(operand->get())) {
174 SetPotentiallyWritten(operand->get());
175 }
176 });
177 }
178
PropagatePotentiallyWrittenUpFromCallee(Region & region,Operation::operand_range propagate_to)179 void ResourceAnalyzer::PropagatePotentiallyWrittenUpFromCallee(
180 Region& region, Operation::operand_range propagate_to) {
181 (void)AnalyzeRegion(region);
182 for (auto t : llvm::zip(region.getArguments(), propagate_to)) {
183 if (!IsResource(std::get<0>(t))) {
184 continue;
185 }
186 if (IsPotentiallyWritten(std::get<0>(t))) {
187 SetPotentiallyWritten(std::get<1>(t));
188 }
189 }
190 }
191 } // namespace TF
192 } // namespace mlir
193