1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/function_utils.h"
17
18 #include "tensorflow/core/common_runtime/function_body.h"
19 #include "tensorflow/core/framework/function.h"
20 #include "tensorflow/core/framework/graph.pb.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/op_def.pb.h"
24 #include "tensorflow/core/framework/versions.pb.h"
25 #include "tensorflow/core/graph/algorithm.h"
26 #include "tensorflow/core/graph/control_flow.h"
27 #include "tensorflow/core/graph/graph.h"
28
29 namespace tensorflow {
30
31 static constexpr const char* const kNodeLabel = "Func";
32
33 // Represents the index-th output of a node.
34 struct Endpoint {
35 Node* node;
36 int index;
37
38 // Returns the string name represents this endpoint.
nametensorflow::Endpoint39 string name() const {
40 if (index == 0) {
41 return node->name();
42 } else {
43 return strings::StrCat(node->name(), ":", index);
44 }
45 }
46
dtypetensorflow::Endpoint47 DataType dtype() const { return node->output_type(index); }
48 };
49
50 // The following Add* routines are used to add a few graph nodes while
51 // functions are transformed.
AddNoOp(StringPiece name,Graph * g)52 static Node* AddNoOp(StringPiece name, Graph* g) {
53 NodeDef ndef;
54 ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
55 ndef.set_op("NoOp");
56 Status s;
57 Node* ret = g->AddNode(ndef, &s);
58 TF_CHECK_OK(s);
59 return ret;
60 }
61
AddIdentity(StringPiece name,Graph * g,Endpoint input)62 static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) {
63 DCHECK_LT(0, input.dtype());
64 NodeDef ndef;
65 ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
66 ndef.set_op("Identity");
67 ndef.add_input(input.name());
68 AddNodeAttr("T", BaseType(input.dtype()), &ndef);
69 Status s;
70 Node* ret = g->AddNode(ndef, &s);
71 TF_CHECK_OK(s);
72 g->AddEdge(input.node, input.index, ret, 0);
73 return ret;
74 }
75
DumpGraph(StringPiece label,const Graph * g)76 void DumpGraph(StringPiece label, const Graph* g) {
77 // TODO(zhifengc): Change Graph to record #nodes.
78 VLOG(2) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges "
79 << g->num_edges();
80 if (VLOG_IS_ON(5)) {
81 for (const auto& line : str_util::Split(DebugString(g), '\n')) {
82 VLOG(5) << "|| " << line;
83 }
84 }
85 }
86
RemoveDeadNodes(Graph * g)87 bool RemoveDeadNodes(Graph* g) {
88 VLOG(2) << "Removing dead nodes";
89 std::unordered_set<const Node*> nodes;
90 for (auto n : g->nodes()) {
91 if (n->IsSource() || n->IsSink() || n->IsControlFlow() ||
92 n->op_def().is_stateful()) {
93 nodes.insert(n);
94 }
95 }
96 return PruneForReverseReachability(g, std::move(nodes));
97 }
98
99 namespace {
100 // If 'edges' contains only 1 non-control edge, returns it. Otherwise,
101 // returns a nullptr.
GetTheOnlyDataEdge(const EdgeSet & edges)102 const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) {
103 const Edge* ret = nullptr;
104 for (const Edge* e : edges) {
105 if (e->IsControlEdge() || ret) {
106 // Don't touch it if there is a control edge.
107 return nullptr;
108 }
109 if (IsRefType(e->src()->output_type(e->src_output()))) {
110 // Don't touch it if the identity node is effectively de-reffing
111 // a ref.
112 return nullptr;
113 }
114 if (IsRecv(e->src()) || IsSwitch(e->src())) {
115 // Don't touch it if the identity is introduced for control flow.
116 // Recv disables all its successors if it receives a dead signal.
117 // When Recv has an outgoing control edge, the current executor
118 // would not disable the destination. The current solution (see
119 // graph_partition.cc) is to add an identity after Recv and change
120 // the control edge to be from this identity node. So the identity
121 // can't be removed.
122 return nullptr;
123 }
124 ret = e;
125 }
126 return ret;
127 }
128 } // end namespace
129
RemoveIdentityNodes(Graph * g)130 bool RemoveIdentityNodes(Graph* g) {
131 VLOG(2) << "Removing identity nodes";
132 bool removed_any = false;
133 gtl::InlinedVector<Node*, 8> matches;
134 for (Node* n : g->nodes()) {
135 if (!n->IsIdentity()) continue;
136 if (!GetTheOnlyDataEdge(n->in_edges())) continue;
137
138 // Some identity nodes are used as sink nodes to give names to output
139 // tensors. These nodes are not going to be executed unless they are in the
140 // fetch set. But if they are in the fetch set we don't want to remove them.
141 if (n->out_edges().empty()) continue;
142
143 matches.push_back(n);
144 }
145 if (!matches.empty()) {
146 for (Node* n : matches) {
147 const Edge* in = GetTheOnlyDataEdge(n->in_edges());
148 for (const Edge* out : n->out_edges()) {
149 if (out->IsControlEdge()) {
150 g->AddControlEdge(in->src(), out->dst());
151 } else {
152 g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input());
153 }
154 }
155 VLOG(2) << "Remove Identity: " << n->DebugString();
156 g->RemoveNode(n);
157 removed_any = true;
158 }
159 }
160 return removed_any;
161 }
162
RemoveListArrayConverter(Graph * g)163 bool RemoveListArrayConverter(Graph* g) {
164 VLOG(2) << "Removing list array converter";
165 gtl::InlinedVector<Node*, 8> matches;
166 for (Node* n : g->nodes()) {
167 if ((n->type_string() == "_ListToArray") ||
168 (n->type_string() == "_ArrayToList")) {
169 matches.push_back(n);
170 }
171 }
172 bool removed_any = false;
173 if (!matches.empty()) {
174 for (Node* n : matches) {
175 if (n->num_inputs() != n->num_outputs()) {
176 continue; // Not expected. Skip.
177 }
178 gtl::InlinedVector<Node*, 8> identity_nodes(n->num_inputs(), nullptr);
179
180 const auto no_op = [&](StringPiece name) -> Node* {
181 return AddNoOp(absl::StrCat(n->name(), "/", name), g);
182 };
183
184 const auto identity = [&](StringPiece name, Endpoint input) -> Node* {
185 Node* node = AddIdentity(absl::StrCat(n->name(), "/", name), g, input);
186 node->set_requested_device(input.node->def().device());
187 return node;
188 };
189
190 // Process input edges first.
191 Node* input_control_node = nullptr;
192 for (const Edge* e : n->in_edges()) {
193 if (e->IsControlEdge()) {
194 if (input_control_node == nullptr) {
195 // If node "n" has any control dependencies, adds a no-op
196 // node (input_control_node) which the additional Identity
197 // nodes depends on and the input_control_node depends on
198 // the node "n"s control dependencies.
199 input_control_node = no_op("input_control_node");
200 }
201 g->AddControlEdge(e->src(), input_control_node);
202 } else {
203 const int index = e->dst_input();
204 Node** id_node = &identity_nodes[index];
205 if (*id_node != nullptr) {
206 LOG(ERROR)
207 << "RemoveListArrayConverter unexpected duplicated input: "
208 << e->dst_input();
209 return removed_any;
210 }
211 *id_node = identity("input", {e->src(), e->src_output()});
212 }
213 }
214
215 // If node "n" has any control dependencies, the added identity
216 // nodes should have control dependencies on input_control_node.
217 if (input_control_node != nullptr) {
218 for (Node* id : identity_nodes) {
219 g->AddControlEdge(input_control_node, id);
220 }
221 }
222
223 Node* output_control_node = nullptr;
224 for (const Edge* e : n->out_edges()) {
225 if (e->IsControlEdge()) {
226 if (output_control_node == nullptr) {
227 // If node "n" is control-depended upon by other nodes,
228 // adds a no-op node (output_control_node) which those
229 // nodes will depend on and output_control_node depends on
230 // all Identity nodes.
231 output_control_node = no_op("output_control_node");
232 }
233 g->AddControlEdge(output_control_node, e->dst());
234 } else {
235 Node* id_node = identity_nodes[e->src_output()];
236 if (id_node == nullptr) {
237 LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: "
238 << e->src_output();
239 return removed_any;
240 }
241 CHECK(id_node);
242 g->AddEdge(id_node, 0, e->dst(), e->dst_input());
243 }
244 }
245
246 // If any nodes have control dependencies on node "n", those
247 // nodes should have control dependencies on
248 // output_control_node.
249 if (output_control_node != nullptr) {
250 for (Node* id : identity_nodes) {
251 g->AddControlEdge(id, output_control_node);
252 }
253 }
254
255 g->RemoveNode(n);
256 removed_any = true;
257 }
258 }
259 return removed_any;
260 }
261
NameAndAttrsFromFunctionCall(const NodeDef & call_def,NameAttrList * function)262 Status NameAndAttrsFromFunctionCall(const NodeDef& call_def,
263 NameAttrList* function) {
264 if (call_def.op() == "PartitionedCall" ||
265 call_def.op() == "StatefulPartitionedCall") {
266 TF_RETURN_IF_ERROR(GetNodeAttr(call_def, "f", function));
267 } else {
268 function->set_name(call_def.op());
269 *function->mutable_attr() = call_def.attr();
270 }
271 return Status::OK();
272 }
273
InstantiateFunctionCall(const NodeDef & call_def,FunctionLibraryRuntime * flr,FunctionLibraryRuntime::Handle * handle)274 Status InstantiateFunctionCall(const NodeDef& call_def,
275 FunctionLibraryRuntime* flr,
276 FunctionLibraryRuntime::Handle* handle) {
277 NameAttrList function;
278 TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(call_def, &function));
279 return flr->Instantiate(function.name(), AttrSlice(&function.attr()), handle);
280 }
281
IsFunctionCall(const FunctionLibraryDefinition & lib_def,const Node & node)282 bool IsFunctionCall(const FunctionLibraryDefinition& lib_def,
283 const Node& node) {
284 return node.IsFunctionCall();
285 }
286
NewName(const Node * n,bool pretty)287 string NewName(const Node* n, bool pretty) {
288 if (pretty) {
289 return strings::StrCat(n->type_string(), n->id());
290 } else {
291 return strings::StrCat("n", n->id());
292 }
293 }
294
295 // TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef.
296 // and stash the original NodeDef name as an attr for documentation
297 // purpose.
ToGraphDef(const Graph * g,GraphDef * gdef,bool pretty)298 void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
299 // We visit nodes in forward topological sort order, which is a
300 // possible execution order of the graph.
301 gtl::InlinedVector<const Edge*, 4> inputs;
302 gdef->Clear();
303 *gdef->mutable_versions() = g->versions();
304
305 std::vector<Node*> start_nodes;
306 for (Node* n : g->nodes()) {
307 if (n->out_edges().empty()) {
308 start_nodes.push_back(n);
309 }
310 }
311
312 ReverseDFSFrom(*g, start_nodes, nullptr, [gdef, pretty, &inputs](Node* n) {
313 if (!n->IsOp()) return;
314 NodeDef* ndef = gdef->add_node();
315 ndef->set_name(NewName(n, pretty));
316 ndef->set_op(n->type_string());
317 for (const auto& attr : n->attrs()) {
318 (*ndef->mutable_attr())[attr.first] = attr.second;
319 }
320
321 if (!n->assigned_device_name().empty()) {
322 ndef->set_device(n->assigned_device_name());
323 } else {
324 ndef->set_device(n->requested_device());
325 }
326
327 inputs.clear();
328 inputs.resize(n->num_inputs());
329 for (const Edge* e : n->in_edges()) {
330 if (e->IsControlEdge()) {
331 inputs.push_back(e);
332 } else {
333 if (inputs[e->dst_input()] == nullptr) {
334 inputs[e->dst_input()] = e;
335 } else {
336 LOG(WARNING) << "Malformed graph node. multiple input edges: "
337 << n->DebugString();
338 }
339 }
340 }
341 // node->name() is merely NodeDef::name, which are not guaranteed
342 // to be unique and stable after optimization rewrites. Therefore,
343 // we use "n<node id>" instead.
344 for (const Edge* e : inputs) {
345 if (e == nullptr) {
346 ndef->add_input("unknown");
347 continue;
348 }
349 const string srcname = NewName(e->src(), pretty);
350 if (!e->src()->IsOp()) {
351 } else if (e->IsControlEdge()) {
352 ndef->add_input(strings::StrCat("^", srcname));
353 } else if (e->src_output() == 0) {
354 ndef->add_input(srcname);
355 } else {
356 ndef->add_input(strings::StrCat(srcname, ":", e->src_output()));
357 }
358 }
359 });
360 }
361
DebugString(const Graph * g)362 string DebugString(const Graph* g) {
363 GraphDef gdef;
364 ToGraphDef(g, &gdef);
365 return DebugString(gdef);
366 }
367
368 } // end namespace tensorflow
369