• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
16 
17 #include <unordered_map>
18 #include <unordered_set>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/framework/function.h"
24 #include "tensorflow/core/framework/node_def_builder.h"
25 #include "tensorflow/core/graph/graph.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 
29 namespace tensorflow {
30 namespace {
31 
IsFunctionCall(const Node & node)32 bool IsFunctionCall(const Node& node) {
33   // TODO(iga): Handle non-PCO functions when we add multi-device support
34   // to regular function calls. Also, the GetFunctionDefAndAttrs assumes that
35   // the function name is stored in the `f` attribute of the node. That code
36   // will need to change as well.
37   const string& op_type = node.op_def().name();
38   return op_type == "PartitionedCall" || op_type == "StatefulPartitionedCall";
39 }
40 
41 // Utility to set node's value in `cache` and `is_deep` to `value`.
Set(const Node & node,bool value,bool * is_deep,std::vector<absl::optional<bool>> * cache)42 Status Set(const Node& node, bool value, bool* is_deep,
43            std::vector<absl::optional<bool>>* cache) {
44   *is_deep = value;
45   (*cache)[node.id()] = value;
46   return Status::OK();
47 }
48 
49 }  // namespace
50 
PlacerInspectionRequiredOpChecker(const Graph * graph)51 PlacerInspectionRequiredOpChecker::PlacerInspectionRequiredOpChecker(
52     const Graph* graph)
53     : PlacerInspectionRequiredOpChecker(graph, &graph->flib_def()) {}
54 
PlacerInspectionRequiredOpChecker(const Graph * graph,const FunctionLibraryDefinition * flib_def)55 PlacerInspectionRequiredOpChecker::PlacerInspectionRequiredOpChecker(
56     const Graph* graph, const FunctionLibraryDefinition* flib_def)
57     : graph_(*graph), flib_def_(*flib_def) {
58   cache_.resize(graph_.num_node_ids());
59 }
60 
IsPlacerInspectionRequired(const Node & node,bool * is_deep)61 Status PlacerInspectionRequiredOpChecker::IsPlacerInspectionRequired(
62     const Node& node, bool* is_deep) {
63   if (cache_[node.id()].has_value()) {
64     *is_deep = cache_[node.id()].value();
65     return Status::OK();
66   }
67 
68   if (!IsFunctionCall(node)) {
69     return Set(node, false, is_deep, &cache_);
70   }
71   const FunctionDef* fdef;
72   NameAttrList func;
73   TF_RETURN_IF_ERROR(GetFunctionDefAndAttrs(flib_def_, node, &fdef, &func));
74   DataTypeVector types;
75   TF_RETURN_IF_ERROR(
76       OutputTypesForNode(AttrSlice(&func.attr()), fdef->signature(), &types));
77   for (DataType type : types) {
78     if (type == DT_RESOURCE) {
79       return Set(node, true, is_deep, &cache_);
80     }
81   }
82   return Set(node, false, is_deep, &cache_);
83 }
84 
GetFunctionDefAndAttrs(const FunctionLibraryDefinition & flib_def,const Node & node,const FunctionDef ** fdef,NameAttrList * func)85 Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def,
86                               const Node& node, const FunctionDef** fdef,
87                               NameAttrList* func) {
88   TF_RETURN_IF_ERROR(GetNodeAttr(node.def(), "f", func));
89   const string& function_name = func->name();
90   *fdef = flib_def.Find(function_name);
91   if (*fdef == nullptr) {
92     return errors::InvalidArgument(
93         "Failed to find function \"", function_name,
94         "\" in function library: ", flib_def.ToProto().DebugString());
95   }
96   return Status::OK();
97 }
98 
FunctionStack(const string & function_name)99 FunctionStack::FunctionStack(const string& function_name)
100     : current_function_name_(function_name) {}
101 
Push(const Node * node_in_current_function,const string & new_current_function) const102 FunctionStack FunctionStack::Push(const Node* node_in_current_function,
103                                   const string& new_current_function) const {
104   FunctionStack new_stack(new_current_function);
105   new_stack.frames_ = frames_;
106   new_stack.frames_.emplace_back(current_function_name_,
107                                  node_in_current_function);
108   return new_stack;
109 }
110 
HasFunction(const string & function_name) const111 bool FunctionStack::HasFunction(const string& function_name) const {
112   if (current_function_name_ == function_name) {
113     return true;
114   }
115   for (const Frame& frame : frames_) {
116     if (frame.function_name == function_name) {
117       return true;
118     }
119   }
120   return false;
121 }
122 
FormatForError() const123 string FunctionStack::FormatForError() const {
124   std::vector<string> msgs;
125   for (int i = 0; i < frames_.size(); ++i) {
126     if (frames_[i].function_name.empty()) {
127       // Empty function body should only happen at the top level, i.e. i = 0.
128       // All internal frames should have valid function names.
129       msgs.push_back(absl::StrCat("Graph contains node ",
130                                   FormatNodeForError(*frames_[i].node)));
131 
132     } else {
133       msgs.push_back(absl::StrCat(
134           "Function ", errors::FormatFunctionForError(frames_[i].function_name),
135           " contains node ", FormatNodeForError(*frames_[i].node)));
136     }
137     const string& fname = (i + 1 < frames_.size())
138                               ? frames_[i + 1].function_name
139                               : current_function_name_;
140     msgs.push_back(absl::StrCat("Node ", FormatNodeForError(*frames_[i].node),
141                                 " calls function ",
142                                 errors::FormatFunctionForError(fname)));
143   }
144   return absl::StrJoin(msgs, "\n  ");
145 }
146 
147 namespace {
148 
149 using OutputEdgeMap = std::vector<std::vector<const Edge*>>;
150 
151 constexpr char kIdentityOp[] = "Identity";
152 
Uniquify(const string & candidate_name,std::unordered_set<string> * node_names)153 string Uniquify(const string& candidate_name,
154                 std::unordered_set<string>* node_names) {
155   if (node_names->find(candidate_name) == node_names->end()) {
156     node_names->insert(candidate_name);
157     return candidate_name;
158   }
159 
160   for (int counter = 0;; ++counter) {
161     string candidate = absl::StrCat(candidate_name, "_", counter);
162     if (node_names->find(candidate) == node_names->end()) {
163       node_names->insert(candidate);
164       return candidate;
165     }
166   }
167 }
168 
AddInputIdentity(Node * node,int input_idx,Graph * graph,std::unordered_set<string> * node_names)169 Status AddInputIdentity(Node* node, int input_idx, Graph* graph,
170                         std::unordered_set<string>* node_names) {
171   const Edge* edge;
172   TF_RETURN_IF_ERROR(node->input_edge(input_idx, &edge));
173 
174   string identity_name = Uniquify(
175       absl::StrCat(edge->src()->name(), "_", node->name()), node_names);
176 
177   NodeDefBuilder builder(identity_name, kIdentityOp);
178   builder.Attr("T", node->input_type(input_idx));
179   NodeDefBuilder::NodeOut input(edge->src()->name(), edge->src_output(),
180                                 node->input_type(input_idx));
181   builder.Input(input);
182   NodeDef identity_def;
183   TF_RETURN_IF_ERROR(builder.Finalize(&identity_def));
184   MergeDebugInfo(NodeDebugInfo(*node), &identity_def);
185 
186   VLOG(6) << "Adding identity into " << edge->src()->name() << ":"
187           << edge->src_output() << " -> " << edge->dst()->name() << ":"
188           << input_idx << " \n"
189           << identity_def.DebugString();
190 
191   Status status;
192   Node* identity_node = graph->AddNode(identity_def, &status);
193   if (!status.ok()) {
194     return status;
195   }
196   graph->AddEdge(edge->src(), edge->src_output(), identity_node, 0);
197 
198   // Replace node's `input_idx` input with the new identity's 0'th output
199   TF_RETURN_IF_ERROR(graph->UpdateEdge(identity_node, 0, node, input_idx));
200 
201   VLOG(6) << "Successfully inserted identity. Modified node: \n"
202           << node->DebugString();
203   return Status::OK();
204 }
205 
206 struct EdgePtrCompare {
operator ()tensorflow::__anon1096285a0211::EdgePtrCompare207   bool operator()(const Edge* lhs, const Edge* rhs) const {
208     return lhs->id() < rhs->id();
209   }
210 };
211 
AddOutputIdentities(Node * node,Graph * graph,std::unordered_set<string> * node_names)212 Status AddOutputIdentities(Node* node, Graph* graph,
213                            std::unordered_set<string>* node_names) {
214   auto add_identity = [&](int src_output, const string& identity_name,
215                           Node** identity_node) {
216     NodeDefBuilder builder(identity_name, kIdentityOp);
217     builder.Attr("T", node->output_type(src_output));
218     NodeDefBuilder::NodeOut input(node->name(), src_output,
219                                   node->output_type(src_output));
220     builder.Input(input);
221     NodeDef identity_def;
222     TF_RETURN_IF_ERROR(builder.Finalize(&identity_def));
223     MergeDebugInfo(NodeDebugInfo(*node), &identity_def);
224 
225     Status status;
226     *identity_node = graph->AddNode(identity_def, &status);
227     if (!status.ok()) {
228       return status;
229     }
230     graph->AddEdge(node, src_output, *identity_node, 0);
231     return Status::OK();
232   };
233 
234   // output_used[i] == true iff `node`'s i'th output is used
235   // in this graph
236   std::vector<bool> output_used(node->num_outputs(), false);
237   // Copy the set of edges since EdgeSet does not allow modifications
238   // to graph edges during iteration.
239   const EdgeSet& out_edges = node->out_edges();
240   std::vector<const Edge*> edge_vector(out_edges.begin(), out_edges.end());
241   std::sort(edge_vector.begin(), edge_vector.end(), EdgePtrCompare());
242   for (const Edge* edge : edge_vector) {
243     if (edge->IsControlEdge()) {
244       continue;
245     }
246     output_used[edge->src_output()] = true;
247 
248     Node* dst = edge->dst();
249     int dst_input = edge->dst_input();
250     int src_output = edge->src_output();
251     string identity_name =
252         Uniquify(absl::StrCat(node->name(), "_", dst->name()), node_names);
253     Node* identity_node;
254     TF_RETURN_IF_ERROR(add_identity(src_output, identity_name, &identity_node));
255     VLOG(6) << "Adding identity into " << node->name() << ":" << src_output
256             << " -> " << dst->name() << ":" << dst_input << " \n"
257             << identity_node->DebugString();
258 
259     // Make original dst node consume the new identity's output instead of
260     // `node`'s output.
261     TF_RETURN_IF_ERROR(graph->UpdateEdge(identity_node, 0, dst, dst_input));
262   }
263 
264   for (int output_idx = 0; output_idx < node->num_outputs(); ++output_idx) {
265     if (output_used[output_idx]) {
266       continue;
267     }
268     // The output is unused in the graph. Just add an identity
269     // consuming it.
270     string identity_name = Uniquify(node->name(), node_names);
271     Node* identity_node;
272     TF_RETURN_IF_ERROR(add_identity(output_idx, identity_name, &identity_node));
273     VLOG(6) << "Added identity into " << node->name() << ":" << output_idx
274             << " -> <no consumer>: \n"
275             << identity_node->DebugString();
276   }
277   return Status::OK();
278 }
279 
IsolateNode(Node * node,Graph * graph)280 Status IsolateNode(Node* node, Graph* graph) {
281   // We use `node_names` to make sure we pick unique names.
282   // We don't use graph->NewName() because it produces verbose names and
283   // does not actually ensure that they are unique (it assumes all names
284   // are generated using it, which is not true today).
285   std::unordered_set<string> node_names(graph->num_nodes());
286   for (Node* n : graph->nodes()) {
287     node_names.insert(n->name());
288   }
289 
290   for (int i = 0; i < node->num_inputs(); ++i) {
291     TF_RETURN_IF_ERROR(AddInputIdentity(node, i, graph, &node_names));
292   }
293   TF_RETURN_IF_ERROR(AddOutputIdentities(node, graph, &node_names));
294   return Status::OK();
295 }
296 
297 }  // namespace
298 
IsolatePlacerInspectionRequiredOps(const FunctionLibraryDefinition & flib_def,Graph * graph)299 Status IsolatePlacerInspectionRequiredOps(
300     const FunctionLibraryDefinition& flib_def, Graph* graph) {
301   PlacerInspectionRequiredOpChecker checker(graph, &flib_def);
302   // It is OK to add nodes to the graph during iteration.
303   // New nodes will get ids above current ids. The loop
304   // will loop over current nodes only because the op_nodes()
305   // iterator uses node ids to iterate.
306   // Because the new nodes will be higher ids, the caching in
307   // the checker will also work fine as new nodes are added.
308   for (Node* node : graph->op_nodes()) {
309     bool should_be_isolated = false;
310     TF_RETURN_IF_ERROR(
311         checker.IsPlacerInspectionRequired(*node, &should_be_isolated));
312     if (!should_be_isolated) {
313       continue;
314     }
315     TF_RETURN_IF_ERROR(IsolateNode(node, graph));
316   }
317 
318   return Status::OK();
319 }
320 
321 }  // namespace tensorflow
322