1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
16
17 #include <unordered_map>
18 #include <unordered_set>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/framework/function.h"
24 #include "tensorflow/core/framework/node_def_builder.h"
25 #include "tensorflow/core/graph/graph.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28
29 namespace tensorflow {
30 namespace {
31
IsFunctionCall(const Node & node)32 bool IsFunctionCall(const Node& node) {
33 // TODO(iga): Handle non-PCO functions when we add multi-device support
34 // to regular function calls. Also, the GetFunctionDefAndAttrs assumes that
35 // the function name is stored in the `f` attribute of the node. That code
36 // will need to change as well.
37 const string& op_type = node.op_def().name();
38 return op_type == "PartitionedCall" || op_type == "StatefulPartitionedCall";
39 }
40
41 // Utility to set node's value in `cache` and `is_deep` to `value`.
Set(const Node & node,bool value,bool * is_deep,std::vector<absl::optional<bool>> * cache)42 Status Set(const Node& node, bool value, bool* is_deep,
43 std::vector<absl::optional<bool>>* cache) {
44 *is_deep = value;
45 (*cache)[node.id()] = value;
46 return Status::OK();
47 }
48
49 } // namespace
50
PlacerInspectionRequiredOpChecker(const Graph * graph)51 PlacerInspectionRequiredOpChecker::PlacerInspectionRequiredOpChecker(
52 const Graph* graph)
53 : PlacerInspectionRequiredOpChecker(graph, &graph->flib_def()) {}
54
PlacerInspectionRequiredOpChecker(const Graph * graph,const FunctionLibraryDefinition * flib_def)55 PlacerInspectionRequiredOpChecker::PlacerInspectionRequiredOpChecker(
56 const Graph* graph, const FunctionLibraryDefinition* flib_def)
57 : graph_(*graph), flib_def_(*flib_def) {
58 cache_.resize(graph_.num_node_ids());
59 }
60
IsPlacerInspectionRequired(const Node & node,bool * is_deep)61 Status PlacerInspectionRequiredOpChecker::IsPlacerInspectionRequired(
62 const Node& node, bool* is_deep) {
63 if (cache_[node.id()].has_value()) {
64 *is_deep = cache_[node.id()].value();
65 return Status::OK();
66 }
67
68 if (!IsFunctionCall(node)) {
69 return Set(node, false, is_deep, &cache_);
70 }
71 const FunctionDef* fdef;
72 NameAttrList func;
73 TF_RETURN_IF_ERROR(GetFunctionDefAndAttrs(flib_def_, node, &fdef, &func));
74 DataTypeVector types;
75 TF_RETURN_IF_ERROR(
76 OutputTypesForNode(AttrSlice(&func.attr()), fdef->signature(), &types));
77 for (DataType type : types) {
78 if (type == DT_RESOURCE) {
79 return Set(node, true, is_deep, &cache_);
80 }
81 }
82 return Set(node, false, is_deep, &cache_);
83 }
84
GetFunctionDefAndAttrs(const FunctionLibraryDefinition & flib_def,const Node & node,const FunctionDef ** fdef,NameAttrList * func)85 Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def,
86 const Node& node, const FunctionDef** fdef,
87 NameAttrList* func) {
88 TF_RETURN_IF_ERROR(GetNodeAttr(node.def(), "f", func));
89 const string& function_name = func->name();
90 *fdef = flib_def.Find(function_name);
91 if (*fdef == nullptr) {
92 return errors::InvalidArgument(
93 "Failed to find function \"", function_name,
94 "\" in function library: ", flib_def.ToProto().DebugString());
95 }
96 return Status::OK();
97 }
98
FunctionStack(const string & function_name)99 FunctionStack::FunctionStack(const string& function_name)
100 : current_function_name_(function_name) {}
101
Push(const Node * node_in_current_function,const string & new_current_function) const102 FunctionStack FunctionStack::Push(const Node* node_in_current_function,
103 const string& new_current_function) const {
104 FunctionStack new_stack(new_current_function);
105 new_stack.frames_ = frames_;
106 new_stack.frames_.emplace_back(current_function_name_,
107 node_in_current_function);
108 return new_stack;
109 }
110
HasFunction(const string & function_name) const111 bool FunctionStack::HasFunction(const string& function_name) const {
112 if (current_function_name_ == function_name) {
113 return true;
114 }
115 for (const Frame& frame : frames_) {
116 if (frame.function_name == function_name) {
117 return true;
118 }
119 }
120 return false;
121 }
122
FormatForError() const123 string FunctionStack::FormatForError() const {
124 std::vector<string> msgs;
125 for (int i = 0; i < frames_.size(); ++i) {
126 if (frames_[i].function_name.empty()) {
127 // Empty function body should only happen at the top level, i.e. i = 0.
128 // All internal frames should have valid function names.
129 msgs.push_back(absl::StrCat("Graph contains node ",
130 FormatNodeForError(*frames_[i].node)));
131
132 } else {
133 msgs.push_back(absl::StrCat(
134 "Function ", errors::FormatFunctionForError(frames_[i].function_name),
135 " contains node ", FormatNodeForError(*frames_[i].node)));
136 }
137 const string& fname = (i + 1 < frames_.size())
138 ? frames_[i + 1].function_name
139 : current_function_name_;
140 msgs.push_back(absl::StrCat("Node ", FormatNodeForError(*frames_[i].node),
141 " calls function ",
142 errors::FormatFunctionForError(fname)));
143 }
144 return absl::StrJoin(msgs, "\n ");
145 }
146
147 namespace {
148
149 using OutputEdgeMap = std::vector<std::vector<const Edge*>>;
150
151 constexpr char kIdentityOp[] = "Identity";
152
Uniquify(const string & candidate_name,std::unordered_set<string> * node_names)153 string Uniquify(const string& candidate_name,
154 std::unordered_set<string>* node_names) {
155 if (node_names->find(candidate_name) == node_names->end()) {
156 node_names->insert(candidate_name);
157 return candidate_name;
158 }
159
160 for (int counter = 0;; ++counter) {
161 string candidate = absl::StrCat(candidate_name, "_", counter);
162 if (node_names->find(candidate) == node_names->end()) {
163 node_names->insert(candidate);
164 return candidate;
165 }
166 }
167 }
168
AddInputIdentity(Node * node,int input_idx,Graph * graph,std::unordered_set<string> * node_names)169 Status AddInputIdentity(Node* node, int input_idx, Graph* graph,
170 std::unordered_set<string>* node_names) {
171 const Edge* edge;
172 TF_RETURN_IF_ERROR(node->input_edge(input_idx, &edge));
173
174 string identity_name = Uniquify(
175 absl::StrCat(edge->src()->name(), "_", node->name()), node_names);
176
177 NodeDefBuilder builder(identity_name, kIdentityOp);
178 builder.Attr("T", node->input_type(input_idx));
179 NodeDefBuilder::NodeOut input(edge->src()->name(), edge->src_output(),
180 node->input_type(input_idx));
181 builder.Input(input);
182 NodeDef identity_def;
183 TF_RETURN_IF_ERROR(builder.Finalize(&identity_def));
184 MergeDebugInfo(NodeDebugInfo(*node), &identity_def);
185
186 VLOG(6) << "Adding identity into " << edge->src()->name() << ":"
187 << edge->src_output() << " -> " << edge->dst()->name() << ":"
188 << input_idx << " \n"
189 << identity_def.DebugString();
190
191 Status status;
192 Node* identity_node = graph->AddNode(identity_def, &status);
193 if (!status.ok()) {
194 return status;
195 }
196 graph->AddEdge(edge->src(), edge->src_output(), identity_node, 0);
197
198 // Replace node's `input_idx` input with the new identity's 0'th output
199 TF_RETURN_IF_ERROR(graph->UpdateEdge(identity_node, 0, node, input_idx));
200
201 VLOG(6) << "Successfully inserted identity. Modified node: \n"
202 << node->DebugString();
203 return Status::OK();
204 }
205
206 struct EdgePtrCompare {
operator ()tensorflow::__anon1096285a0211::EdgePtrCompare207 bool operator()(const Edge* lhs, const Edge* rhs) const {
208 return lhs->id() < rhs->id();
209 }
210 };
211
AddOutputIdentities(Node * node,Graph * graph,std::unordered_set<string> * node_names)212 Status AddOutputIdentities(Node* node, Graph* graph,
213 std::unordered_set<string>* node_names) {
214 auto add_identity = [&](int src_output, const string& identity_name,
215 Node** identity_node) {
216 NodeDefBuilder builder(identity_name, kIdentityOp);
217 builder.Attr("T", node->output_type(src_output));
218 NodeDefBuilder::NodeOut input(node->name(), src_output,
219 node->output_type(src_output));
220 builder.Input(input);
221 NodeDef identity_def;
222 TF_RETURN_IF_ERROR(builder.Finalize(&identity_def));
223 MergeDebugInfo(NodeDebugInfo(*node), &identity_def);
224
225 Status status;
226 *identity_node = graph->AddNode(identity_def, &status);
227 if (!status.ok()) {
228 return status;
229 }
230 graph->AddEdge(node, src_output, *identity_node, 0);
231 return Status::OK();
232 };
233
234 // output_used[i] == true iff `node`'s i'th output is used
235 // in this graph
236 std::vector<bool> output_used(node->num_outputs(), false);
237 // Copy the set of edges since EdgeSet does not allow modifications
238 // to graph edges during iteration.
239 const EdgeSet& out_edges = node->out_edges();
240 std::vector<const Edge*> edge_vector(out_edges.begin(), out_edges.end());
241 std::sort(edge_vector.begin(), edge_vector.end(), EdgePtrCompare());
242 for (const Edge* edge : edge_vector) {
243 if (edge->IsControlEdge()) {
244 continue;
245 }
246 output_used[edge->src_output()] = true;
247
248 Node* dst = edge->dst();
249 int dst_input = edge->dst_input();
250 int src_output = edge->src_output();
251 string identity_name =
252 Uniquify(absl::StrCat(node->name(), "_", dst->name()), node_names);
253 Node* identity_node;
254 TF_RETURN_IF_ERROR(add_identity(src_output, identity_name, &identity_node));
255 VLOG(6) << "Adding identity into " << node->name() << ":" << src_output
256 << " -> " << dst->name() << ":" << dst_input << " \n"
257 << identity_node->DebugString();
258
259 // Make original dst node consume the new identity's output instead of
260 // `node`'s output.
261 TF_RETURN_IF_ERROR(graph->UpdateEdge(identity_node, 0, dst, dst_input));
262 }
263
264 for (int output_idx = 0; output_idx < node->num_outputs(); ++output_idx) {
265 if (output_used[output_idx]) {
266 continue;
267 }
268 // The output is unused in the graph. Just add an identity
269 // consuming it.
270 string identity_name = Uniquify(node->name(), node_names);
271 Node* identity_node;
272 TF_RETURN_IF_ERROR(add_identity(output_idx, identity_name, &identity_node));
273 VLOG(6) << "Added identity into " << node->name() << ":" << output_idx
274 << " -> <no consumer>: \n"
275 << identity_node->DebugString();
276 }
277 return Status::OK();
278 }
279
IsolateNode(Node * node,Graph * graph)280 Status IsolateNode(Node* node, Graph* graph) {
281 // We use `node_names` to make sure we pick unique names.
282 // We don't use graph->NewName() because it produces verbose names and
283 // does not actually ensure that they are unique (it assumes all names
284 // are generated using it, which is not true today).
285 std::unordered_set<string> node_names(graph->num_nodes());
286 for (Node* n : graph->nodes()) {
287 node_names.insert(n->name());
288 }
289
290 for (int i = 0; i < node->num_inputs(); ++i) {
291 TF_RETURN_IF_ERROR(AddInputIdentity(node, i, graph, &node_names));
292 }
293 TF_RETURN_IF_ERROR(AddOutputIdentities(node, graph, &node_names));
294 return Status::OK();
295 }
296
297 } // namespace
298
IsolatePlacerInspectionRequiredOps(const FunctionLibraryDefinition & flib_def,Graph * graph)299 Status IsolatePlacerInspectionRequiredOps(
300 const FunctionLibraryDefinition& flib_def, Graph* graph) {
301 PlacerInspectionRequiredOpChecker checker(graph, &flib_def);
302 // It is OK to add nodes to the graph during iteration.
303 // New nodes will get ids above current ids. The loop
304 // will loop over current nodes only because the op_nodes()
305 // iterator uses node ids to iterate.
306 // Because the new nodes will be higher ids, the caching in
307 // the checker will also work fine as new nodes are added.
308 for (Node* node : graph->op_nodes()) {
309 bool should_be_isolated = false;
310 TF_RETURN_IF_ERROR(
311 checker.IsPlacerInspectionRequired(*node, &should_be_isolated));
312 if (!should_be_isolated) {
313 continue;
314 }
315 TF_RETURN_IF_ERROR(IsolateNode(node, graph));
316 }
317
318 return Status::OK();
319 }
320
321 } // namespace tensorflow
322