• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/graph_execution_state.h"
17 
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/metrics.h"
27 #include "tensorflow/core/common_runtime/optimization_registry.h"
28 #include "tensorflow/core/common_runtime/placer.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/graph.pb_text.h"
31 #include "tensorflow/core/framework/graph_def_util.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/tensor.pb.h"
34 #include "tensorflow/core/framework/versions.pb.h"
35 #include "tensorflow/core/graph/algorithm.h"
36 #include "tensorflow/core/graph/collective_order.h"
37 #include "tensorflow/core/graph/graph.h"
38 #include "tensorflow/core/graph/graph_constructor.h"
39 #include "tensorflow/core/graph/subgraph.h"
40 #include "tensorflow/core/graph/tensor_id.h"
41 #include "tensorflow/core/graph/validate.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/gtl/flatset.h"
45 #include "tensorflow/core/lib/strings/stringprintf.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/types.h"
48 #include "tensorflow/core/util/device_name_utils.h"
49 #include "tensorflow/core/util/util.h"
50 
51 #ifndef IS_MOBILE_PLATFORM
52 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
53 #include "tensorflow/core/grappler/grappler_item.h"
54 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
55 #endif  // IS_MOBILE_PLATFORM
56 
57 namespace tensorflow {
58 
GraphExecutionState(GraphDef * graph_def,const GraphExecutionStateOptions & options)59 GraphExecutionState::GraphExecutionState(
60     GraphDef* graph_def, const GraphExecutionStateOptions& options)
61     : stateful_placements_(options.stateful_placements),
62       device_set_(options.device_set),
63       session_options_(options.session_options),
64       session_handle_(options.session_handle),
65       flib_def_(new FunctionLibraryDefinition(OpRegistry::Global(),
66                                               graph_def->library())),
67       graph_(nullptr) {
68   // NOTE(mrry): GraphDef does not have a move constructor, so we pass
69   // a non-const pointer and use `Swap()` to transfer the contents
70   // without copying.
71   original_graph_def_.Swap(graph_def);
72   // TODO(mrry): Publish placement visualizations or handle the log
73   // placement option.
74 }
75 
~GraphExecutionState()76 GraphExecutionState::~GraphExecutionState() {
77   node_name_to_cost_id_map_.clear();
78   delete graph_;
79 }
80 
MakeForBaseGraph(GraphDef * graph_def,const GraphExecutionStateOptions & options,std::unique_ptr<GraphExecutionState> * out_state)81 /* static */ Status GraphExecutionState::MakeForBaseGraph(
82     GraphDef* graph_def, const GraphExecutionStateOptions& options,
83     std::unique_ptr<GraphExecutionState>* out_state) {
84 #ifndef __ANDROID__
85   VLOG(4) << "Graph proto is \n" << graph_def->DebugString();
86 #endif  // __ANDROID__
87 
88   std::unique_ptr<GraphExecutionState> ret(
89       new GraphExecutionState(graph_def, options));
90 
91   TF_RETURN_IF_ERROR(
92       AddDefaultAttrsToGraphDef(&ret->original_graph_def_, *ret->flib_def_, 0));
93   // TODO(mrry): Refactor InitBaseGraph() so that we don't have to
94   // pass an empty BuildGraphOptions (that isn't going to be used when
95   // place_pruned_graph is false).
96   if (!ret->session_options_->config.graph_options().place_pruned_graph()) {
97     TF_RETURN_IF_ERROR(ret->InitBaseGraph(BuildGraphOptions()));
98   }
99   *out_state = std::move(ret);
100   return Status::OK();
101 }
102 
MakeForPrunedGraph(const FunctionDefLibrary & func_def_lib,const GraphExecutionStateOptions & options,const GraphDef & graph_def,const BuildGraphOptions & subgraph_options,std::unique_ptr<GraphExecutionState> * out_state,std::unique_ptr<ClientGraph> * out_client_graph)103 /* static */ Status GraphExecutionState::MakeForPrunedGraph(
104     const FunctionDefLibrary& func_def_lib,
105     const GraphExecutionStateOptions& options, const GraphDef& graph_def,
106     const BuildGraphOptions& subgraph_options,
107     std::unique_ptr<GraphExecutionState>* out_state,
108     std::unique_ptr<ClientGraph>* out_client_graph) {
109   DCHECK(options.session_options->config.graph_options().place_pruned_graph());
110   // NOTE(mrry): This makes a copy of `graph_def`, which is
111   // regrettable. We could make `GraphDef` objects sharable between
112   // execution states to optimize pruned graph execution, but since
113   // this case is primarily used for interactive sessions, we make the
114   // bet that graph construction is not performance-critical. (Note
115   // also that the previous version used `Extend()`, which is strictly
116   // more expensive than copying a `GraphDef`.)
117   GraphDef temp(graph_def);
118   std::unique_ptr<GraphExecutionState> ret(
119       new GraphExecutionState(&temp, options));
120   TF_RETURN_IF_ERROR(
121       AddDefaultAttrsToGraphDef(&ret->original_graph_def_, *ret->flib_def_, 0));
122   TF_RETURN_IF_ERROR(ret->InitBaseGraph(subgraph_options));
123   TF_RETURN_IF_ERROR(ret->BuildGraph(subgraph_options, out_client_graph));
124   *out_state = std::move(ret);
125   return Status::OK();
126 }
127 
Extend(const GraphDef & extension_def,std::unique_ptr<GraphExecutionState> * out) const128 Status GraphExecutionState::Extend(
129     const GraphDef& extension_def,
130     std::unique_ptr<GraphExecutionState>* out) const {
131   GraphDef gdef;
132 
133   // 1. Copy the function library.
134   TF_RETURN_IF_ERROR(flib_def_->AddLibrary(extension_def.library()));
135   *gdef.mutable_library() = flib_def_->ToProto();
136 
137   // 2. Build an index of the new node names.
138   std::unordered_set<string> new_names;
139   for (const NodeDef& node : extension_def.node()) {
140     new_names.insert(node.name());
141   }
142 
143   // 3. Add the non-duplicates from the old graph to the new graph.
144   //    Return an error if the same node name appears in both the
145   //    old graph and the extension.
146   for (const NodeDef& node : original_graph_def_.node()) {
147     if (new_names.count(node.name()) == 0) {
148       *gdef.add_node() = node;
149     } else {
150       return errors::InvalidArgument(tensorflow::strings::Printf(
151           "GraphDef argument to Extend includes node '%s', which was created "
152           "by a previous call to Create or Extend in this session.",
153           node.name().c_str()));
154     }
155   }
156 
157   // 4. Merge the versions field.
158   int old_node_size = gdef.node_size();
159   gdef.mutable_node()->MergeFrom(extension_def.node());
160   TF_RETURN_IF_ERROR(
161       AddDefaultAttrsToGraphDef(&gdef, *flib_def_, old_node_size));
162   // Merge versions
163   if (gdef.has_versions()) {
164     if (gdef.versions().producer() != extension_def.versions().producer()) {
165       return errors::InvalidArgument(
166           "Can't extend GraphDef at version ", gdef.versions().producer(),
167           " with graph at version ", extension_def.versions().producer());
168     }
169     VersionDef* versions = gdef.mutable_versions();
170     versions->set_min_consumer(std::max(
171         versions->min_consumer(), extension_def.versions().min_consumer()));
172     if (extension_def.versions().bad_consumers_size()) {
173       // Add new bad_consumers that aren't already marked bad.
174       //
175       // Note: This implementation is quadratic time if there are many calls to
176       // ExtendLocked with many bad consumers.  Since this is unlikely, and
177       // fixing it would require data structures outside of this routine,
178       // quadratic time it is.
179       auto* bad_consumers = versions->mutable_bad_consumers();
180       const std::unordered_set<int> existing(bad_consumers->begin(),
181                                              bad_consumers->end());
182       for (const int v : extension_def.versions().bad_consumers()) {
183         if (existing.find(v) == existing.end()) {
184           bad_consumers->Add(v);
185         }
186       }
187     }
188 
189   } else {
190     gdef.mutable_versions()->CopyFrom(extension_def.versions());
191   }
192 
193   // 5. Validate that the final graphdef is valid.
194   if (gdef.versions().producer() >= 5) {
195     // Validate the graph: we assume that merging two valid graphs
196     // should maintain graph validity.
197     TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *flib_def_));
198   }
199 
200   // 6. Add the extension.
201   GraphExecutionStateOptions combined_options;
202   combined_options.device_set = device_set_;
203   combined_options.session_options = session_options_;
204   combined_options.session_handle = session_handle_;
205   combined_options.stateful_placements = stateful_placements_;
206 
207   // NOTE(mrry): `gdef` is no longer valid after the constructor
208   // executes.
209   std::unique_ptr<GraphExecutionState> new_execution_state(
210       new GraphExecutionState(&gdef, combined_options));
211 
212   TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
213       &new_execution_state->original_graph_def_, *flib_def_, 0));
214   if (!session_options_->config.graph_options().place_pruned_graph()) {
215     // TODO(mrry): Refactor InitBaseGraph() so that we don't have to
216     // pass an empty BuildGraphOptions (that isn't going to be used
217     // when place_pruned_graph is false).
218     TF_RETURN_IF_ERROR(new_execution_state->InitBaseGraph(BuildGraphOptions()));
219   }
220   *out = std::move(new_execution_state);
221 
222   // TODO(mrry): This is likely to be used for non-throughput-sensitive
223   // interactive workloads, but in future we may want to transfer other
224   // parts of the placement and/or cost model.
225   return Status::OK();
226 }
227 
SaveStatefulNodes(Graph * graph)228 void GraphExecutionState::SaveStatefulNodes(Graph* graph) {
229   for (Node* n : graph->nodes()) {
230     if (n->op_def().is_stateful()) {
231       VLOG(2) << "Saving " << n->DebugString();
232       stateful_placements_[n->name()] = n->assigned_device_name();
233     }
234   }
235 }
236 
RestoreStatefulNodes(Graph * graph)237 void GraphExecutionState::RestoreStatefulNodes(Graph* graph) {
238   for (Node* n : graph->nodes()) {
239     if (n->op_def().is_stateful()) {
240       auto iter = stateful_placements_.find(n->name());
241       if (iter != stateful_placements_.end()) {
242         n->set_assigned_device_name(iter->second);
243         VLOG(2) << "Restored " << n->DebugString();
244       }
245     }
246   }
247 }
248 
249 namespace {
250 
251 class TensorConnectionPruneRewrite : public subgraph::PruneRewrite {
252  public:
TensorConnectionPruneRewrite(const string * endpoint_name,NodeBuilder::NodeOut from_tensor)253   TensorConnectionPruneRewrite(const string* endpoint_name,
254                                NodeBuilder::NodeOut from_tensor)
255       : subgraph::PruneRewrite(endpoint_name, nullptr /* device_info */),
256         from_tensor_(std::move(from_tensor)) {}
257 
AddNode(Graph * g,NodeBuilder::NodeOut feed_tensor,Node ** out_node)258   Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
259                  Node** out_node) override {
260     Status s;
261     auto check_no_cycle_fn = [this, feed_tensor, &s](Node* n) {
262       if (n == feed_tensor.node) {
263         s.Update(errors::InvalidArgument(
264             "Requested Tensor connection between nodes \"",
265             feed_tensor.node->name(), "\" and \"", from_tensor_.node->name(),
266             "\" would create a cycle."));
267       }
268     };
269     ReverseDFSFrom(*g, {from_tensor_.node}, std::move(check_no_cycle_fn),
270                    nullptr);
271     TF_RETURN_IF_ERROR(s);
272 
273     TF_RETURN_IF_ERROR(
274         NodeBuilder(strings::StrCat("_identity_", feed_tensor.node->name(), "_",
275                                     feed_tensor.index),
276                     "Identity")
277             .Input(from_tensor_)
278             .Attr("T",
279                   BaseType(from_tensor_.node->output_type(from_tensor_.index)))
280             .Finalize(g, out_node));
281 
282     (*out_node)->set_assigned_device_name(
283         feed_tensor.node->assigned_device_name());
284     return Status::OK();
285   }
286 
287  private:
288   NodeBuilder::NodeOut from_tensor_;
289 };
290 
291 template <class Map>
LookupDevice(const DeviceSet & device_set,const string & tensor_name,const Map & tensor2device,const tensorflow::DeviceAttributes ** out_device_attrs)292 Status LookupDevice(const DeviceSet& device_set, const string& tensor_name,
293                     const Map& tensor2device,
294                     const tensorflow::DeviceAttributes** out_device_attrs) {
295   *out_device_attrs = nullptr;
296   if (tensor2device.empty()) {
297     *out_device_attrs = &device_set.client_device()->attributes();
298     return Status::OK();
299   }
300   const auto it = tensor2device.find(tensor_name);
301   if (it == tensor2device.end()) {
302     *out_device_attrs = &device_set.client_device()->attributes();
303     return Status::OK();
304   }
305   DeviceNameUtils::ParsedName parsed_name;
306   if (!DeviceNameUtils::ParseFullName(it->second, &parsed_name)) {
307     return errors::InvalidArgument("Invalid device name ('", it->second,
308                                    "') provided for the tensor '", tensor_name,
309                                    "' in CallableOptions");
310   }
311   Device* device = device_set.FindDeviceByName(
312       DeviceNameUtils::ParsedNameToString(parsed_name));
313   if (device == nullptr) {
314     return errors::InvalidArgument("Device '", it->second,
315                                    "' specified for tensor '", tensor_name,
316                                    "' in CallableOptions does not exist");
317   }
318   *out_device_attrs = &device->attributes();
319   return Status::OK();
320 }
321 
322 struct TensorAndDevice {
323   // WARNING: backing memory for the 'tensor' field is NOT owend.
324   const TensorId tensor;
325   // WARNING: device pointer is not owned, so must outlive TensorAndDevice.
326   const DeviceAttributes* device;
327 };
328 
329 // Tensors of some DataTypes cannot placed in device memory as feeds or
330 // fetches. Validate against a whitelist of those known to work.
IsFeedAndFetchSupported(DataType dtype,const string & device_type)331 bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) {
332   // The mechanism for supporting feeds of device-backed Tensors requires
333   // the _Arg kernel to be registered for the corresponding type (and that
334   // the input to the kernel be in device and not host memory).
335   //
336   // The mechanism for supporting fetches of device-backed Tensors requires
337   // the _Retval kernel to be registered for the corresponding type (and
338   // that the output is produced in device and not host memory).
339   //
340   // For now, we return true iff there are _Arg AND _Retval kernels for dtype on
341   // the device. False negatives are okay, false positives would be bad.
342   //
343   // TODO(ashankar): Instead of a whitelist here, perhaps we could query
344   // the kernel registry for _Arg and _Retval kernels instead.
345   if (device_type == DEVICE_CPU) return true;
346   if (device_type != DEVICE_GPU) return false;
347   switch (dtype) {
348     case DT_BFLOAT16:
349     case DT_BOOL:
350     case DT_COMPLEX128:
351     case DT_COMPLEX64:
352     case DT_DOUBLE:
353     case DT_FLOAT:
354     case DT_HALF:
355     case DT_INT16:
356     case DT_INT64:
357     case DT_INT8:
358     case DT_UINT16:
359     case DT_UINT8:
360       return true;
361     default:
362       return false;
363   }
364 }
365 
ValidateFeedAndFetchDevices(const Graph & graph,const std::vector<TensorAndDevice> & tensors_and_devices)366 Status ValidateFeedAndFetchDevices(
367     const Graph& graph,
368     const std::vector<TensorAndDevice>& tensors_and_devices) {
369   if (tensors_and_devices.empty()) return Status::OK();
370   std::vector<bool> found(tensors_and_devices.size(), false);
371   for (const Node* node : graph.nodes()) {
372     // Linearly looping through all nodes and then all feed+fetch tensors isn't
373     // quite efficient. At the time of this writing, the expectation was that
374     // tensors_and_devices.size() is really small in practice, so this won't be
375     // problematic.
376     // Revist and make a more efficient lookup possible if needed (e.g., perhaps
377     // Graph can maintain a map from node name to Node*).
378     for (int i = 0; i < tensors_and_devices.size(); ++i) {
379       const TensorAndDevice& td = tensors_and_devices[i];
380       if (td.tensor.first != node->name()) continue;
381       found[i] = true;
382       TF_RETURN_IF_ERROR(graph.IsValidOutputTensor(node, td.tensor.second));
383       const DataType dtype = node->output_type(td.tensor.second);
384       if (!IsFeedAndFetchSupported(dtype, td.device->device_type())) {
385         return errors::Unimplemented(
386             "Cannot feed or fetch tensor '", td.tensor.ToString(),
387             "' from device ", td.device->name(), " as feeding/fetching from ",
388             td.device->device_type(), " devices is not yet supported for ",
389             DataTypeString(dtype), " tensors");
390       }
391     }
392   }
393   for (int i = 0; i < found.size(); ++i) {
394     if (!found[i]) {
395       return errors::InvalidArgument(
396           "Tensor ", tensors_and_devices[i].tensor.ToString(),
397           ", specified in either feed_devices or fetch_devices was not found "
398           "in the Graph");
399     }
400   }
401   return Status::OK();
402 }
403 
GetFeedShapeAndTypeFromAttribute(const NodeDef & node,PartialTensorShape * shape,DataType * type)404 Status GetFeedShapeAndTypeFromAttribute(const NodeDef& node,
405                                         PartialTensorShape* shape,
406                                         DataType* type) {
407   static const gtl::FlatSet<string>* const kHasExplicitShapeAttribute =
408       CHECK_NOTNULL((new gtl::FlatSet<string>{
409           "Placeholder", "PlaceholderV2", "PlaceholderWithDefault",
410           "ParallelConcat", "ImmutableConst", "_ParallelConcatStart",
411           "InfeedDequeue", "OutfeedDequeue", "CollectiveBcastSend",
412           "CollectiveBcastRecv", "AccumulateNV2", "VariableV2", "Variable",
413           "TemporaryVariable", "NcclBroadcast", "_ScopedAllocator",
414           "_ScopedAllocatorConcat"}));
415 
416   // All the node types handled here have their output datatype set in
417   // either attribute 'dtype' or 'T'.
418   if (!GetNodeAttr(node, "dtype", type).ok() &&
419       !GetNodeAttr(node, "T", type).ok()) {
420     return errors::InvalidArgument(
421         "Could not determine output type for feed node: ", node.name(),
422         " of type ", node.op());
423   }
424 
425   // First handle the case of feeding a const node.
426   if (node.op() == "Const" && HasNodeAttr(node, "value")) {
427     *shape =
428         PartialTensorShape(node.attr().at("value").tensor().tensor_shape());
429   } else if (kHasExplicitShapeAttribute->find(node.op()) !=
430              kHasExplicitShapeAttribute->end()) {
431     TF_RETURN_IF_ERROR(GetNodeAttr(node, "shape", shape));
432   } else {
433     return errors::InvalidArgument("Could not determine shape for feed node: ",
434                                    node.name(), " of type ", node.op());
435   }
436   return Status::OK();
437 }
438 
439 }  // namespace
440 
PruneGraph(const BuildGraphOptions & options,Graph * graph,subgraph::RewriteGraphMetadata * out_rewrite_metadata)441 Status GraphExecutionState::PruneGraph(
442     const BuildGraphOptions& options, Graph* graph,
443     subgraph::RewriteGraphMetadata* out_rewrite_metadata) {
444   std::vector<std::unique_ptr<subgraph::PruneRewrite>> feed_rewrites;
445   feed_rewrites.reserve(options.callable_options.feed_size());
446   std::vector<std::unique_ptr<subgraph::PruneRewrite>> fetch_rewrites;
447   fetch_rewrites.reserve(options.callable_options.fetch_size());
448   if (options.use_function_convention) {
449     std::vector<TensorAndDevice> tensors_and_devices;
450     for (int i = 0; i < options.callable_options.feed_size(); ++i) {
451       // WARNING: feed MUST be a reference, since ArgFeedRewrite and
452       // tensors_and_devices holds on to its address.
453       const string& feed = options.callable_options.feed(i);
454       const DeviceAttributes* device_info;
455       TF_RETURN_IF_ERROR(LookupDevice(*device_set_, feed,
456                                       options.callable_options.feed_devices(),
457                                       &device_info));
458       feed_rewrites.emplace_back(
459           new subgraph::ArgFeedRewrite(&feed, device_info, i));
460       tensors_and_devices.push_back({ParseTensorName(feed), device_info});
461     }
462     if (!options.callable_options.fetch_devices().empty() &&
463         !options.callable_options.fetch_skip_sync()) {
464       return errors::Unimplemented(
465           "CallableOptions.fetch_skip_sync = false is not yet implemented. You "
466           "can set it to true instead, but MUST ensure that Device::Sync() is "
467           "invoked on the Device corresponding to the fetched tensor before "
468           "dereferencing the Tensor's memory.");
469     }
470     for (int i = 0; i < options.callable_options.fetch_size(); ++i) {
471       // WARNING: fetch MUST be a reference, since RetvalFetchRewrite and
472       // tensors_and_devices holds on to its address.
473       const string& fetch = options.callable_options.fetch(i);
474       const DeviceAttributes* device_info;
475       TF_RETURN_IF_ERROR(LookupDevice(*device_set_, fetch,
476                                       options.callable_options.fetch_devices(),
477                                       &device_info));
478       fetch_rewrites.emplace_back(
479           new subgraph::RetvalFetchRewrite(&fetch, device_info, i));
480       tensors_and_devices.push_back({ParseTensorName(fetch), device_info});
481     }
482     TF_RETURN_IF_ERROR(
483         ValidateFeedAndFetchDevices(*graph, tensors_and_devices));
484   } else {
485     if (!options.callable_options.feed_devices().empty() ||
486         !options.callable_options.fetch_devices().empty()) {
487       return errors::Unimplemented(
488           "CallableOptions::feed_devices and CallableOptions::fetch_devices "
489           "to configure feeding/fetching tensors to/from device memory is not "
490           "yet supported when using a remote session.");
491     }
492     const DeviceAttributes* device_info =
493         &device_set_->client_device()->attributes();
494     for (const string& feed : options.callable_options.feed()) {
495       feed_rewrites.emplace_back(
496           new subgraph::RecvFeedRewrite(&feed, device_info));
497     }
498     for (const string& fetch : options.callable_options.fetch()) {
499       fetch_rewrites.emplace_back(
500           new subgraph::SendFetchRewrite(&fetch, device_info));
501     }
502   }
503 
504   for (const TensorConnection& tensor_connection :
505        options.callable_options.tensor_connection()) {
506     Node* from_node = nullptr;
507     TensorId from_id(ParseTensorName(tensor_connection.from_tensor()));
508 
509     for (Node* n : graph->nodes()) {
510       if (n->name() == from_id.first) {
511         from_node = n;
512         break;
513       }
514     }
515     if (from_node == nullptr) {
516       return errors::InvalidArgument(
517           "Requested tensor connection from unknown node: \"",
518           tensor_connection.to_tensor(), "\".");
519     }
520     if (from_id.second >= from_node->num_outputs()) {
521       return errors::InvalidArgument(
522           "Requested tensor connection from unknown edge: \"",
523           tensor_connection.to_tensor(),
524           "\" (actual number of outputs = ", from_node->num_outputs(), ").");
525     }
526 
527     feed_rewrites.emplace_back(new TensorConnectionPruneRewrite(
528         &tensor_connection.to_tensor(), {from_node, from_id.second}));
529   }
530 
531   std::vector<string> target_node_names(
532       options.callable_options.target().begin(),
533       options.callable_options.target().end());
534   TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
535       graph, feed_rewrites, fetch_rewrites, target_node_names,
536       out_rewrite_metadata));
537 
538   CHECK_EQ(out_rewrite_metadata->feed_types.size(),
539            options.callable_options.feed_size() +
540                options.callable_options.tensor_connection_size());
541   for (int i = 0; i < options.callable_options.tensor_connection_size(); ++i) {
542     out_rewrite_metadata->feed_types.pop_back();
543   }
544   return Status::OK();
545 }
546 
InitBaseGraph(const BuildGraphOptions & options)547 Status GraphExecutionState::InitBaseGraph(const BuildGraphOptions& options) {
548   const GraphDef* graph_def = &original_graph_def_;
549 
550   std::unique_ptr<Graph> new_graph(new Graph(OpRegistry::Global()));
551   GraphConstructorOptions opts;
552   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, *graph_def, new_graph.get()));
553   if (session_options_ &&
554       session_options_->config.graph_options().place_pruned_graph()) {
555     // Rewrite the graph before placement.
556     rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata);
557     TF_RETURN_IF_ERROR(
558         PruneGraph(options, new_graph.get(), rewrite_metadata_.get()));
559   }
560 
561   // Save stateful placements before placing.
562   RestoreStatefulNodes(new_graph.get());
563 
564   GraphOptimizationPassOptions optimization_options;
565   optimization_options.session_handle = session_handle_;
566   optimization_options.session_options = session_options_;
567   optimization_options.graph = &new_graph;
568   optimization_options.flib_def = flib_def_.get();
569   optimization_options.device_set = device_set_;
570 
571   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
572       OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
573 
574   Placer placer(new_graph.get(), device_set_, /* default_device= */ nullptr,
575                 session_options_ == nullptr ||
576                     session_options_->config.allow_soft_placement(),
577                 session_options_ != nullptr &&
578                     session_options_->config.log_device_placement());
579   // TODO(mrry): Consider making the Placer cancelable.
580   TF_RETURN_IF_ERROR(placer.Run());
581 
582   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
583       OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
584 
585   for (const Node* n : new_graph->nodes()) {
586     VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id();
587     node_name_to_cost_id_map_[n->name()] = n->cost_id();
588   }
589 
590   SaveStatefulNodes(new_graph.get());
591   graph_ = new_graph.release();
592   return Status::OK();
593 }
594 
OptimizeGraph(const BuildGraphOptions & options,std::unique_ptr<Graph> * optimized_graph,std::unique_ptr<FunctionLibraryDefinition> * optimized_flib)595 Status GraphExecutionState::OptimizeGraph(
596     const BuildGraphOptions& options, std::unique_ptr<Graph>* optimized_graph,
597     std::unique_ptr<FunctionLibraryDefinition>* optimized_flib) {
598 #ifndef IS_MOBILE_PLATFORM
599   if (session_options_->config.graph_options().place_pruned_graph()) {
600     return errors::InvalidArgument("Can't optimize a pruned graph");
601   }
602 
603   if (grappler::MetaOptimizerEnabled(session_options_->config)) {
604     grappler::GrapplerItem item;
605     item.id = "tf_graph";
606     graph_->ToGraphDef(&item.graph);
607 
608     // It's ok to skip invalid device annotations in Grappler.
609     Status inferred_devices = item.InferDevicesFromGraph();
610     if (!inferred_devices.ok()) {
611       VLOG(3) << inferred_devices.error_message();
612     }
613 
614     // TODO(b/114748242): Add a unit test to test this bug fix.
615     if (flib_def_) {
616       *item.graph.mutable_library() = flib_def_->ToProto();
617     }
618 
619     item.fetch.insert(item.fetch.end(),
620                       options.callable_options.fetch().begin(),
621                       options.callable_options.fetch().end());
622     item.fetch.insert(item.fetch.end(),
623                       options.callable_options.target().begin(),
624                       options.callable_options.target().end());
625 
626     for (const TensorConnection& tensor_connection :
627          options.callable_options.tensor_connection()) {
628       item.fetch.push_back(tensor_connection.from_tensor());
629     }
630 
631     if (!(options.callable_options.feed().empty() &&
632           options.callable_options.tensor_connection().empty())) {
633       std::unordered_set<string> feeds;
634       for (const string& feed : options.callable_options.feed()) {
635         TensorId id = ParseTensorName(feed);
636         if (id.second != 0) {
637           return errors::InvalidArgument("Unsupported feed: ", feed);
638         }
639         feeds.emplace(id.first);
640       }
641       for (const TensorConnection& tensor_connection :
642            options.callable_options.tensor_connection()) {
643         TensorId id = ParseTensorName(tensor_connection.to_tensor());
644         if (id.second != 0) {
645           return errors::InvalidArgument("Unsupported feed: ",
646                                          tensor_connection.to_tensor());
647         }
648         feeds.emplace(id.first);
649       }
650       for (const NodeDef& node : original_graph_def_.node()) {
651         if (feeds.find(node.name()) == feeds.end()) {
652           continue;
653         }
654         // Get the type and shape of the feed node.
655         PartialTensorShape partial_shape;
656         DataType type;
657         TF_RETURN_IF_ERROR(
658             GetFeedShapeAndTypeFromAttribute(node, &partial_shape, &type));
659         // If the shape of the placeholder is only partially known, we are free
660         // to set unknown dimensions of its shape to any value we desire. We
661         // choose 0 to minimize the memory impact. Note that this only matters
662         // if an optimizer chooses to run the graph.
663         TensorShape shape;
664         if (partial_shape.unknown_rank()) {
665           shape = TensorShape({0});
666         } else {
667           for (int i = 0; i < partial_shape.dims(); ++i) {
668             if (partial_shape.dim_size(i) < 0) {
669               partial_shape.set_dim(i, 0);
670             }
671           }
672           if (!partial_shape.AsTensorShape(&shape)) {
673             return errors::InvalidArgument(
674                 "Could not derive shape for feed node: ", node.DebugString());
675           }
676         }
677 
678         Tensor fake_input(type, shape);
679         item.feed.emplace_back(node.name(), fake_input);
680       }
681     }
682 
683     Device* cpu_device = nullptr;
684     for (const auto& device : device_set_->devices()) {
685       if (device->parsed_name().id == 0 &&
686           StringPiece(device->parsed_name().type) == "CPU" &&
687           device->GetAllocator(AllocatorAttributes()) != nullptr) {
688         cpu_device = device;
689       }
690     }
691     grappler::VirtualCluster cluster(device_set_);
692     GraphDef new_graph;
693     TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(
694         item, session_options_->config, cpu_device, &cluster, &new_graph));
695 
696     // Merge optimized graph function library with an original library.
697     // Optimized graph might have new functions specialized for it's
698     // instantiation context (see Grappler function optimizer), and modified
699     // function body for the existing functions.
700     optimized_flib->reset(new FunctionLibraryDefinition(*flib_def_));
701 
702     for (const FunctionDef& fdef : new_graph.library().function()) {
703       const string& func_name = fdef.signature().name();
704 
705       if ((*optimized_flib)->Contains(func_name)) {
706         VLOG(3) << "Replace function: name=" << func_name;
707         TF_RETURN_IF_ERROR((*optimized_flib)->ReplaceFunction(func_name, fdef));
708       } else {
709         VLOG(3) << "Add new function: name=" << func_name;
710         TF_RETURN_IF_ERROR((*optimized_flib)->AddFunctionDef(fdef));
711       }
712     }
713 
714     optimized_graph->reset(new Graph(OpRegistry::Global()));
715 
716     GraphConstructorOptions opts;
717     opts.allow_internal_ops = true;
718     TF_RETURN_IF_ERROR(
719         ConvertGraphDefToGraph(opts, new_graph, optimized_graph->get()));
720     // The graph conversion sets the requested device names but not the
721     // assigned device names. However, since at this point the graph is placed
722     // TF expects an assigned device name for every node. Therefore we copy
723     // the requested device into the assigned device field.
724     for (Node* node : optimized_graph->get()->nodes()) {
725       node->set_assigned_device_name(node->requested_device());
726     }
727     return Status::OK();
728   } else {
729     return errors::InvalidArgument("Meta Optimizer disabled");
730   }
731 #else
732   return errors::InvalidArgument("Mobile platforms not supported");
733 #endif  // IS_MOBILE_PLATFORM
734 }
735 
BuildGraph(const BuildGraphOptions & options,std::unique_ptr<ClientGraph> * out)736 Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
737                                        std::unique_ptr<ClientGraph>* out) {
738   VLOG(1) << "BuildGraph";
739   const uint64 start_time_usecs = Env::Default()->NowMicros();
740   if (!graph_) {
741     // It is only valid to call this method directly when the original graph
742     // was created with the option `place_pruned_graph == false`.
743     return errors::Internal(
744         "Attempted to prune a graph that has not been fully initialized.");
745   }
746 
747   // Grappler optimization might change the structure of a graph itself, and
748   // also it can add/prune functions to/from the library.
749   std::unique_ptr<Graph> optimized_graph;
750   std::unique_ptr<FunctionLibraryDefinition> optimized_flib;
751 
752   Status s = OptimizeGraph(options, &optimized_graph, &optimized_flib);
753   if (!s.ok()) {
754     VLOG(2) << "Grappler optimization failed. Error: " << s.error_message();
755     // Simply copy the original graph and the function library if we couldn't
756     // optimize it.
757     optimized_graph.reset(new Graph(flib_def_.get()));
758     CopyGraph(*graph_, optimized_graph.get());
759     optimized_flib.reset(new FunctionLibraryDefinition(*flib_def_));
760   }
761 
762   subgraph::RewriteGraphMetadata rewrite_metadata;
763   if (session_options_ == nullptr ||
764       !session_options_->config.graph_options().place_pruned_graph()) {
765     TF_RETURN_IF_ERROR(
766         PruneGraph(options, optimized_graph.get(), &rewrite_metadata));
767   } else {
768     // This GraphExecutionState represents a graph that was
769     // pruned when this was constructed, so we copy the metadata from
770     // a member variable.
771     CHECK(rewrite_metadata_);
772     rewrite_metadata = *rewrite_metadata_;
773   }
774 
775   CHECK_EQ(options.callable_options.feed_size(),
776            rewrite_metadata.feed_types.size());
777   CHECK_EQ(options.callable_options.fetch_size(),
778            rewrite_metadata.fetch_types.size());
779 
780   // TODO(andydavis): Clarify optimization pass requirements around CostModel.
781   GraphOptimizationPassOptions optimization_options;
782   optimization_options.session_options = session_options_;
783   optimization_options.graph = &optimized_graph;
784   optimization_options.flib_def = optimized_flib.get();
785   optimization_options.device_set = device_set_;
786 
787   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
788       OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
789 
790   int64 collective_graph_key = options.collective_graph_key;
791   if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
792     // BuildGraphOptions does not specify a collective_graph_key.  Check all
793     // nodes in the Graph and FunctionLibraryDefinition for collective ops and
794     // if found, initialize a collective_graph_key as a hash of the ordered set
795     // of instance keys.
796     std::set<int32> instance_key_set;
797     for (Node* node : optimized_graph->nodes()) {
798       if (node->IsCollective()) {
799         int32 instance_key;
800         TF_RETURN_IF_ERROR(
801             GetNodeAttr(node->attrs(), "instance_key", &instance_key));
802         instance_key_set.emplace(instance_key);
803       } else {
804         const FunctionDef* fdef = optimized_flib->Find(node->def().op());
805         if (fdef != nullptr) {
806           for (const NodeDef& ndef : fdef->node_def()) {
807             if (ndef.op() == "CollectiveReduce" ||
808                 ndef.op() == "CollectiveBcastSend" ||
809                 ndef.op() == "CollectiveBcastRecv") {
810               int32 instance_key;
811               TF_RETURN_IF_ERROR(
812                   GetNodeAttr(ndef, "instance_key", &instance_key));
813               instance_key_set.emplace(instance_key);
814             }
815           }
816         }
817       }
818     }
819     if (!instance_key_set.empty()) {
820       uint64 hash = 0x8774aa605c729c72ULL;
821       for (int32 instance_key : instance_key_set) {
822         hash = Hash64Combine(instance_key, hash);
823       }
824       collective_graph_key = hash;
825     }
826   }
827 
828   // Make collective execution order deterministic if needed.
829   if (options.collective_order != GraphCollectiveOrder::kNone) {
830     TF_RETURN_IF_ERROR(
831         OrderCollectives(optimized_graph.get(), options.collective_order));
832   }
833 
834   // Copy the extracted graph in order to make its node ids dense,
835   // since the local CostModel used to record its stats is sized by
836   // the largest node id.
837   std::unique_ptr<ClientGraph> dense_copy(
838       new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types,
839                       rewrite_metadata.fetch_types, collective_graph_key));
840   CopyGraph(*optimized_graph, &dense_copy->graph);
841 
842   // TODO(vrv): We should check invariants of the graph here.
843   metrics::UpdateGraphBuildTime(Env::Default()->NowMicros() - start_time_usecs);
844   *out = std::move(dense_copy);
845   return Status::OK();
846 }
847 
848 }  // namespace tensorflow
849