1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/subgraph.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <string>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/node_def_util.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/graph/algorithm.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/graph/graph_constructor.h"
31 #include "tensorflow/core/graph/tensor_id.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/platform/logging.h"
35
36 namespace tensorflow {
37 namespace subgraph {
38
39 // ----------------------------------------------------------------------------
40 // Subgraph construction-related routines
41 // ----------------------------------------------------------------------------
42 // TODO(vrv): Profile the unordered_set and unordered_map use in this file to
43 // see if we should use an alternative implementation.
44
45 namespace {
46
47 typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameIndex;
48
49 // Rewrite graph by replacing the output tensors specified in
50 // "fed_outputs" with special feed nodes for each specified output
51 // tensor, and removing any nodes that are now disconnected from the
52 // part of the graph that reaches the sink node. The set of special
53 // feed nodes added to the graph are returned in "*feed_nodes".
54 //
55 // Return true on success. On error, return false and sets *error to
56 // an appropriate error message (and *g is left in an indeterminate
57 // state).
FeedInputs(Graph * g,const std::vector<std::unique_ptr<PruneRewrite>> & feed_rewrites,NameIndex * name_index,DataTypeVector * out_feed_types)58 Status FeedInputs(
59 Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites,
60 NameIndex* name_index, DataTypeVector* out_feed_types) {
61 out_feed_types->clear();
62 out_feed_types->reserve(feed_rewrites.size());
63 for (size_t i = 0; i < feed_rewrites.size(); ++i) {
64 const string& t = feed_rewrites[i]->endpoint_name();
65 TensorId id(ParseTensorName(t));
66
67 auto iter = name_index->find(id.first);
68 if (iter == name_index->end()) {
69 return errors::NotFound("FeedInputs: unable to find feed output ", t);
70 }
71 Node* n = iter->second;
72 DCHECK_EQ(n->name(), id.first);
73 if (id.second >= n->num_outputs()) {
74 return errors::InvalidArgument(
75 "FeedInputs: ", t, " should have output index < ", n->num_outputs());
76 }
77
78 Node* feed_node;
79 TF_RETURN_IF_ERROR(
80 feed_rewrites[i]->AddNode(g, {n, id.second}, &feed_node));
81
82 // Update name_index
83 (*name_index)[feed_node->name()] = feed_node;
84 // Duplicate control edges aren't allowed, but feed_node was *just* created
85 // so there's no need to check for a duplicate.
86 g->AddControlEdge(g->source_node(), feed_node, true);
87
88 // Look through edges coming out of "n" for edges whose src_output() index
89 // matches "output_index". If found, replace the edges with a connection
90 // from the special feed node.
91 std::vector<const Edge*> to_remove;
92 for (const Edge* e : n->out_edges()) {
93 if (e->src_output() == id.second) {
94 to_remove.emplace_back(e);
95 } else if (e->src_output() == Graph::kControlSlot &&
96 (n->type_string() == "Placeholder" ||
97 n->type_string() == "PlaceholderV2")) {
98 // When feeding a Placeholder node, any outgoing control edges
99 // will be replaced with a control edge from the replacement
100 // feed_node.
101 // TODO(josh11b,mrry): Come up with a more elegant way of addressing
102 // the general version of this problem.
103 to_remove.emplace_back(e);
104 }
105 }
106
107 for (const Edge* e : to_remove) {
108 if (e->src_output() == id.second) {
109 g->AddEdge(feed_node, 0, e->dst(), e->dst_input());
110 } else {
111 CHECK_EQ(Graph::kControlSlot, e->src_output());
112 // Duplicate control edges aren't allowed, but feed_node was *just*
113 // created so there's no need to check for a duplicate.
114 g->AddControlEdge(feed_node, e->dst(), true);
115 }
116 g->RemoveEdge(e);
117 }
118 out_feed_types->push_back(BaseType(n->output_type(id.second)));
119 }
120 return Status::OK();
121 }
122
FetchOutputs(Graph * g,const std::vector<std::unique_ptr<PruneRewrite>> & fetch_rewrites,NameIndex * name_index,std::vector<Node * > * out_fetch_nodes,DataTypeVector * out_fetch_types)123 Status FetchOutputs(
124 Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites,
125 NameIndex* name_index, std::vector<Node*>* out_fetch_nodes,
126 DataTypeVector* out_fetch_types) {
127 out_fetch_nodes->clear();
128 out_fetch_nodes->reserve(fetch_rewrites.size());
129 for (size_t i = 0; i < fetch_rewrites.size(); ++i) {
130 const string& t = fetch_rewrites[i]->endpoint_name();
131
132 // Parse t into node_name and output_index.
133 TensorId id(ParseTensorName(t));
134
135 // Find node in graph with that name.
136 auto iter = name_index->find(id.first);
137 if (iter == name_index->end()) {
138 return errors::NotFound("FetchOutputs node ", t, ": not found");
139 }
140 Node* n = iter->second;
141 DCHECK_EQ(n->name(), id.first);
142 VLOG(2) << "Found fetch node for " << t;
143
144 // Validate output_index
145 if (n->num_outputs() == 0) {
146 return errors::InvalidArgument(
147 "Tried to fetch data for '", t,
148 "', which produces no output. To run to a node but not fetch any "
149 "data, pass '",
150 t,
151 "' as an argument to the 'target_node_names' argument of the "
152 "Session::Run API.");
153 } else if (id.second >= n->num_outputs()) {
154 return errors::InvalidArgument("FetchOutputs ", t,
155 ": output index too large, must be < ",
156 n->num_outputs());
157 }
158
159 // Create the fetch Node and connect it up
160 Node* fetch_node;
161 TF_RETURN_IF_ERROR(
162 fetch_rewrites[i]->AddNode(g, {n, id.second}, &fetch_node));
163
164 // Update the index.
165 (*name_index)[fetch_node->name()] = fetch_node;
166
167 // Duplicate control edges aren't allowed, but fetch_node was *just* created
168 // so there's no need to check for a duplicate.
169 g->AddControlEdge(fetch_node, g->sink_node(), true);
170 out_fetch_nodes->push_back(fetch_node);
171 out_fetch_types->push_back(BaseType(n->output_type(id.second)));
172 }
173
174 return Status::OK();
175 }
176
AddNodeToTargets(const string & node_or_tensor_name,const NameIndex & name_index,std::unordered_set<const Node * > * targets)177 bool AddNodeToTargets(const string& node_or_tensor_name,
178 const NameIndex& name_index,
179 std::unordered_set<const Node*>* targets) {
180 TensorId id = ParseTensorName(node_or_tensor_name);
181 auto iter = name_index.find(id.first);
182 if (iter == name_index.end()) {
183 return false;
184 }
185 const Node* n = iter->second;
186 CHECK_EQ(n->name(), id.first);
187 targets->insert(n);
188 return true;
189 }
190
PruneForTargets(Graph * g,const NameIndex & name_index,const std::vector<Node * > & fetch_nodes,const gtl::ArraySlice<string> & target_nodes)191 Status PruneForTargets(Graph* g, const NameIndex& name_index,
192 const std::vector<Node*>& fetch_nodes,
193 const gtl::ArraySlice<string>& target_nodes) {
194 string not_found;
195 std::unordered_set<const Node*> targets;
196 for (Node* n : fetch_nodes) {
197 if (!AddNodeToTargets(n->name(), name_index, &targets)) {
198 strings::StrAppend(¬_found, n->name(), " ");
199 }
200 }
201 for (const string& s : target_nodes) {
202 if (!AddNodeToTargets(s, name_index, &targets)) {
203 strings::StrAppend(¬_found, s, " ");
204 }
205 }
206 if (!not_found.empty()) {
207 return errors::NotFound("PruneForTargets: Some target nodes not found: ",
208 not_found);
209 }
210 PruneForReverseReachability(g, targets);
211
212 // Reconnect nodes with no outgoing edges to the sink node
213 FixupSourceAndSinkEdges(g);
214
215 return Status::OK();
216 }
217
218 } // namespace
219
AddNode(Graph * g,NodeBuilder::NodeOut feed_tensor,Node ** out_node)220 Status ArgFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
221 Node** out_node) {
222 // NOTE(mrry): We must include the index as part of the node
223 // name, because _Arg is a "stateful" kernel and therefore
224 // its name must uniquely identify a kernel instance across all
225 // graphs in the same session.
226 TF_RETURN_IF_ERROR(
227 NodeBuilder(strings::StrCat("_arg_", feed_tensor.node->name(), "_",
228 feed_tensor.index, "_", arg_index_),
229 "_Arg")
230 .Attr("T", BaseType(feed_tensor.node->output_type(feed_tensor.index)))
231 .Attr("index", arg_index_)
232 .Finalize(g, out_node));
233 (*out_node)->set_assigned_device_name(device_info().name());
234 return Status::OK();
235 }
236
AddNode(Graph * g,NodeBuilder::NodeOut feed_tensor,Node ** out_node)237 Status RecvFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
238 Node** out_node) {
239 TF_RETURN_IF_ERROR(
240 NodeBuilder(strings::StrCat("_recv_", feed_tensor.node->name(), "_",
241 feed_tensor.index),
242 "_Recv")
243 .Attr("tensor_type",
244 BaseType(feed_tensor.node->output_type(feed_tensor.index)))
245 .Attr("tensor_name", endpoint_name())
246 .Attr("send_device", device_info().name())
247 .Attr("recv_device", device_info().name())
248 .Attr("send_device_incarnation",
249 static_cast<int64>(device_info().incarnation()))
250 .Attr("client_terminated", true)
251 .Finalize(g, out_node));
252
253 (*out_node)->set_assigned_device_name(device_info().name());
254 return Status::OK();
255 }
256
AddNode(Graph * g,NodeBuilder::NodeOut fetch_tensor,Node ** out_node)257 Status RetvalFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
258 Node** out_node) {
259 // NOTE(mrry): We must include the index as part of the node
260 // name, because _Retval is a "stateful" kernel and therefore
261 // its name must uniquely identify a kernel instance across all
262 // graphs in the same session.
263 TF_RETURN_IF_ERROR(
264 NodeBuilder(strings::StrCat("_retval_", fetch_tensor.node->name(), "_",
265 fetch_tensor.index, "_", retval_index_),
266 "_Retval")
267 .Input(fetch_tensor.node, fetch_tensor.index)
268 .Attr("T",
269 BaseType(fetch_tensor.node->output_type(fetch_tensor.index)))
270 .Attr("index", retval_index_)
271 .Finalize(g, out_node));
272 (*out_node)->set_assigned_device_name(device_info().name());
273 return Status::OK();
274 }
275
AddNode(Graph * g,NodeBuilder::NodeOut fetch_tensor,Node ** out_node)276 Status SendFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
277 Node** out_node) {
278 TF_RETURN_IF_ERROR(
279 NodeBuilder(strings::StrCat("_send_", fetch_tensor.node->name(), "_",
280 fetch_tensor.index),
281 "_Send")
282 .Input(fetch_tensor.node, fetch_tensor.index)
283 .Attr("tensor_name", endpoint_name())
284 .Attr("send_device", device_info().name())
285 .Attr("recv_device", device_info().name())
286 .Attr("send_device_incarnation",
287 static_cast<int64>(device_info().incarnation()))
288 .Attr("client_terminated", true)
289 .Finalize(g, out_node));
290 (*out_node)->set_assigned_device_name(device_info().name());
291 return Status::OK();
292 }
293
RewriteGraphForExecution(Graph * g,const gtl::ArraySlice<string> & fed_outputs,const gtl::ArraySlice<string> & fetch_outputs,const gtl::ArraySlice<string> & target_node_names,const DeviceAttributes & device_info,bool use_function_convention,RewriteGraphMetadata * out_metadata)294 Status RewriteGraphForExecution(
295 Graph* g, const gtl::ArraySlice<string>& fed_outputs,
296 const gtl::ArraySlice<string>& fetch_outputs,
297 const gtl::ArraySlice<string>& target_node_names,
298 const DeviceAttributes& device_info, bool use_function_convention,
299 RewriteGraphMetadata* out_metadata) {
300 std::vector<std::unique_ptr<PruneRewrite>> feed_rewrites;
301 feed_rewrites.reserve(fed_outputs.size());
302 if (use_function_convention) {
303 for (size_t i = 0; i < fed_outputs.size(); ++i) {
304 feed_rewrites.emplace_back(new ArgFeedRewrite(
305 &fed_outputs[i], &device_info, static_cast<int32>(i)));
306 }
307 } else {
308 for (const string& fed_output : fed_outputs) {
309 feed_rewrites.emplace_back(
310 new RecvFeedRewrite(&fed_output, &device_info));
311 }
312 }
313
314 std::vector<std::unique_ptr<PruneRewrite>> fetch_rewrites;
315 fetch_rewrites.reserve(fetch_outputs.size());
316 if (use_function_convention) {
317 for (size_t i = 0; i < fetch_outputs.size(); ++i) {
318 fetch_rewrites.emplace_back(new RetvalFetchRewrite(
319 &fetch_outputs[i], &device_info, static_cast<int32>(i)));
320 }
321 } else {
322 for (const string& fetch_output : fetch_outputs) {
323 fetch_rewrites.emplace_back(
324 new SendFetchRewrite(&fetch_output, &device_info));
325 }
326 }
327
328 return RewriteGraphForExecution(g, feed_rewrites, fetch_rewrites,
329 target_node_names, out_metadata);
330 }
331
332 namespace {
333 template <typename StringContainer>
ConvertToVector(StringContainer field)334 std::vector<string> ConvertToVector(StringContainer field) {
335 return std::vector<string>(field.begin(), field.end());
336 }
337 } // namespace
338
RewriteGraphForExecution(Graph * g,const std::vector<std::unique_ptr<PruneRewrite>> & feed_rewrites,const std::vector<std::unique_ptr<PruneRewrite>> & fetch_rewrites,const gtl::ArraySlice<string> & target_node_names,RewriteGraphMetadata * out_metadata)339 Status RewriteGraphForExecution(
340 Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites,
341 const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites,
342 const gtl::ArraySlice<string>& target_node_names,
343 RewriteGraphMetadata* out_metadata) {
344 if (fetch_rewrites.empty() && target_node_names.empty()) {
345 return errors::InvalidArgument(
346 "Must specify at least one target to fetch or execute.");
347 }
348
349 std::unordered_set<string> endpoints;
350 for (const auto& feed_rewrite : feed_rewrites) {
351 auto result = endpoints.insert(feed_rewrite->endpoint_name());
352 if (!result.second) {
353 return errors::InvalidArgument("Endpoint \"",
354 feed_rewrite->endpoint_name(),
355 "\" fed more than once.");
356 }
357 }
358
359 for (const auto& fetch_rewrite : fetch_rewrites) {
360 if (endpoints.count(fetch_rewrite->endpoint_name()) > 0) {
361 return errors::InvalidArgument(fetch_rewrite->endpoint_name(),
362 " is both fed and fetched.");
363 }
364 }
365
366 // A separate index mapping name to Node*, for use by FeedInputs,
367 // FetchOutputs, and PruneForTargets
368 NameIndex name_index;
369 name_index.reserve(g->num_nodes());
370 for (Node* n : g->nodes()) {
371 name_index[n->name()] = n;
372 }
373
374 // Add the feeds. This may replace nodes in the graph, including the nodes
375 // currently listed in "fetch_rewrites". We pass "name_index" so the index is
376 // kept up to date.
377 if (!feed_rewrites.empty()) {
378 TF_RETURN_IF_ERROR(
379 FeedInputs(g, feed_rewrites, &name_index, &out_metadata->feed_types));
380 }
381
382 // Add the fetch nodes, also updating "name_index".
383 std::vector<Node*> fetch_nodes;
384 if (!fetch_rewrites.empty()) {
385 TF_RETURN_IF_ERROR(FetchOutputs(g, fetch_rewrites, &name_index,
386 &fetch_nodes, &out_metadata->fetch_types));
387 }
388
389 // Prune the graph to only compute what is needed for the fetch nodes and the
390 // target nodes.
391 if (!fetch_nodes.empty() || !target_node_names.empty()) {
392 TF_RETURN_IF_ERROR(
393 PruneForTargets(g, name_index, fetch_nodes, target_node_names));
394 }
395
396 return Status::OK();
397 }
398
399 } // namespace subgraph
400
401 } // namespace tensorflow
402