• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/subgraph.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <string>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/node_def_util.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/graph/algorithm.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/graph/graph_constructor.h"
31 #include "tensorflow/core/graph/tensor_id.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/platform/logging.h"
35 
36 namespace tensorflow {
37 namespace subgraph {
38 
39 // ----------------------------------------------------------------------------
40 // Subgraph construction-related routines
41 // ----------------------------------------------------------------------------
42 // TODO(vrv): Profile the unordered_set and unordered_map use in this file to
43 // see if we should use an alternative implementation.
44 
45 namespace {
46 
47 typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameIndex;
48 
49 // Rewrite graph by replacing the output tensors specified in
50 // "fed_outputs" with special feed nodes for each specified output
51 // tensor, and removing any nodes that are now disconnected from the
52 // part of the graph that reaches the sink node.  The set of special
53 // feed nodes added to the graph are returned in "*feed_nodes".
54 //
55 // Return true on success.  On error, return false and sets *error to
56 // an appropriate error message (and *g is left in an indeterminate
57 // state).
FeedInputs(Graph * g,const std::vector<std::unique_ptr<PruneRewrite>> & feed_rewrites,NameIndex * name_index,DataTypeVector * out_feed_types)58 Status FeedInputs(
59     Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites,
60     NameIndex* name_index, DataTypeVector* out_feed_types) {
61   out_feed_types->clear();
62   out_feed_types->reserve(feed_rewrites.size());
63   for (size_t i = 0; i < feed_rewrites.size(); ++i) {
64     const string& t = feed_rewrites[i]->endpoint_name();
65     TensorId id(ParseTensorName(t));
66 
67     auto iter = name_index->find(id.first);
68     if (iter == name_index->end()) {
69       return errors::NotFound("FeedInputs: unable to find feed output ", t);
70     }
71     Node* n = iter->second;
72     DCHECK_EQ(n->name(), id.first);
73     if (id.second >= n->num_outputs()) {
74       return errors::InvalidArgument(
75           "FeedInputs: ", t, " should have output index < ", n->num_outputs());
76     }
77 
78     Node* feed_node;
79     TF_RETURN_IF_ERROR(
80         feed_rewrites[i]->AddNode(g, {n, id.second}, &feed_node));
81 
82     // Update name_index
83     (*name_index)[feed_node->name()] = feed_node;
84     // Duplicate control edges aren't allowed, but feed_node was *just* created
85     // so there's no need to check for a duplicate.
86     g->AddControlEdge(g->source_node(), feed_node, true);
87 
88     // Look through edges coming out of "n" for edges whose src_output() index
89     // matches "output_index".  If found, replace the edges with a connection
90     // from the special feed node.
91     std::vector<const Edge*> to_remove;
92     for (const Edge* e : n->out_edges()) {
93       if (e->src_output() == id.second) {
94         to_remove.emplace_back(e);
95       } else if (e->src_output() == Graph::kControlSlot &&
96                  (n->type_string() == "Placeholder" ||
97                   n->type_string() == "PlaceholderV2")) {
98         // When feeding a Placeholder node, any outgoing control edges
99         // will be replaced with a control edge from the replacement
100         // feed_node.
101         // TODO(josh11b,mrry): Come up with a more elegant way of addressing
102         // the general version of this problem.
103         to_remove.emplace_back(e);
104       }
105     }
106 
107     for (const Edge* e : to_remove) {
108       if (e->src_output() == id.second) {
109         g->AddEdge(feed_node, 0, e->dst(), e->dst_input());
110       } else {
111         CHECK_EQ(Graph::kControlSlot, e->src_output());
112         // Duplicate control edges aren't allowed, but feed_node was *just*
113         // created so there's no need to check for a duplicate.
114         g->AddControlEdge(feed_node, e->dst(), true);
115       }
116       g->RemoveEdge(e);
117     }
118     out_feed_types->push_back(BaseType(n->output_type(id.second)));
119   }
120   return Status::OK();
121 }
122 
FetchOutputs(Graph * g,const std::vector<std::unique_ptr<PruneRewrite>> & fetch_rewrites,NameIndex * name_index,std::vector<Node * > * out_fetch_nodes,DataTypeVector * out_fetch_types)123 Status FetchOutputs(
124     Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites,
125     NameIndex* name_index, std::vector<Node*>* out_fetch_nodes,
126     DataTypeVector* out_fetch_types) {
127   out_fetch_nodes->clear();
128   out_fetch_nodes->reserve(fetch_rewrites.size());
129   for (size_t i = 0; i < fetch_rewrites.size(); ++i) {
130     const string& t = fetch_rewrites[i]->endpoint_name();
131 
132     // Parse t into node_name and output_index.
133     TensorId id(ParseTensorName(t));
134 
135     // Find node in graph with that name.
136     auto iter = name_index->find(id.first);
137     if (iter == name_index->end()) {
138       return errors::NotFound("FetchOutputs node ", t, ": not found");
139     }
140     Node* n = iter->second;
141     DCHECK_EQ(n->name(), id.first);
142     VLOG(2) << "Found fetch node for " << t;
143 
144     // Validate output_index
145     if (n->num_outputs() == 0) {
146       return errors::InvalidArgument(
147           "Tried to fetch data for '", t,
148           "', which produces no output.  To run to a node but not fetch any "
149           "data, pass '",
150           t,
151           "' as an argument to the 'target_node_names' argument of the "
152           "Session::Run API.");
153     } else if (id.second >= n->num_outputs()) {
154       return errors::InvalidArgument("FetchOutputs ", t,
155                                      ": output index too large, must be < ",
156                                      n->num_outputs());
157     }
158 
159     // Create the fetch Node and connect it up
160     Node* fetch_node;
161     TF_RETURN_IF_ERROR(
162         fetch_rewrites[i]->AddNode(g, {n, id.second}, &fetch_node));
163 
164     // Update the index.
165     (*name_index)[fetch_node->name()] = fetch_node;
166 
167     // Duplicate control edges aren't allowed, but fetch_node was *just* created
168     // so there's no need to check for a duplicate.
169     g->AddControlEdge(fetch_node, g->sink_node(), true);
170     out_fetch_nodes->push_back(fetch_node);
171     out_fetch_types->push_back(BaseType(n->output_type(id.second)));
172   }
173 
174   return Status::OK();
175 }
176 
AddNodeToTargets(const string & node_or_tensor_name,const NameIndex & name_index,std::unordered_set<const Node * > * targets)177 bool AddNodeToTargets(const string& node_or_tensor_name,
178                       const NameIndex& name_index,
179                       std::unordered_set<const Node*>* targets) {
180   TensorId id = ParseTensorName(node_or_tensor_name);
181   auto iter = name_index.find(id.first);
182   if (iter == name_index.end()) {
183     return false;
184   }
185   const Node* n = iter->second;
186   CHECK_EQ(n->name(), id.first);
187   targets->insert(n);
188   return true;
189 }
190 
PruneForTargets(Graph * g,const NameIndex & name_index,const std::vector<Node * > & fetch_nodes,const gtl::ArraySlice<string> & target_nodes)191 Status PruneForTargets(Graph* g, const NameIndex& name_index,
192                        const std::vector<Node*>& fetch_nodes,
193                        const gtl::ArraySlice<string>& target_nodes) {
194   string not_found;
195   std::unordered_set<const Node*> targets;
196   for (Node* n : fetch_nodes) {
197     if (!AddNodeToTargets(n->name(), name_index, &targets)) {
198       strings::StrAppend(&not_found, n->name(), " ");
199     }
200   }
201   for (const string& s : target_nodes) {
202     if (!AddNodeToTargets(s, name_index, &targets)) {
203       strings::StrAppend(&not_found, s, " ");
204     }
205   }
206   if (!not_found.empty()) {
207     return errors::NotFound("PruneForTargets: Some target nodes not found: ",
208                             not_found);
209   }
210   PruneForReverseReachability(g, targets);
211 
212   // Reconnect nodes with no outgoing edges to the sink node
213   FixupSourceAndSinkEdges(g);
214 
215   return Status::OK();
216 }
217 
218 }  // namespace
219 
AddNode(Graph * g,NodeBuilder::NodeOut feed_tensor,Node ** out_node)220 Status ArgFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
221                                Node** out_node) {
222   // NOTE(mrry): We must include the index as part of the node
223   // name, because _Arg is a "stateful" kernel and therefore
224   // its name must uniquely identify a kernel instance across all
225   // graphs in the same session.
226   TF_RETURN_IF_ERROR(
227       NodeBuilder(strings::StrCat("_arg_", feed_tensor.node->name(), "_",
228                                   feed_tensor.index, "_", arg_index_),
229                   "_Arg")
230           .Attr("T", BaseType(feed_tensor.node->output_type(feed_tensor.index)))
231           .Attr("index", arg_index_)
232           .Finalize(g, out_node));
233   (*out_node)->set_assigned_device_name(device_info().name());
234   return Status::OK();
235 }
236 
AddNode(Graph * g,NodeBuilder::NodeOut feed_tensor,Node ** out_node)237 Status RecvFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
238                                 Node** out_node) {
239   TF_RETURN_IF_ERROR(
240       NodeBuilder(strings::StrCat("_recv_", feed_tensor.node->name(), "_",
241                                   feed_tensor.index),
242                   "_Recv")
243           .Attr("tensor_type",
244                 BaseType(feed_tensor.node->output_type(feed_tensor.index)))
245           .Attr("tensor_name", endpoint_name())
246           .Attr("send_device", device_info().name())
247           .Attr("recv_device", device_info().name())
248           .Attr("send_device_incarnation",
249                 static_cast<int64>(device_info().incarnation()))
250           .Attr("client_terminated", true)
251           .Finalize(g, out_node));
252 
253   (*out_node)->set_assigned_device_name(device_info().name());
254   return Status::OK();
255 }
256 
AddNode(Graph * g,NodeBuilder::NodeOut fetch_tensor,Node ** out_node)257 Status RetvalFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
258                                    Node** out_node) {
259   // NOTE(mrry): We must include the index as part of the node
260   // name, because _Retval is a "stateful" kernel and therefore
261   // its name must uniquely identify a kernel instance across all
262   // graphs in the same session.
263   TF_RETURN_IF_ERROR(
264       NodeBuilder(strings::StrCat("_retval_", fetch_tensor.node->name(), "_",
265                                   fetch_tensor.index, "_", retval_index_),
266                   "_Retval")
267           .Input(fetch_tensor.node, fetch_tensor.index)
268           .Attr("T",
269                 BaseType(fetch_tensor.node->output_type(fetch_tensor.index)))
270           .Attr("index", retval_index_)
271           .Finalize(g, out_node));
272   (*out_node)->set_assigned_device_name(device_info().name());
273   return Status::OK();
274 }
275 
AddNode(Graph * g,NodeBuilder::NodeOut fetch_tensor,Node ** out_node)276 Status SendFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
277                                  Node** out_node) {
278   TF_RETURN_IF_ERROR(
279       NodeBuilder(strings::StrCat("_send_", fetch_tensor.node->name(), "_",
280                                   fetch_tensor.index),
281                   "_Send")
282           .Input(fetch_tensor.node, fetch_tensor.index)
283           .Attr("tensor_name", endpoint_name())
284           .Attr("send_device", device_info().name())
285           .Attr("recv_device", device_info().name())
286           .Attr("send_device_incarnation",
287                 static_cast<int64>(device_info().incarnation()))
288           .Attr("client_terminated", true)
289           .Finalize(g, out_node));
290   (*out_node)->set_assigned_device_name(device_info().name());
291   return Status::OK();
292 }
293 
RewriteGraphForExecution(Graph * g,const gtl::ArraySlice<string> & fed_outputs,const gtl::ArraySlice<string> & fetch_outputs,const gtl::ArraySlice<string> & target_node_names,const DeviceAttributes & device_info,bool use_function_convention,RewriteGraphMetadata * out_metadata)294 Status RewriteGraphForExecution(
295     Graph* g, const gtl::ArraySlice<string>& fed_outputs,
296     const gtl::ArraySlice<string>& fetch_outputs,
297     const gtl::ArraySlice<string>& target_node_names,
298     const DeviceAttributes& device_info, bool use_function_convention,
299     RewriteGraphMetadata* out_metadata) {
300   std::vector<std::unique_ptr<PruneRewrite>> feed_rewrites;
301   feed_rewrites.reserve(fed_outputs.size());
302   if (use_function_convention) {
303     for (size_t i = 0; i < fed_outputs.size(); ++i) {
304       feed_rewrites.emplace_back(new ArgFeedRewrite(
305           &fed_outputs[i], &device_info, static_cast<int32>(i)));
306     }
307   } else {
308     for (const string& fed_output : fed_outputs) {
309       feed_rewrites.emplace_back(
310           new RecvFeedRewrite(&fed_output, &device_info));
311     }
312   }
313 
314   std::vector<std::unique_ptr<PruneRewrite>> fetch_rewrites;
315   fetch_rewrites.reserve(fetch_outputs.size());
316   if (use_function_convention) {
317     for (size_t i = 0; i < fetch_outputs.size(); ++i) {
318       fetch_rewrites.emplace_back(new RetvalFetchRewrite(
319           &fetch_outputs[i], &device_info, static_cast<int32>(i)));
320     }
321   } else {
322     for (const string& fetch_output : fetch_outputs) {
323       fetch_rewrites.emplace_back(
324           new SendFetchRewrite(&fetch_output, &device_info));
325     }
326   }
327 
328   return RewriteGraphForExecution(g, feed_rewrites, fetch_rewrites,
329                                   target_node_names, out_metadata);
330 }
331 
332 namespace {
333 template <typename StringContainer>
ConvertToVector(StringContainer field)334 std::vector<string> ConvertToVector(StringContainer field) {
335   return std::vector<string>(field.begin(), field.end());
336 }
337 }  // namespace
338 
RewriteGraphForExecution(Graph * g,const std::vector<std::unique_ptr<PruneRewrite>> & feed_rewrites,const std::vector<std::unique_ptr<PruneRewrite>> & fetch_rewrites,const gtl::ArraySlice<string> & target_node_names,RewriteGraphMetadata * out_metadata)339 Status RewriteGraphForExecution(
340     Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites,
341     const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites,
342     const gtl::ArraySlice<string>& target_node_names,
343     RewriteGraphMetadata* out_metadata) {
344   if (fetch_rewrites.empty() && target_node_names.empty()) {
345     return errors::InvalidArgument(
346         "Must specify at least one target to fetch or execute.");
347   }
348 
349   std::unordered_set<string> endpoints;
350   for (const auto& feed_rewrite : feed_rewrites) {
351     auto result = endpoints.insert(feed_rewrite->endpoint_name());
352     if (!result.second) {
353       return errors::InvalidArgument("Endpoint \"",
354                                      feed_rewrite->endpoint_name(),
355                                      "\" fed more than once.");
356     }
357   }
358 
359   for (const auto& fetch_rewrite : fetch_rewrites) {
360     if (endpoints.count(fetch_rewrite->endpoint_name()) > 0) {
361       return errors::InvalidArgument(fetch_rewrite->endpoint_name(),
362                                      " is both fed and fetched.");
363     }
364   }
365 
366   // A separate index mapping name to Node*, for use by FeedInputs,
367   // FetchOutputs, and PruneForTargets
368   NameIndex name_index;
369   name_index.reserve(g->num_nodes());
370   for (Node* n : g->nodes()) {
371     name_index[n->name()] = n;
372   }
373 
374   // Add the feeds.  This may replace nodes in the graph, including the nodes
375   // currently listed in "fetch_rewrites".  We pass "name_index" so the index is
376   // kept up to date.
377   if (!feed_rewrites.empty()) {
378     TF_RETURN_IF_ERROR(
379         FeedInputs(g, feed_rewrites, &name_index, &out_metadata->feed_types));
380   }
381 
382   // Add the fetch nodes, also updating "name_index".
383   std::vector<Node*> fetch_nodes;
384   if (!fetch_rewrites.empty()) {
385     TF_RETURN_IF_ERROR(FetchOutputs(g, fetch_rewrites, &name_index,
386                                     &fetch_nodes, &out_metadata->fetch_types));
387   }
388 
389   // Prune the graph to only compute what is needed for the fetch nodes and the
390   // target nodes.
391   if (!fetch_nodes.empty() || !target_node_names.empty()) {
392     TF_RETURN_IF_ERROR(
393         PruneForTargets(g, name_index, fetch_nodes, target_node_names));
394   }
395 
396   return Status::OK();
397 }
398 
399 }  // namespace subgraph
400 
401 }  // namespace tensorflow
402