• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
16 
17 #include <unordered_map>
18 #include <unordered_set>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/node_def_builder.h"
24 #include "tensorflow/core/graph/graph.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status.h"
27 
28 namespace tensorflow {
29 namespace {
30 
IsFunctionCall(const Node & node)31 bool IsFunctionCall(const Node& node) {
32   // TODO(iga): Handle non-PCO functions when we add multi-device support
33   // to regular function calls. Also, the GetFunctionDefAndAttrs assumes that
34   // the function name is stored in the `f` attribute of the node. That code
35   // will need to change as well.
36   const string& op_type = node.op_def().name();
37   return op_type == "PartitionedCall" || op_type == "StatefulPartitionedCall";
38 }
39 
40 // Utility to set node's value in `cache` and `is_deep` to `value`.
Set(const Node & node,bool value,bool * is_deep,std::vector<absl::optional<bool>> * cache)41 Status Set(const Node& node, bool value, bool* is_deep,
42            std::vector<absl::optional<bool>>* cache) {
43   *is_deep = value;
44   (*cache)[node.id()] = value;
45   return Status::OK();
46 }
47 
48 }  // namespace
49 
PlacerInspectionRequiredOpChecker(const Graph * graph)50 PlacerInspectionRequiredOpChecker::PlacerInspectionRequiredOpChecker(
51     const Graph* graph)
52     : PlacerInspectionRequiredOpChecker(graph, &graph->flib_def()) {}
53 
PlacerInspectionRequiredOpChecker(const Graph * graph,const FunctionLibraryDefinition * flib_def)54 PlacerInspectionRequiredOpChecker::PlacerInspectionRequiredOpChecker(
55     const Graph* graph, const FunctionLibraryDefinition* flib_def)
56     : graph_(*graph), flib_def_(*flib_def) {
57   cache_.resize(graph_.num_node_ids());
58 }
59 
IsPlacerInspectionRequired(const Node & node,bool * is_deep)60 Status PlacerInspectionRequiredOpChecker::IsPlacerInspectionRequired(
61     const Node& node, bool* is_deep) {
62   if (cache_[node.id()].has_value()) {
63     *is_deep = cache_[node.id()].value();
64     return Status::OK();
65   }
66 
67   if (!IsFunctionCall(node)) {
68     return Set(node, false, is_deep, &cache_);
69   }
70   const FunctionDef* fdef;
71   NameAttrList func;
72   TF_RETURN_IF_ERROR(GetFunctionDefAndAttrs(flib_def_, node, &fdef, &func));
73   DataTypeVector types;
74   TF_RETURN_IF_ERROR(
75       OutputTypesForNode(AttrSlice(&func.attr()), fdef->signature(), &types));
76   for (DataType type : types) {
77     if (type == DT_RESOURCE) {
78       return Set(node, true, is_deep, &cache_);
79     }
80   }
81   return Set(node, false, is_deep, &cache_);
82 }
83 
GetFunctionDefAndAttrs(const FunctionLibraryDefinition & flib_def,const Node & node,const FunctionDef ** fdef,NameAttrList * func)84 Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def,
85                               const Node& node, const FunctionDef** fdef,
86                               NameAttrList* func) {
87   TF_RETURN_IF_ERROR(GetNodeAttr(node.def(), "f", func));
88   const string& function_name = func->name();
89   *fdef = flib_def.Find(function_name);
90   if (*fdef == nullptr) {
91     return errors::InvalidArgument(
92         "Failed to find function \"", function_name,
93         "\" in function library: ", flib_def.ToProto().DebugString());
94   }
95   return Status::OK();
96 }
97 
FunctionStack(const string & function_name)98 FunctionStack::FunctionStack(const string& function_name)
99     : current_function_name_(function_name) {}
100 
Push(const Node * node_in_current_function,const string & new_current_function) const101 FunctionStack FunctionStack::Push(const Node* node_in_current_function,
102                                   const string& new_current_function) const {
103   FunctionStack new_stack(new_current_function);
104   new_stack.frames_ = frames_;
105   new_stack.frames_.emplace_back(current_function_name_,
106                                  node_in_current_function);
107   return new_stack;
108 }
109 
HasFunction(const string & function_name) const110 bool FunctionStack::HasFunction(const string& function_name) const {
111   if (current_function_name_ == function_name) {
112     return true;
113   }
114   for (const Frame& frame : frames_) {
115     if (frame.function_name == function_name) {
116       return true;
117     }
118   }
119   return false;
120 }
121 
FormatForError() const122 string FunctionStack::FormatForError() const {
123   std::vector<string> msgs;
124   for (int i = 0; i < frames_.size(); ++i) {
125     if (frames_[i].function_name.empty()) {
126       // Empty function body should only happen at the top level, i.e. i = 0.
127       // All internal frames should have valid function names.
128       msgs.push_back(absl::StrCat("Graph contains node ",
129                                   FormatNodeForError(*frames_[i].node)));
130 
131     } else {
132       msgs.push_back(absl::StrCat(
133           "Function ", errors::FormatFunctionForError(frames_[i].function_name),
134           " contains node ", FormatNodeForError(*frames_[i].node)));
135     }
136     const string& fname = (i + 1 < frames_.size())
137                               ? frames_[i + 1].function_name
138                               : current_function_name_;
139     msgs.push_back(absl::StrCat("Node ", FormatNodeForError(*frames_[i].node),
140                                 " calls function ",
141                                 errors::FormatFunctionForError(fname)));
142   }
143   return absl::StrJoin(msgs, "\n  ");
144 }
145 
146 namespace {
147 
148 using OutputEdgeMap = std::vector<std::vector<const Edge*>>;
149 
150 constexpr char kIdentityOp[] = "Identity";
151 
Uniquify(const string & candidate_name,std::unordered_set<string> * node_names)152 string Uniquify(const string& candidate_name,
153                 std::unordered_set<string>* node_names) {
154   if (node_names->find(candidate_name) == node_names->end()) {
155     node_names->insert(candidate_name);
156     return candidate_name;
157   }
158 
159   for (int counter = 0;; ++counter) {
160     string candidate = absl::StrCat(candidate_name, "_", counter);
161     if (node_names->find(candidate) == node_names->end()) {
162       node_names->insert(candidate);
163       return candidate;
164     }
165   }
166 }
167 
AddInputIdentity(Node * node,int input_idx,Graph * graph,std::unordered_set<string> * node_names)168 Status AddInputIdentity(Node* node, int input_idx, Graph* graph,
169                         std::unordered_set<string>* node_names) {
170   const Edge* edge;
171   TF_RETURN_IF_ERROR(node->input_edge(input_idx, &edge));
172 
173   string identity_name = Uniquify(
174       absl::StrCat(edge->src()->name(), "_", node->name()), node_names);
175 
176   NodeDefBuilder builder(identity_name, kIdentityOp);
177   builder.Attr("T", node->input_type(input_idx));
178   NodeDefBuilder::NodeOut input(edge->src()->name(), edge->src_output(),
179                                 node->input_type(input_idx));
180   builder.Input(input);
181   NodeDef identity_def;
182   TF_RETURN_IF_ERROR(builder.Finalize(&identity_def));
183   MergeDebugInfo(NodeDebugInfo(*node), &identity_def);
184 
185   VLOG(6) << "Adding identity into " << edge->src()->name() << ":"
186           << edge->src_output() << " -> " << edge->dst()->name() << ":"
187           << input_idx << " \n"
188           << identity_def.DebugString();
189 
190   Status status;
191   Node* identity_node = graph->AddNode(identity_def, &status);
192   if (!status.ok()) {
193     return status;
194   }
195   graph->AddEdge(edge->src(), edge->src_output(), identity_node, 0);
196 
197   // Replace node's `input_idx` input with the new identity's 0'th output
198   TF_RETURN_IF_ERROR(graph->UpdateEdge(identity_node, 0, node, input_idx));
199 
200   VLOG(6) << "Successfully inserted identity. Modified node: \n"
201           << node->DebugString();
202   return Status::OK();
203 }
204 
205 struct EdgePtrCompare {
operator ()tensorflow::__anon5ee8efb60211::EdgePtrCompare206   bool operator()(const Edge* lhs, const Edge* rhs) const {
207     return lhs->id() < rhs->id();
208   }
209 };
210 
AddOutputIdentities(Node * node,Graph * graph,std::unordered_set<string> * node_names)211 Status AddOutputIdentities(Node* node, Graph* graph,
212                            std::unordered_set<string>* node_names) {
213   auto add_identity = [&](int src_output, const string& identity_name,
214                           Node** identity_node) {
215     NodeDefBuilder builder(identity_name, kIdentityOp);
216     builder.Attr("T", node->output_type(src_output));
217     NodeDefBuilder::NodeOut input(node->name(), src_output,
218                                   node->output_type(src_output));
219     builder.Input(input);
220     NodeDef identity_def;
221     TF_RETURN_IF_ERROR(builder.Finalize(&identity_def));
222     MergeDebugInfo(NodeDebugInfo(*node), &identity_def);
223 
224     Status status;
225     *identity_node = graph->AddNode(identity_def, &status);
226     if (!status.ok()) {
227       return status;
228     }
229     graph->AddEdge(node, src_output, *identity_node, 0);
230     return Status::OK();
231   };
232 
233   // output_used[i] == true iff `node`'s i'th output is used
234   // in this graph
235   std::vector<bool> output_used(node->num_outputs(), false);
236   // Copy the set of edges since EdgeSet does not allow modifications
237   // to graph edges during iteration.
238   const EdgeSet& out_edges = node->out_edges();
239   std::vector<const Edge*> edge_vector(out_edges.begin(), out_edges.end());
240   std::sort(edge_vector.begin(), edge_vector.end(), EdgePtrCompare());
241   for (const Edge* edge : edge_vector) {
242     if (edge->IsControlEdge()) {
243       continue;
244     }
245     output_used[edge->src_output()] = true;
246 
247     Node* dst = edge->dst();
248     int dst_input = edge->dst_input();
249     int src_output = edge->src_output();
250     string identity_name =
251         Uniquify(absl::StrCat(node->name(), "_", dst->name()), node_names);
252     Node* identity_node;
253     TF_RETURN_IF_ERROR(add_identity(src_output, identity_name, &identity_node));
254     VLOG(6) << "Adding identity into " << node->name() << ":" << src_output
255             << " -> " << dst->name() << ":" << dst_input << " \n"
256             << identity_node->DebugString();
257 
258     // Make original dst node consume the new identity's output instead of
259     // `node`'s output.
260     TF_RETURN_IF_ERROR(graph->UpdateEdge(identity_node, 0, dst, dst_input));
261   }
262 
263   for (int output_idx = 0; output_idx < node->num_outputs(); ++output_idx) {
264     if (output_used[output_idx]) {
265       continue;
266     }
267     // The output is unused in the graph. Just add an identity
268     // consuming it.
269     string identity_name = Uniquify(node->name(), node_names);
270     Node* identity_node;
271     TF_RETURN_IF_ERROR(add_identity(output_idx, identity_name, &identity_node));
272     VLOG(6) << "Added identity into " << node->name() << ":" << output_idx
273             << " -> <no consumer>: \n"
274             << identity_node->DebugString();
275   }
276   return Status::OK();
277 }
278 
IsolateNode(Node * node,Graph * graph)279 Status IsolateNode(Node* node, Graph* graph) {
280   // We use `node_names` to make sure we pick unique names.
281   // We don't use graph->NewName() because it produces verbose names and
282   // does not actually ensure that they are unique (it assumes all names
283   // are generated using it, which is not true today).
284   std::unordered_set<string> node_names(graph->num_nodes());
285   for (Node* n : graph->nodes()) {
286     node_names.insert(n->name());
287   }
288 
289   for (int i = 0; i < node->num_inputs(); ++i) {
290     TF_RETURN_IF_ERROR(AddInputIdentity(node, i, graph, &node_names));
291   }
292   TF_RETURN_IF_ERROR(AddOutputIdentities(node, graph, &node_names));
293   return Status::OK();
294 }
295 
296 }  // namespace
297 
IsolatePlacerInspectionRequiredOps(const FunctionLibraryDefinition & flib_def,Graph * graph)298 Status IsolatePlacerInspectionRequiredOps(
299     const FunctionLibraryDefinition& flib_def, Graph* graph) {
300   PlacerInspectionRequiredOpChecker checker(graph, &flib_def);
301   // It is OK to add nodes to the graph during iteration.
302   // New nodes will get ids above current ids. The loop
303   // will loop over current nodes only because the op_nodes()
304   // iterator uses node ids to iterate.
305   // Because the new nodes will be higher ids, the caching in
306   // the checker will also work fine as new nodes are added.
307   for (Node* node : graph->op_nodes()) {
308     bool should_be_isolated = false;
309     TF_RETURN_IF_ERROR(
310         checker.IsPlacerInspectionRequired(*node, &should_be_isolated));
311     if (!should_be_isolated) {
312       continue;
313     }
314     TF_RETURN_IF_ERROR(IsolateNode(node, graph));
315   }
316 
317   return Status::OK();
318 }
319 
320 }  // namespace tensorflow
321