1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
16
17 #include <unordered_map>
18 #include <unordered_set>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/node_def_builder.h"
24 #include "tensorflow/core/graph/graph.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status.h"
27
28 namespace tensorflow {
29 namespace {
30
IsFunctionCall(const Node & node)31 bool IsFunctionCall(const Node& node) {
32 // TODO(iga): Handle non-PCO functions when we add multi-device support
33 // to regular function calls. Also, the GetFunctionDefAndAttrs assumes that
34 // the function name is stored in the `f` attribute of the node. That code
35 // will need to change as well.
36 const string& op_type = node.op_def().name();
37 return op_type == "PartitionedCall" || op_type == "StatefulPartitionedCall";
38 }
39
40 // Utility to set node's value in `cache` and `is_deep` to `value`.
Set(const Node & node,bool value,bool * is_deep,std::vector<absl::optional<bool>> * cache)41 Status Set(const Node& node, bool value, bool* is_deep,
42 std::vector<absl::optional<bool>>* cache) {
43 *is_deep = value;
44 (*cache)[node.id()] = value;
45 return Status::OK();
46 }
47
48 } // namespace
49
PlacerInspectionRequiredOpChecker(const Graph * graph)50 PlacerInspectionRequiredOpChecker::PlacerInspectionRequiredOpChecker(
51 const Graph* graph)
52 : PlacerInspectionRequiredOpChecker(graph, &graph->flib_def()) {}
53
PlacerInspectionRequiredOpChecker(const Graph * graph,const FunctionLibraryDefinition * flib_def)54 PlacerInspectionRequiredOpChecker::PlacerInspectionRequiredOpChecker(
55 const Graph* graph, const FunctionLibraryDefinition* flib_def)
56 : graph_(*graph), flib_def_(*flib_def) {
57 cache_.resize(graph_.num_node_ids());
58 }
59
IsPlacerInspectionRequired(const Node & node,bool * is_deep)60 Status PlacerInspectionRequiredOpChecker::IsPlacerInspectionRequired(
61 const Node& node, bool* is_deep) {
62 if (cache_[node.id()].has_value()) {
63 *is_deep = cache_[node.id()].value();
64 return Status::OK();
65 }
66
67 if (!IsFunctionCall(node)) {
68 return Set(node, false, is_deep, &cache_);
69 }
70 const FunctionDef* fdef;
71 NameAttrList func;
72 TF_RETURN_IF_ERROR(GetFunctionDefAndAttrs(flib_def_, node, &fdef, &func));
73 DataTypeVector types;
74 TF_RETURN_IF_ERROR(
75 OutputTypesForNode(AttrSlice(&func.attr()), fdef->signature(), &types));
76 for (DataType type : types) {
77 if (type == DT_RESOURCE) {
78 return Set(node, true, is_deep, &cache_);
79 }
80 }
81 return Set(node, false, is_deep, &cache_);
82 }
83
GetFunctionDefAndAttrs(const FunctionLibraryDefinition & flib_def,const Node & node,const FunctionDef ** fdef,NameAttrList * func)84 Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def,
85 const Node& node, const FunctionDef** fdef,
86 NameAttrList* func) {
87 TF_RETURN_IF_ERROR(GetNodeAttr(node.def(), "f", func));
88 const string& function_name = func->name();
89 *fdef = flib_def.Find(function_name);
90 if (*fdef == nullptr) {
91 return errors::InvalidArgument(
92 "Failed to find function \"", function_name,
93 "\" in function library: ", flib_def.ToProto().DebugString());
94 }
95 return Status::OK();
96 }
97
FunctionStack(const string & function_name)98 FunctionStack::FunctionStack(const string& function_name)
99 : current_function_name_(function_name) {}
100
Push(const Node * node_in_current_function,const string & new_current_function) const101 FunctionStack FunctionStack::Push(const Node* node_in_current_function,
102 const string& new_current_function) const {
103 FunctionStack new_stack(new_current_function);
104 new_stack.frames_ = frames_;
105 new_stack.frames_.emplace_back(current_function_name_,
106 node_in_current_function);
107 return new_stack;
108 }
109
HasFunction(const string & function_name) const110 bool FunctionStack::HasFunction(const string& function_name) const {
111 if (current_function_name_ == function_name) {
112 return true;
113 }
114 for (const Frame& frame : frames_) {
115 if (frame.function_name == function_name) {
116 return true;
117 }
118 }
119 return false;
120 }
121
FormatForError() const122 string FunctionStack::FormatForError() const {
123 std::vector<string> msgs;
124 for (int i = 0; i < frames_.size(); ++i) {
125 if (frames_[i].function_name.empty()) {
126 // Empty function body should only happen at the top level, i.e. i = 0.
127 // All internal frames should have valid function names.
128 msgs.push_back(absl::StrCat("Graph contains node ",
129 FormatNodeForError(*frames_[i].node)));
130
131 } else {
132 msgs.push_back(absl::StrCat(
133 "Function ", errors::FormatFunctionForError(frames_[i].function_name),
134 " contains node ", FormatNodeForError(*frames_[i].node)));
135 }
136 const string& fname = (i + 1 < frames_.size())
137 ? frames_[i + 1].function_name
138 : current_function_name_;
139 msgs.push_back(absl::StrCat("Node ", FormatNodeForError(*frames_[i].node),
140 " calls function ",
141 errors::FormatFunctionForError(fname)));
142 }
143 return absl::StrJoin(msgs, "\n ");
144 }
145
146 namespace {
147
148 using OutputEdgeMap = std::vector<std::vector<const Edge*>>;
149
150 constexpr char kIdentityOp[] = "Identity";
151
Uniquify(const string & candidate_name,std::unordered_set<string> * node_names)152 string Uniquify(const string& candidate_name,
153 std::unordered_set<string>* node_names) {
154 if (node_names->find(candidate_name) == node_names->end()) {
155 node_names->insert(candidate_name);
156 return candidate_name;
157 }
158
159 for (int counter = 0;; ++counter) {
160 string candidate = absl::StrCat(candidate_name, "_", counter);
161 if (node_names->find(candidate) == node_names->end()) {
162 node_names->insert(candidate);
163 return candidate;
164 }
165 }
166 }
167
AddInputIdentity(Node * node,int input_idx,Graph * graph,std::unordered_set<string> * node_names)168 Status AddInputIdentity(Node* node, int input_idx, Graph* graph,
169 std::unordered_set<string>* node_names) {
170 const Edge* edge;
171 TF_RETURN_IF_ERROR(node->input_edge(input_idx, &edge));
172
173 string identity_name = Uniquify(
174 absl::StrCat(edge->src()->name(), "_", node->name()), node_names);
175
176 NodeDefBuilder builder(identity_name, kIdentityOp);
177 builder.Attr("T", node->input_type(input_idx));
178 NodeDefBuilder::NodeOut input(edge->src()->name(), edge->src_output(),
179 node->input_type(input_idx));
180 builder.Input(input);
181 NodeDef identity_def;
182 TF_RETURN_IF_ERROR(builder.Finalize(&identity_def));
183 MergeDebugInfo(NodeDebugInfo(*node), &identity_def);
184
185 VLOG(6) << "Adding identity into " << edge->src()->name() << ":"
186 << edge->src_output() << " -> " << edge->dst()->name() << ":"
187 << input_idx << " \n"
188 << identity_def.DebugString();
189
190 Status status;
191 Node* identity_node = graph->AddNode(identity_def, &status);
192 if (!status.ok()) {
193 return status;
194 }
195 graph->AddEdge(edge->src(), edge->src_output(), identity_node, 0);
196
197 // Replace node's `input_idx` input with the new identity's 0'th output
198 TF_RETURN_IF_ERROR(graph->UpdateEdge(identity_node, 0, node, input_idx));
199
200 VLOG(6) << "Successfully inserted identity. Modified node: \n"
201 << node->DebugString();
202 return Status::OK();
203 }
204
205 struct EdgePtrCompare {
operator ()tensorflow::__anon5ee8efb60211::EdgePtrCompare206 bool operator()(const Edge* lhs, const Edge* rhs) const {
207 return lhs->id() < rhs->id();
208 }
209 };
210
AddOutputIdentities(Node * node,Graph * graph,std::unordered_set<string> * node_names)211 Status AddOutputIdentities(Node* node, Graph* graph,
212 std::unordered_set<string>* node_names) {
213 auto add_identity = [&](int src_output, const string& identity_name,
214 Node** identity_node) {
215 NodeDefBuilder builder(identity_name, kIdentityOp);
216 builder.Attr("T", node->output_type(src_output));
217 NodeDefBuilder::NodeOut input(node->name(), src_output,
218 node->output_type(src_output));
219 builder.Input(input);
220 NodeDef identity_def;
221 TF_RETURN_IF_ERROR(builder.Finalize(&identity_def));
222 MergeDebugInfo(NodeDebugInfo(*node), &identity_def);
223
224 Status status;
225 *identity_node = graph->AddNode(identity_def, &status);
226 if (!status.ok()) {
227 return status;
228 }
229 graph->AddEdge(node, src_output, *identity_node, 0);
230 return Status::OK();
231 };
232
233 // output_used[i] == true iff `node`'s i'th output is used
234 // in this graph
235 std::vector<bool> output_used(node->num_outputs(), false);
236 // Copy the set of edges since EdgeSet does not allow modifications
237 // to graph edges during iteration.
238 const EdgeSet& out_edges = node->out_edges();
239 std::vector<const Edge*> edge_vector(out_edges.begin(), out_edges.end());
240 std::sort(edge_vector.begin(), edge_vector.end(), EdgePtrCompare());
241 for (const Edge* edge : edge_vector) {
242 if (edge->IsControlEdge()) {
243 continue;
244 }
245 output_used[edge->src_output()] = true;
246
247 Node* dst = edge->dst();
248 int dst_input = edge->dst_input();
249 int src_output = edge->src_output();
250 string identity_name =
251 Uniquify(absl::StrCat(node->name(), "_", dst->name()), node_names);
252 Node* identity_node;
253 TF_RETURN_IF_ERROR(add_identity(src_output, identity_name, &identity_node));
254 VLOG(6) << "Adding identity into " << node->name() << ":" << src_output
255 << " -> " << dst->name() << ":" << dst_input << " \n"
256 << identity_node->DebugString();
257
258 // Make original dst node consume the new identity's output instead of
259 // `node`'s output.
260 TF_RETURN_IF_ERROR(graph->UpdateEdge(identity_node, 0, dst, dst_input));
261 }
262
263 for (int output_idx = 0; output_idx < node->num_outputs(); ++output_idx) {
264 if (output_used[output_idx]) {
265 continue;
266 }
267 // The output is unused in the graph. Just add an identity
268 // consuming it.
269 string identity_name = Uniquify(node->name(), node_names);
270 Node* identity_node;
271 TF_RETURN_IF_ERROR(add_identity(output_idx, identity_name, &identity_node));
272 VLOG(6) << "Added identity into " << node->name() << ":" << output_idx
273 << " -> <no consumer>: \n"
274 << identity_node->DebugString();
275 }
276 return Status::OK();
277 }
278
IsolateNode(Node * node,Graph * graph)279 Status IsolateNode(Node* node, Graph* graph) {
280 // We use `node_names` to make sure we pick unique names.
281 // We don't use graph->NewName() because it produces verbose names and
282 // does not actually ensure that they are unique (it assumes all names
283 // are generated using it, which is not true today).
284 std::unordered_set<string> node_names(graph->num_nodes());
285 for (Node* n : graph->nodes()) {
286 node_names.insert(n->name());
287 }
288
289 for (int i = 0; i < node->num_inputs(); ++i) {
290 TF_RETURN_IF_ERROR(AddInputIdentity(node, i, graph, &node_names));
291 }
292 TF_RETURN_IF_ERROR(AddOutputIdentities(node, graph, &node_names));
293 return Status::OK();
294 }
295
296 } // namespace
297
IsolatePlacerInspectionRequiredOps(const FunctionLibraryDefinition & flib_def,Graph * graph)298 Status IsolatePlacerInspectionRequiredOps(
299 const FunctionLibraryDefinition& flib_def, Graph* graph) {
300 PlacerInspectionRequiredOpChecker checker(graph, &flib_def);
301 // It is OK to add nodes to the graph during iteration.
302 // New nodes will get ids above current ids. The loop
303 // will loop over current nodes only because the op_nodes()
304 // iterator uses node ids to iterate.
305 // Because the new nodes will be higher ids, the caching in
306 // the checker will also work fine as new nodes are added.
307 for (Node* node : graph->op_nodes()) {
308 bool should_be_isolated = false;
309 TF_RETURN_IF_ERROR(
310 checker.IsPlacerInspectionRequired(*node, &should_be_isolated));
311 if (!should_be_isolated) {
312 continue;
313 }
314 TF_RETURN_IF_ERROR(IsolateNode(node, graph));
315 }
316
317 return Status::OK();
318 }
319
320 } // namespace tensorflow
321