• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/subgraph.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <string>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/node_def_util.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/graph/algorithm.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/graph/tensor_id.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/platform/logging.h"
34 
35 namespace tensorflow {
36 namespace subgraph {
37 
38 // ----------------------------------------------------------------------------
39 // Subgraph construction-related routines
40 // ----------------------------------------------------------------------------
41 // TODO(vrv): Profile the unordered_set and unordered_map use in this file to
42 // see if we should use an alternative implementation.
43 
44 namespace {
45 
46 typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameIndex;
47 
48 // Rewrite graph by replacing the output tensors specified in
49 // "fed_outputs" with special feed nodes for each specified output
50 // tensor, and removing any nodes that are now disconnected from the
51 // part of the graph that reaches the sink node.  The set of special
52 // feed nodes added to the graph are returned in "*feed_nodes".
53 //
54 // Return true on success.  On error, return false and sets *error to
55 // an appropriate error message (and *g is left in an indeterminate
56 // state).
FeedInputs(Graph * g,const std::vector<std::unique_ptr<PruneRewrite>> & feed_rewrites,NameIndex * name_index,DataTypeVector * out_feed_types)57 Status FeedInputs(
58     Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites,
59     NameIndex* name_index, DataTypeVector* out_feed_types) {
60   out_feed_types->clear();
61   out_feed_types->reserve(feed_rewrites.size());
62   for (size_t i = 0; i < feed_rewrites.size(); ++i) {
63     const string& t = feed_rewrites[i]->endpoint_name();
64     TensorId id(ParseTensorName(t));
65 
66     auto iter = name_index->find(id.first);
67     if (iter == name_index->end()) {
68       return errors::NotFound("FeedInputs: unable to find feed output ", t);
69     }
70     Node* n = iter->second;
71     DCHECK_EQ(n->name(), id.first);
72     if (id.second >= n->num_outputs()) {
73       return errors::InvalidArgument(
74           "FeedInputs: ", t, " should have output index < ", n->num_outputs());
75     }
76 
77     Node* feed_node;
78     TF_RETURN_IF_ERROR(
79         feed_rewrites[i]->AddNode(g, {n, id.second}, &feed_node));
80 
81     // Update name_index
82     (*name_index)[feed_node->name()] = feed_node;
83     // Duplicate control edges aren't allowed, but feed_node was *just* created
84     // so there's no need to check for a duplicate.
85     g->AddControlEdge(g->source_node(), feed_node, true);
86 
87     // Look through edges coming out of "n" for edges whose src_output() index
88     // matches "output_index".  If found, replace the edges with a connection
89     // from the special feed node.
90     std::vector<const Edge*> to_remove;
91     for (const Edge* e : n->out_edges()) {
92       if (e->src_output() == id.second) {
93         to_remove.emplace_back(e);
94       } else if (e->src_output() == Graph::kControlSlot &&
95                  (n->type_string() == "Placeholder" ||
96                   n->type_string() == "PlaceholderV2")) {
97         // When feeding a Placeholder node, any outgoing control edges
98         // will be replaced with a control edge from the replacement
99         // feed_node.
100         // TODO(josh11b,mrry): Come up with a more elegant way of addressing
101         // the general version of this problem.
102         to_remove.emplace_back(e);
103       }
104     }
105 
106     for (const Edge* e : to_remove) {
107       if (e->src_output() == id.second) {
108         g->AddEdge(feed_node, 0, e->dst(), e->dst_input());
109       } else {
110         CHECK_EQ(Graph::kControlSlot, e->src_output());
111         // Duplicate control edges aren't allowed, but feed_node was *just*
112         // created so there's no need to check for a duplicate.
113         g->AddControlEdge(feed_node, e->dst(), true);
114       }
115       g->RemoveEdge(e);
116     }
117     out_feed_types->push_back(BaseType(n->output_type(id.second)));
118   }
119   return OkStatus();
120 }
121 
FetchOutputs(Graph * g,const std::vector<std::unique_ptr<PruneRewrite>> & fetch_rewrites,NameIndex * name_index,std::vector<Node * > * out_fetch_nodes,DataTypeVector * out_fetch_types)122 Status FetchOutputs(
123     Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites,
124     NameIndex* name_index, std::vector<Node*>* out_fetch_nodes,
125     DataTypeVector* out_fetch_types) {
126   out_fetch_nodes->clear();
127   out_fetch_nodes->reserve(fetch_rewrites.size());
128   for (size_t i = 0; i < fetch_rewrites.size(); ++i) {
129     const string& t = fetch_rewrites[i]->endpoint_name();
130 
131     // Parse t into node_name and output_index.
132     TensorId id(ParseTensorName(t));
133 
134     // Find node in graph with that name.
135     auto iter = name_index->find(id.first);
136     if (iter == name_index->end()) {
137       return errors::NotFound("FetchOutputs node ", t, ": not found");
138     }
139     Node* n = iter->second;
140     DCHECK_EQ(n->name(), id.first);
141     VLOG(2) << "Found fetch node for " << t;
142 
143     // Validate output_index
144     if (n->num_outputs() == 0) {
145       return errors::InvalidArgument(
146           "Tried to fetch data for '", t,
147           "', which produces no output.  To run to a node but not fetch any "
148           "data, pass '",
149           t,
150           "' as an argument to the 'target_node_names' argument of the "
151           "Session::Run API.");
152     } else if (id.second >= n->num_outputs()) {
153       return errors::InvalidArgument("FetchOutputs ", t,
154                                      ": output index too large, must be < ",
155                                      n->num_outputs());
156     }
157 
158     // Create the fetch Node and connect it up
159     Node* fetch_node;
160     TF_RETURN_IF_ERROR(
161         fetch_rewrites[i]->AddNode(g, {n, id.second}, &fetch_node));
162 
163     // Update the index.
164     (*name_index)[fetch_node->name()] = fetch_node;
165 
166     // Duplicate control edges aren't allowed, but fetch_node was *just* created
167     // so there's no need to check for a duplicate.
168     g->AddControlEdge(fetch_node, g->sink_node(), true);
169     out_fetch_nodes->push_back(fetch_node);
170     out_fetch_types->push_back(BaseType(n->output_type(id.second)));
171   }
172 
173   return OkStatus();
174 }
175 
AddNodeToTargets(const string & node_or_tensor_name,const NameIndex & name_index,std::unordered_set<const Node * > * targets)176 bool AddNodeToTargets(const string& node_or_tensor_name,
177                       const NameIndex& name_index,
178                       std::unordered_set<const Node*>* targets) {
179   TensorId id = ParseTensorName(node_or_tensor_name);
180   auto iter = name_index.find(id.first);
181   if (iter == name_index.end()) {
182     return false;
183   }
184   const Node* n = iter->second;
185   CHECK_EQ(n->name(), id.first);
186   targets->insert(n);
187   return true;
188 }
189 
PruneForTargets(Graph * g,const NameIndex & name_index,const std::vector<Node * > & fetch_nodes,const gtl::ArraySlice<string> & target_nodes)190 Status PruneForTargets(Graph* g, const NameIndex& name_index,
191                        const std::vector<Node*>& fetch_nodes,
192                        const gtl::ArraySlice<string>& target_nodes) {
193   string not_found;
194   std::unordered_set<const Node*> targets;
195   for (Node* n : fetch_nodes) {
196     if (!AddNodeToTargets(n->name(), name_index, &targets)) {
197       strings::StrAppend(&not_found, n->name(), " ");
198     }
199   }
200   for (const string& s : target_nodes) {
201     if (!AddNodeToTargets(s, name_index, &targets)) {
202       strings::StrAppend(&not_found, s, " ");
203     }
204   }
205   if (!not_found.empty()) {
206     return errors::NotFound("PruneForTargets: Some target nodes not found: ",
207                             not_found);
208   }
209   PruneForReverseReachability(g, std::move(targets));
210 
211   // Reconnect nodes with no outgoing edges to the sink node
212   FixupSourceAndSinkEdges(g);
213 
214   return OkStatus();
215 }
216 
217 }  // namespace
218 
AddNode(Graph * g,NodeBuilder::NodeOut feed_tensor,Node ** out_node)219 Status ArgFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
220                                Node** out_node) {
221   // NOTE(mrry): We must include the index as part of the node
222   // name, because _Arg is a "stateful" kernel and therefore
223   // its name must uniquely identify a kernel instance across all
224   // graphs in the same session.
225   TF_RETURN_IF_ERROR(
226       NodeBuilder(strings::StrCat("_arg_", feed_tensor.node->name(), "_",
227                                   feed_tensor.index, "_", arg_index_),
228                   "_Arg")
229           .Attr("T", BaseType(feed_tensor.node->output_type(feed_tensor.index)))
230           .Attr("index", arg_index_)
231           .Finalize(g, out_node, /*consume=*/true));
232   (*out_node)->set_assigned_device_name(device_info().name());
233   return OkStatus();
234 }
235 
AddNode(Graph * g,NodeBuilder::NodeOut feed_tensor,Node ** out_node)236 Status RecvFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
237                                 Node** out_node) {
238   TF_RETURN_IF_ERROR(
239       NodeBuilder(strings::StrCat("_recv_", feed_tensor.node->name(), "_",
240                                   feed_tensor.index),
241                   "_Recv")
242           .Attr("tensor_type",
243                 BaseType(feed_tensor.node->output_type(feed_tensor.index)))
244           .Attr("tensor_name", endpoint_name())
245           .Attr("send_device", device_info().name())
246           .Attr("recv_device", device_info().name())
247           .Attr("send_device_incarnation",
248                 static_cast<int64_t>(device_info().incarnation()))
249           .Attr("client_terminated", true)
250           .Finalize(g, out_node, /*consume=*/true));
251 
252   (*out_node)->set_assigned_device_name(device_info().name());
253   return OkStatus();
254 }
255 
AddNode(Graph * g,NodeBuilder::NodeOut fetch_tensor,Node ** out_node)256 Status RetvalFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
257                                    Node** out_node) {
258   // NOTE(mrry): We must include the index as part of the node
259   // name, because _Retval is a "stateful" kernel and therefore
260   // its name must uniquely identify a kernel instance across all
261   // graphs in the same session.
262   TF_RETURN_IF_ERROR(
263       NodeBuilder(strings::StrCat("_retval_", fetch_tensor.node->name(), "_",
264                                   fetch_tensor.index, "_", retval_index_),
265                   "_Retval")
266           .Input(fetch_tensor.node, fetch_tensor.index)
267           .Attr("T",
268                 BaseType(fetch_tensor.node->output_type(fetch_tensor.index)))
269           .Attr("index", retval_index_)
270           .Finalize(g, out_node, /*consume=*/true));
271   (*out_node)->set_assigned_device_name(device_info().name());
272   return OkStatus();
273 }
274 
AddNode(Graph * g,NodeBuilder::NodeOut fetch_tensor,Node ** out_node)275 Status SendFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
276                                  Node** out_node) {
277   TF_RETURN_IF_ERROR(
278       NodeBuilder(strings::StrCat("_send_", fetch_tensor.node->name(), "_",
279                                   fetch_tensor.index),
280                   "_Send")
281           .Input(fetch_tensor.node, fetch_tensor.index)
282           .Attr("tensor_name", endpoint_name())
283           .Attr("send_device", device_info().name())
284           .Attr("recv_device", device_info().name())
285           .Attr("send_device_incarnation",
286                 static_cast<int64_t>(device_info().incarnation()))
287           .Attr("client_terminated", true)
288           .Finalize(g, out_node, /*consume=*/true));
289   (*out_node)->set_assigned_device_name(device_info().name());
290   return OkStatus();
291 }
292 
RewriteGraphForExecution(Graph * g,const gtl::ArraySlice<string> & fed_outputs,const gtl::ArraySlice<string> & fetch_outputs,const gtl::ArraySlice<string> & target_node_names,const DeviceAttributes & device_info,bool use_function_convention,RewriteGraphMetadata * out_metadata)293 Status RewriteGraphForExecution(
294     Graph* g, const gtl::ArraySlice<string>& fed_outputs,
295     const gtl::ArraySlice<string>& fetch_outputs,
296     const gtl::ArraySlice<string>& target_node_names,
297     const DeviceAttributes& device_info, bool use_function_convention,
298     RewriteGraphMetadata* out_metadata) {
299   std::vector<std::unique_ptr<PruneRewrite>> feed_rewrites;
300   feed_rewrites.reserve(fed_outputs.size());
301   if (use_function_convention) {
302     for (size_t i = 0; i < fed_outputs.size(); ++i) {
303       feed_rewrites.emplace_back(new ArgFeedRewrite(
304           &fed_outputs[i], &device_info, static_cast<int32>(i)));
305     }
306   } else {
307     for (const string& fed_output : fed_outputs) {
308       feed_rewrites.emplace_back(
309           new RecvFeedRewrite(&fed_output, &device_info));
310     }
311   }
312 
313   std::vector<std::unique_ptr<PruneRewrite>> fetch_rewrites;
314   fetch_rewrites.reserve(fetch_outputs.size());
315   if (use_function_convention) {
316     for (size_t i = 0; i < fetch_outputs.size(); ++i) {
317       fetch_rewrites.emplace_back(new RetvalFetchRewrite(
318           &fetch_outputs[i], &device_info, static_cast<int32>(i)));
319     }
320   } else {
321     for (const string& fetch_output : fetch_outputs) {
322       fetch_rewrites.emplace_back(
323           new SendFetchRewrite(&fetch_output, &device_info));
324     }
325   }
326 
327   return RewriteGraphForExecution(g, feed_rewrites, fetch_rewrites,
328                                   target_node_names, out_metadata);
329 }
330 
331 namespace {
332 template <typename StringContainer>
ConvertToVector(StringContainer field)333 std::vector<string> ConvertToVector(StringContainer field) {
334   return std::vector<string>(field.begin(), field.end());
335 }
336 }  // namespace
337 
RewriteGraphForExecution(Graph * g,const std::vector<std::unique_ptr<PruneRewrite>> & feed_rewrites,const std::vector<std::unique_ptr<PruneRewrite>> & fetch_rewrites,const gtl::ArraySlice<string> & target_node_names,RewriteGraphMetadata * out_metadata)338 Status RewriteGraphForExecution(
339     Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites,
340     const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites,
341     const gtl::ArraySlice<string>& target_node_names,
342     RewriteGraphMetadata* out_metadata) {
343   if (fetch_rewrites.empty() && target_node_names.empty()) {
344     return errors::InvalidArgument(
345         "Must specify at least one target to fetch or execute.");
346   }
347 
348   std::unordered_set<string> endpoints;
349   for (const auto& feed_rewrite : feed_rewrites) {
350     auto result = endpoints.insert(feed_rewrite->endpoint_name());
351     if (!result.second) {
352       return errors::InvalidArgument("Endpoint \"",
353                                      feed_rewrite->endpoint_name(),
354                                      "\" fed more than once.");
355     }
356   }
357 
358   for (const auto& fetch_rewrite : fetch_rewrites) {
359     if (endpoints.count(fetch_rewrite->endpoint_name()) > 0) {
360       return errors::InvalidArgument(fetch_rewrite->endpoint_name(),
361                                      " is both fed and fetched.");
362     }
363   }
364 
365   // A separate index mapping name to Node*, for use by FeedInputs,
366   // FetchOutputs, and PruneForTargets
367   NameIndex name_index;
368   name_index.reserve(g->num_nodes());
369   for (Node* n : g->nodes()) {
370     name_index[n->name()] = n;
371   }
372 
373   // Add the feeds.  This may replace nodes in the graph, including the nodes
374   // currently listed in "fetch_rewrites".  We pass "name_index" so the index is
375   // kept up to date.
376   if (!feed_rewrites.empty()) {
377     TF_RETURN_IF_ERROR(
378         FeedInputs(g, feed_rewrites, &name_index, &out_metadata->feed_types));
379   }
380 
381   // Add the fetch nodes, also updating "name_index".
382   std::vector<Node*> fetch_nodes;
383   if (!fetch_rewrites.empty()) {
384     TF_RETURN_IF_ERROR(FetchOutputs(g, fetch_rewrites, &name_index,
385                                     &fetch_nodes, &out_metadata->fetch_types));
386   }
387 
388   // Prune the graph to only compute what is needed for the fetch nodes and the
389   // target nodes.
390   if (!fetch_nodes.empty() || !target_node_names.empty()) {
391     TF_RETURN_IF_ERROR(
392         PruneForTargets(g, name_index, fetch_nodes, target_node_names));
393   }
394 
395   return OkStatus();
396 }
397 
398 }  // namespace subgraph
399 
400 }  // namespace tensorflow
401