• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
16 
17 #include <unordered_map>
18 #include <unordered_set>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/framework/node_def_builder.h"
24 #include "tensorflow/core/graph/graph.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status.h"
27 
28 namespace tensorflow {
29 namespace {
30 
IsFunctionCall(const Node & node)31 bool IsFunctionCall(const Node& node) {
32   // TODO(iga): Handle non-PCO functions when we add multi-device support
33   // to regular function calls. Also, the GetFunctionDefAndAttrs assumes that
34   // the function name is stored in the `f` attribute of the node. That code
35   // will need to change as well.
36   const string& op_type = node.op_def().name();
37   return op_type == "PartitionedCall" || op_type == "StatefulPartitionedCall";
38 }
39 
40 // Utility to set node's value in `cache` and `is_deep` to `value`.
Set(const Node & node,bool value,bool * is_deep,std::vector<absl::optional<bool>> * cache)41 Status Set(const Node& node, bool value, bool* is_deep,
42            std::vector<absl::optional<bool>>* cache) {
43   *is_deep = value;
44   (*cache)[node.id()] = value;
45   return Status::OK();
46 }
47 
48 }  // namespace
49 
PlacerInspectionRequiredOpChecker(const Graph * graph,const FunctionLibraryDefinition * flib_def)50 PlacerInspectionRequiredOpChecker::PlacerInspectionRequiredOpChecker(
51     const Graph* graph, const FunctionLibraryDefinition* flib_def)
52     : graph_(*graph), flib_def_(*flib_def) {
53   cache_.resize(graph_.num_node_ids());
54 }
55 
IsPlacerInspectionRequired(const Node & node,bool * is_deep)56 Status PlacerInspectionRequiredOpChecker::IsPlacerInspectionRequired(
57     const Node& node, bool* is_deep) {
58   if (cache_[node.id()].has_value()) {
59     *is_deep = cache_[node.id()].value();
60     return Status::OK();
61   }
62 
63   if (!IsFunctionCall(node)) {
64     return Set(node, false, is_deep, &cache_);
65   }
66   const FunctionDef* fdef;
67   NameAttrList func;
68   TF_RETURN_IF_ERROR(GetFunctionDefAndAttrs(flib_def_, node, &fdef, &func));
69   DataTypeVector types;
70   TF_RETURN_IF_ERROR(
71       OutputTypesForNode(AttrSlice(&func.attr()), fdef->signature(), &types));
72   for (DataType type : types) {
73     if (type == DT_RESOURCE) {
74       return Set(node, true, is_deep, &cache_);
75     }
76   }
77   return Set(node, false, is_deep, &cache_);
78 }
79 
GetFunctionDefAndAttrs(const FunctionLibraryDefinition & flib_def,const Node & node,const FunctionDef ** fdef,NameAttrList * func)80 Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def,
81                               const Node& node, const FunctionDef** fdef,
82                               NameAttrList* func) {
83   TF_RETURN_IF_ERROR(GetNodeAttr(node.def(), "f", func));
84   const string& function_name = func->name();
85   *fdef = flib_def.Find(function_name);
86   if (*fdef == nullptr) {
87     return errors::InvalidArgument(
88         "Failed to find function \"", function_name,
89         "\" in function library: ", flib_def.ToProto().DebugString());
90   }
91   return Status::OK();
92 }
93 
FunctionStack(const string & function_name)94 FunctionStack::FunctionStack(const string& function_name)
95     : current_function_name_(function_name) {}
96 
Push(const Node * node_in_current_function,const string & new_current_function) const97 FunctionStack FunctionStack::Push(const Node* node_in_current_function,
98                                   const string& new_current_function) const {
99   FunctionStack new_stack(new_current_function);
100   new_stack.frames_ = frames_;
101   new_stack.frames_.emplace_back(current_function_name_,
102                                  node_in_current_function);
103   return new_stack;
104 }
105 
HasFunction(const string & function_name) const106 bool FunctionStack::HasFunction(const string& function_name) const {
107   if (current_function_name_ == function_name) {
108     return true;
109   }
110   for (const Frame& frame : frames_) {
111     if (frame.function_name == function_name) {
112       return true;
113     }
114   }
115   return false;
116 }
117 
FormatForError() const118 string FunctionStack::FormatForError() const {
119   std::vector<string> msgs;
120   for (int i = 0; i < frames_.size(); ++i) {
121     if (frames_[i].function_name.empty()) {
122       // Empty function body should only happen at the top level, i.e. i = 0.
123       // All internal frames should have valid function names.
124       msgs.push_back(absl::StrCat("Graph contains node ",
125                                   FormatNodeForError(*frames_[i].node)));
126 
127     } else {
128       msgs.push_back(absl::StrCat(
129           "Function ", errors::FormatFunctionForError(frames_[i].function_name),
130           " contains node ", FormatNodeForError(*frames_[i].node)));
131     }
132     const string& fname = (i + 1 < frames_.size())
133                               ? frames_[i + 1].function_name
134                               : current_function_name_;
135     msgs.push_back(absl::StrCat("Node ", FormatNodeForError(*frames_[i].node),
136                                 " calls function ",
137                                 errors::FormatFunctionForError(fname)));
138   }
139   return absl::StrJoin(msgs, "\n  ");
140 }
141 
142 namespace {
143 
144 using OutputEdgeMap = std::vector<std::vector<const Edge*>>;
145 
146 constexpr char kIdentityOp[] = "Identity";
147 
Uniquify(const string & candidate_name,std::unordered_set<string> * node_names)148 string Uniquify(const string& candidate_name,
149                 std::unordered_set<string>* node_names) {
150   if (node_names->find(candidate_name) == node_names->end()) {
151     node_names->insert(candidate_name);
152     return candidate_name;
153   }
154 
155   for (int counter = 0;; ++counter) {
156     string candidate = absl::StrCat(candidate_name, "_", counter);
157     if (node_names->find(candidate) == node_names->end()) {
158       node_names->insert(candidate);
159       return candidate;
160     }
161   }
162 }
163 
AddInputIdentity(Node * node,int input_idx,Graph * graph,std::unordered_set<string> * node_names)164 Status AddInputIdentity(Node* node, int input_idx, Graph* graph,
165                         std::unordered_set<string>* node_names) {
166   const Edge* edge;
167   TF_RETURN_IF_ERROR(node->input_edge(input_idx, &edge));
168 
169   string identity_name = Uniquify(
170       absl::StrCat(edge->src()->name(), "_", node->name()), node_names);
171 
172   NodeDefBuilder builder(identity_name, kIdentityOp);
173   builder.Attr("T", node->input_type(input_idx));
174   NodeDefBuilder::NodeOut input(edge->src()->name(), edge->src_output(),
175                                 node->input_type(input_idx));
176   builder.Input(input);
177   NodeDef identity_def;
178   TF_RETURN_IF_ERROR(builder.Finalize(&identity_def));
179   MergeDebugInfo(NodeDebugInfo(*node), &identity_def);
180 
181   VLOG(6) << "Adding identity into " << edge->src()->name() << ":"
182           << edge->src_output() << " -> " << edge->dst()->name() << ":"
183           << input_idx << " \n"
184           << identity_def.DebugString();
185 
186   Status status;
187   Node* identity_node = graph->AddNode(identity_def, &status);
188   if (!status.ok()) {
189     return status;
190   }
191   graph->AddEdge(edge->src(), edge->src_output(), identity_node, 0);
192 
193   // Replace node's `input_idx` input with the new identity's 0'th output
194   TF_RETURN_IF_ERROR(graph->UpdateEdge(identity_node, 0, node, input_idx));
195 
196   VLOG(6) << "Successfully inserted identity. Modified node: \n"
197           << node->DebugString();
198   return Status::OK();
199 }
200 
201 struct EdgePtrCompare {
operator ()tensorflow::__anon58bc023b0211::EdgePtrCompare202   bool operator()(const Edge* lhs, const Edge* rhs) const {
203     return lhs->id() < rhs->id();
204   }
205 };
206 
AddOutputIdentities(Node * node,Graph * graph,std::unordered_set<string> * node_names)207 Status AddOutputIdentities(Node* node, Graph* graph,
208                            std::unordered_set<string>* node_names) {
209   auto add_identity = [&](int src_output, const string& identity_name,
210                           Node** identity_node) {
211     NodeDefBuilder builder(identity_name, kIdentityOp);
212     builder.Attr("T", node->output_type(src_output));
213     NodeDefBuilder::NodeOut input(node->name(), src_output,
214                                   node->output_type(src_output));
215     builder.Input(input);
216     NodeDef identity_def;
217     TF_RETURN_IF_ERROR(builder.Finalize(&identity_def));
218     MergeDebugInfo(NodeDebugInfo(*node), &identity_def);
219 
220     Status status;
221     *identity_node = graph->AddNode(identity_def, &status);
222     if (!status.ok()) {
223       return status;
224     }
225     graph->AddEdge(node, src_output, *identity_node, 0);
226     return Status::OK();
227   };
228 
229   // output_used[i] == true iff `node`'s i'th output is used
230   // in this graph
231   std::vector<bool> output_used(node->num_outputs(), false);
232   // Copy the set of edges since EdgeSet does not allow modifications
233   // to graph edges during iteration.
234   const EdgeSet& out_edges = node->out_edges();
235   std::vector<const Edge*> edge_vector(out_edges.begin(), out_edges.end());
236   std::sort(edge_vector.begin(), edge_vector.end(), EdgePtrCompare());
237   for (const Edge* edge : edge_vector) {
238     if (edge->IsControlEdge()) {
239       continue;
240     }
241     output_used[edge->src_output()] = true;
242 
243     Node* dst = edge->dst();
244     int dst_input = edge->dst_input();
245     int src_output = edge->src_output();
246     string identity_name =
247         Uniquify(absl::StrCat(node->name(), "_", dst->name()), node_names);
248     Node* identity_node;
249     TF_RETURN_IF_ERROR(add_identity(src_output, identity_name, &identity_node));
250     VLOG(6) << "Adding identity into " << node->name() << ":" << src_output
251             << " -> " << dst->name() << ":" << dst_input << " \n"
252             << identity_node->DebugString();
253 
254     // Make original dst node consume the new identity's output instead of
255     // `node`'s output.
256     TF_RETURN_IF_ERROR(graph->UpdateEdge(identity_node, 0, dst, dst_input));
257   }
258 
259   for (int output_idx = 0; output_idx < node->num_outputs(); ++output_idx) {
260     if (output_used[output_idx]) {
261       continue;
262     }
263     // The output is unused in the graph. Just add an identity
264     // consuming it.
265     string identity_name = Uniquify(node->name(), node_names);
266     Node* identity_node;
267     TF_RETURN_IF_ERROR(add_identity(output_idx, identity_name, &identity_node));
268     VLOG(6) << "Added identity into " << node->name() << ":" << output_idx
269             << " -> <no consumer>: \n"
270             << identity_node->DebugString();
271   }
272   return Status::OK();
273 }
274 
IsolateNode(Node * node,Graph * graph)275 Status IsolateNode(Node* node, Graph* graph) {
276   // We use `node_names` to make sure we pick unique names.
277   // We don't use graph->NewName() because it produces verbose names and
278   // does not actually ensure that they are unique (it assumes all names
279   // are generated using it, which is not true today).
280   std::unordered_set<string> node_names(graph->num_nodes());
281   for (Node* n : graph->nodes()) {
282     node_names.insert(n->name());
283   }
284 
285   for (int i = 0; i < node->num_inputs(); ++i) {
286     TF_RETURN_IF_ERROR(AddInputIdentity(node, i, graph, &node_names));
287   }
288   TF_RETURN_IF_ERROR(AddOutputIdentities(node, graph, &node_names));
289   return Status::OK();
290 }
291 
292 }  // namespace
293 
IsolatePlacerInspectionRequiredOps(const FunctionLibraryDefinition & flib_def,Graph * graph)294 Status IsolatePlacerInspectionRequiredOps(
295     const FunctionLibraryDefinition& flib_def, Graph* graph) {
296   PlacerInspectionRequiredOpChecker checker(graph, &flib_def);
297   // It is OK to add nodes to the graph during iteration.
298   // New nodes will get ids above current ids. The loop
299   // will loop over current nodes only because the op_nodes()
300   // iterator uses node ids to iterate.
301   // Because the new nodes will be higher ids, the caching in
302   // the checker will also work fine as new nodes are added.
303   for (Node* node : graph->op_nodes()) {
304     bool should_be_isolated = false;
305     TF_RETURN_IF_ERROR(
306         checker.IsPlacerInspectionRequired(*node, &should_be_isolated));
307     if (!should_be_isolated) {
308       continue;
309     }
310     TF_RETURN_IF_ERROR(IsolateNode(node, graph));
311   }
312 
313   return Status::OK();
314 }
315 
316 }  // namespace tensorflow
317