• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/function_utils.h"
17 
18 #include "tensorflow/core/common_runtime/function_body.h"
19 #include "tensorflow/core/framework/function.h"
20 #include "tensorflow/core/framework/graph.pb.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/op_def.pb.h"
24 #include "tensorflow/core/framework/versions.pb.h"
25 #include "tensorflow/core/graph/algorithm.h"
26 #include "tensorflow/core/graph/control_flow.h"
27 #include "tensorflow/core/graph/graph.h"
28 
29 namespace tensorflow {
30 
31 static constexpr const char* const kNodeLabel = "Func";
32 
33 // Represents the index-th output of a node.
34 struct Endpoint {
35   Node* node;
36   int index;
37 
38   // Returns the string name represents this endpoint.
nametensorflow::Endpoint39   string name() const {
40     if (index == 0) {
41       return node->name();
42     } else {
43       return strings::StrCat(node->name(), ":", index);
44     }
45   }
46 
dtypetensorflow::Endpoint47   DataType dtype() const { return node->output_type(index); }
48 };
49 
50 // The following Add* routines are used to add a few graph nodes while
51 // functions are transformed.
AddNoOp(StringPiece name,Graph * g)52 static Node* AddNoOp(StringPiece name, Graph* g) {
53   NodeDef ndef;
54   ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
55   ndef.set_op("NoOp");
56   Status s;
57   Node* ret = g->AddNode(ndef, &s);
58   TF_CHECK_OK(s);
59   return ret;
60 }
61 
AddIdentity(StringPiece name,Graph * g,Endpoint input)62 static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) {
63   DCHECK_LT(0, input.dtype());
64   NodeDef ndef;
65   ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
66   ndef.set_op("Identity");
67   ndef.add_input(input.name());
68   AddNodeAttr("T", BaseType(input.dtype()), &ndef);
69   Status s;
70   Node* ret = g->AddNode(ndef, &s);
71   TF_CHECK_OK(s);
72   g->AddEdge(input.node, input.index, ret, 0);
73   return ret;
74 }
75 
DumpGraph(StringPiece label,const Graph * g)76 void DumpGraph(StringPiece label, const Graph* g) {
77   // TODO(zhifengc): Change Graph to record #nodes.
78   VLOG(2) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges "
79           << g->num_edges();
80   if (VLOG_IS_ON(5)) {
81     for (const auto& line : str_util::Split(DebugString(g), '\n')) {
82       VLOG(5) << "|| " << line;
83     }
84   }
85 }
86 
RemoveDeadNodes(Graph * g)87 bool RemoveDeadNodes(Graph* g) {
88   VLOG(2) << "Removing dead nodes";
89   std::unordered_set<const Node*> nodes;
90   for (auto n : g->nodes()) {
91     if (n->IsSource() || n->IsSink() || n->IsControlFlow() ||
92         n->op_def().is_stateful()) {
93       nodes.insert(n);
94     }
95   }
96   return PruneForReverseReachability(g, std::move(nodes));
97 }
98 
99 namespace {
100 // If 'edges' contains only 1 non-control edge, returns it. Otherwise,
101 // returns a nullptr.
GetTheOnlyDataEdge(const EdgeSet & edges)102 const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) {
103   const Edge* ret = nullptr;
104   for (const Edge* e : edges) {
105     if (e->IsControlEdge() || ret) {
106       // Don't touch it if there is a control edge.
107       return nullptr;
108     }
109     if (IsRefType(e->src()->output_type(e->src_output()))) {
110       // Don't touch it if the identity node is effectively de-reffing
111       // a ref.
112       return nullptr;
113     }
114     if (IsRecv(e->src()) || IsSwitch(e->src())) {
115       // Don't touch it if the identity is introduced for control flow.
116       // Recv disables all its successors if it receives a dead signal.
117       // When Recv has an outgoing control edge, the current executor
118       // would not disable the destination. The current solution (see
119       // graph_partition.cc) is to add an identity after Recv and change
120       // the control edge to be from this identity node. So the identity
121       // can't be removed.
122       return nullptr;
123     }
124     ret = e;
125   }
126   return ret;
127 }
128 }  // end namespace
129 
RemoveIdentityNodes(Graph * g)130 bool RemoveIdentityNodes(Graph* g) {
131   VLOG(2) << "Removing identity nodes";
132   bool removed_any = false;
133   gtl::InlinedVector<Node*, 8> matches;
134   for (Node* n : g->nodes()) {
135     if (!n->IsIdentity()) continue;
136     if (!GetTheOnlyDataEdge(n->in_edges())) continue;
137 
138     // Some identity nodes are used as sink nodes to give names to output
139     // tensors. These nodes are not going to be executed unless they are in the
140     // fetch set. But if they are in the fetch set we don't want to remove them.
141     if (n->out_edges().empty()) continue;
142 
143     matches.push_back(n);
144   }
145   if (!matches.empty()) {
146     for (Node* n : matches) {
147       const Edge* in = GetTheOnlyDataEdge(n->in_edges());
148       for (const Edge* out : n->out_edges()) {
149         if (out->IsControlEdge()) {
150           g->AddControlEdge(in->src(), out->dst());
151         } else {
152           g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input());
153         }
154       }
155       VLOG(2) << "Remove Identity: " << n->DebugString();
156       g->RemoveNode(n);
157       removed_any = true;
158     }
159   }
160   return removed_any;
161 }
162 
RemoveListArrayConverter(Graph * g)163 bool RemoveListArrayConverter(Graph* g) {
164   VLOG(2) << "Removing list array converter";
165   gtl::InlinedVector<Node*, 8> matches;
166   for (Node* n : g->nodes()) {
167     if ((n->type_string() == "_ListToArray") ||
168         (n->type_string() == "_ArrayToList")) {
169       matches.push_back(n);
170     }
171   }
172   bool removed_any = false;
173   if (!matches.empty()) {
174     for (Node* n : matches) {
175       if (n->num_inputs() != n->num_outputs()) {
176         continue;  // Not expected. Skip.
177       }
178       gtl::InlinedVector<Node*, 8> identity_nodes(n->num_inputs(), nullptr);
179 
180       const auto no_op = [&](StringPiece name) -> Node* {
181         return AddNoOp(absl::StrCat(n->name(), "/", name), g);
182       };
183 
184       const auto identity = [&](StringPiece name, Endpoint input) -> Node* {
185         Node* node = AddIdentity(absl::StrCat(n->name(), "/", name), g, input);
186         node->set_requested_device(input.node->def().device());
187         return node;
188       };
189 
190       // Process input edges first.
191       Node* input_control_node = nullptr;
192       for (const Edge* e : n->in_edges()) {
193         if (e->IsControlEdge()) {
194           if (input_control_node == nullptr) {
195             // If node "n" has any control dependencies, adds a no-op
196             // node (input_control_node) which the additional Identity
197             // nodes depends on and the input_control_node depends on
198             // the node "n"s control dependencies.
199             input_control_node = no_op("input_control_node");
200           }
201           g->AddControlEdge(e->src(), input_control_node);
202         } else {
203           const int index = e->dst_input();
204           Node** id_node = &identity_nodes[index];
205           if (*id_node != nullptr) {
206             LOG(ERROR)
207                 << "RemoveListArrayConverter unexpected duplicated input: "
208                 << e->dst_input();
209             return removed_any;
210           }
211           *id_node = identity("input", {e->src(), e->src_output()});
212         }
213       }
214 
215       // If node "n" has any control dependencies, the added identity
216       // nodes should have control dependencies on input_control_node.
217       if (input_control_node != nullptr) {
218         for (Node* id : identity_nodes) {
219           g->AddControlEdge(input_control_node, id);
220         }
221       }
222 
223       Node* output_control_node = nullptr;
224       for (const Edge* e : n->out_edges()) {
225         if (e->IsControlEdge()) {
226           if (output_control_node == nullptr) {
227             // If node "n" is control-depended upon by other nodes,
228             // adds a no-op node (output_control_node) which those
229             // nodes will depend on and output_control_node depends on
230             // all Identity nodes.
231             output_control_node = no_op("output_control_node");
232           }
233           g->AddControlEdge(output_control_node, e->dst());
234         } else {
235           Node* id_node = identity_nodes[e->src_output()];
236           if (id_node == nullptr) {
237             LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: "
238                        << e->src_output();
239             return removed_any;
240           }
241           CHECK(id_node);
242           g->AddEdge(id_node, 0, e->dst(), e->dst_input());
243         }
244       }
245 
246       // If any nodes have control dependencies on node "n", those
247       // nodes should have control dependencies on
248       // output_control_node.
249       if (output_control_node != nullptr) {
250         for (Node* id : identity_nodes) {
251           g->AddControlEdge(id, output_control_node);
252         }
253       }
254 
255       g->RemoveNode(n);
256       removed_any = true;
257     }
258   }
259   return removed_any;
260 }
261 
NameAndAttrsFromFunctionCall(const NodeDef & call_def,NameAttrList * function)262 Status NameAndAttrsFromFunctionCall(const NodeDef& call_def,
263                                     NameAttrList* function) {
264   if (call_def.op() == "PartitionedCall" ||
265       call_def.op() == "StatefulPartitionedCall") {
266     TF_RETURN_IF_ERROR(GetNodeAttr(call_def, "f", function));
267   } else {
268     function->set_name(call_def.op());
269     *function->mutable_attr() = call_def.attr();
270   }
271   return Status::OK();
272 }
273 
InstantiateFunctionCall(const NodeDef & call_def,FunctionLibraryRuntime * flr,FunctionLibraryRuntime::Handle * handle)274 Status InstantiateFunctionCall(const NodeDef& call_def,
275                                FunctionLibraryRuntime* flr,
276                                FunctionLibraryRuntime::Handle* handle) {
277   NameAttrList function;
278   TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(call_def, &function));
279   return flr->Instantiate(function.name(), AttrSlice(&function.attr()), handle);
280 }
281 
IsFunctionCall(const FunctionLibraryDefinition & lib_def,const Node & node)282 bool IsFunctionCall(const FunctionLibraryDefinition& lib_def,
283                     const Node& node) {
284   return node.IsFunctionCall();
285 }
286 
NewName(const Node * n,bool pretty)287 string NewName(const Node* n, bool pretty) {
288   if (pretty) {
289     return strings::StrCat(n->type_string(), n->id());
290   } else {
291     return strings::StrCat("n", n->id());
292   }
293 }
294 
295 // TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef.
296 // and stash the original NodeDef name as an attr for documentation
297 // purpose.
ToGraphDef(const Graph * g,GraphDef * gdef,bool pretty)298 void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
299   // We visit nodes in forward topological sort order, which is a
300   // possible execution order of the graph.
301   gtl::InlinedVector<const Edge*, 4> inputs;
302   gdef->Clear();
303   *gdef->mutable_versions() = g->versions();
304 
305   std::vector<Node*> start_nodes;
306   for (Node* n : g->nodes()) {
307     if (n->out_edges().empty()) {
308       start_nodes.push_back(n);
309     }
310   }
311 
312   ReverseDFSFrom(*g, start_nodes, nullptr, [gdef, pretty, &inputs](Node* n) {
313     if (!n->IsOp()) return;
314     NodeDef* ndef = gdef->add_node();
315     ndef->set_name(NewName(n, pretty));
316     ndef->set_op(n->type_string());
317     for (const auto& attr : n->attrs()) {
318       (*ndef->mutable_attr())[attr.first] = attr.second;
319     }
320 
321     if (!n->assigned_device_name().empty()) {
322       ndef->set_device(n->assigned_device_name());
323     } else {
324       ndef->set_device(n->requested_device());
325     }
326 
327     inputs.clear();
328     inputs.resize(n->num_inputs());
329     for (const Edge* e : n->in_edges()) {
330       if (e->IsControlEdge()) {
331         inputs.push_back(e);
332       } else {
333         if (inputs[e->dst_input()] == nullptr) {
334           inputs[e->dst_input()] = e;
335         } else {
336           LOG(WARNING) << "Malformed graph node. multiple input edges: "
337                        << n->DebugString();
338         }
339       }
340     }
341     // node->name() is merely NodeDef::name, which are not guaranteed
342     // to be unique and stable after optimization rewrites. Therefore,
343     // we use "n<node id>" instead.
344     for (const Edge* e : inputs) {
345       if (e == nullptr) {
346         ndef->add_input("unknown");
347         continue;
348       }
349       const string srcname = NewName(e->src(), pretty);
350       if (!e->src()->IsOp()) {
351       } else if (e->IsControlEdge()) {
352         ndef->add_input(strings::StrCat("^", srcname));
353       } else if (e->src_output() == 0) {
354         ndef->add_input(srcname);
355       } else {
356         ndef->add_input(strings::StrCat(srcname, ":", e->src_output()));
357       }
358     }
359   });
360 }
361 
DebugString(const Graph * g)362 string DebugString(const Graph* g) {
363   GraphDef gdef;
364   ToGraphDef(g, &gdef);
365   return DebugString(gdef);
366 }
367 
368 }  // end namespace tensorflow
369