• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/graph_execution_state.h"
17 
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/memory/memory.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/graph_constructor.h"
30 #include "tensorflow/core/common_runtime/optimization_registry.h"
31 #include "tensorflow/core/common_runtime/placer.h"
32 #include "tensorflow/core/framework/attr_value.pb.h"
33 #include "tensorflow/core/framework/device_factory.h"
34 #include "tensorflow/core/framework/function.h"
35 #include "tensorflow/core/framework/function.pb.h"
36 #include "tensorflow/core/framework/graph.pb.h"
37 #include "tensorflow/core/framework/graph_def_util.h"
38 #include "tensorflow/core/framework/metrics.h"
39 #include "tensorflow/core/framework/node_def.pb.h"
40 #include "tensorflow/core/framework/op.h"
41 #include "tensorflow/core/framework/tensor.pb.h"
42 #include "tensorflow/core/framework/versions.pb.h"
43 #include "tensorflow/core/graph/algorithm.h"
44 #include "tensorflow/core/graph/collective_order.h"
45 #include "tensorflow/core/graph/graph.h"
46 #include "tensorflow/core/graph/subgraph.h"
47 #include "tensorflow/core/graph/tensor_id.h"
48 #include "tensorflow/core/graph/validate.h"
49 #include "tensorflow/core/lib/core/errors.h"
50 #include "tensorflow/core/lib/core/status.h"
51 #include "tensorflow/core/lib/gtl/flatset.h"
52 #include "tensorflow/core/lib/strings/stringprintf.h"
53 #include "tensorflow/core/platform/logging.h"
54 #include "tensorflow/core/platform/types.h"
55 #include "tensorflow/core/util/device_name_utils.h"
56 #include "tensorflow/core/util/util.h"
57 
58 #ifndef IS_MOBILE_PLATFORM
59 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
60 #include "tensorflow/core/grappler/grappler_item.h"
61 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
62 #endif  // IS_MOBILE_PLATFORM
63 
64 namespace tensorflow {
65 
66 namespace {
IsCollectiveV2(const string & op)67 bool IsCollectiveV2(const string& op) {
68   return op == "CollectiveReduceV2" || op == "CollectiveGatherV2" ||
69          op == "CollectiveBcastRecvV2" || op == "CollectiveBcastSendV2";
70 }
71 }  // namespace
72 
GraphExecutionState(std::unique_ptr<GraphDef> && graph_def,std::unique_ptr<FunctionLibraryDefinition> && flib_def,const GraphExecutionStateOptions & options)73 GraphExecutionState::GraphExecutionState(
74     std::unique_ptr<GraphDef>&& graph_def,
75     std::unique_ptr<FunctionLibraryDefinition>&& flib_def,
76     const GraphExecutionStateOptions& options)
77     : stateful_placements_(options.stateful_placements),
78       original_graph_def_(std::move(graph_def)),
79       device_set_(options.device_set),
80       session_options_(options.session_options),
81       session_handle_(options.session_handle),
82       flib_def_(std::move(flib_def)),
83       graph_(nullptr) {}
84 
~GraphExecutionState()85 GraphExecutionState::~GraphExecutionState() {
86   node_name_to_cost_id_map_.clear();
87   delete graph_;
88 }
89 
MakeForBaseGraph(GraphDef && graph_def,const GraphExecutionStateOptions & options,std::unique_ptr<GraphExecutionState> * out_state)90 /* static */ Status GraphExecutionState::MakeForBaseGraph(
91     GraphDef&& graph_def, const GraphExecutionStateOptions& options,
92     std::unique_ptr<GraphExecutionState>* out_state) {
93 #ifndef __ANDROID__
94   VLOG(4) << "Graph proto is \n" << graph_def.DebugString();
95 #endif  // __ANDROID__
96 
97   auto flib_def = std::make_unique<FunctionLibraryDefinition>(
98       OpRegistry::Global(), graph_def.library());
99 
100   TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&graph_def, *flib_def, 0));
101 
102   if (options.session_options->config.graph_options().place_pruned_graph() ||
103       !options.session_options->config.experimental()
104            .optimize_for_static_graph()) {
105     auto ret = absl::WrapUnique(new GraphExecutionState(
106         std::make_unique<GraphDef>(std::move(graph_def)), std::move(flib_def),
107         options));
108 
109     // When place_pruned_graph is true, a different Graph* will be initialized
110     // each time we prune the original graph, so there is no need to
111     // construct a Graph* in this case.
112     if (!options.session_options->config.graph_options().place_pruned_graph()) {
113       auto base_graph = std::make_unique<Graph>(OpRegistry::Global());
114       TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *ret->original_graph_def_,
115                                                 base_graph.get()));
116       TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph)));
117     }
118     *out_state = std::move(ret);
119   } else {
120     auto ret = absl::WrapUnique(
121         new GraphExecutionState(nullptr, std::move(flib_def), options));
122     auto base_graph = std::make_unique<Graph>(OpRegistry::Global());
123     TF_RETURN_IF_ERROR(
124         ConvertGraphDefToGraph({}, std::move(graph_def), base_graph.get()));
125     TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph)));
126     *out_state = std::move(ret);
127   }
128   return OkStatus();
129 }
130 
MakeForPrunedGraph(const GraphExecutionState & base_execution_state,const GraphExecutionStateOptions & options,const BuildGraphOptions & subgraph_options,std::unique_ptr<GraphExecutionState> * out_state,std::unique_ptr<ClientGraph> * out_client_graph)131 /* static */ Status GraphExecutionState::MakeForPrunedGraph(
132     const GraphExecutionState& base_execution_state,
133     const GraphExecutionStateOptions& options,
134     const BuildGraphOptions& subgraph_options,
135     std::unique_ptr<GraphExecutionState>* out_state,
136     std::unique_ptr<ClientGraph>* out_client_graph) {
137   if (!(base_execution_state.session_options_->config.graph_options()
138             .place_pruned_graph() &&
139         options.session_options->config.graph_options().place_pruned_graph())) {
140     return errors::Internal(
141         "MakeForPrunedGraph is only supported when the `place_pruned_graph` "
142         "option is true.");
143   }
144   if (!base_execution_state.original_graph_def_) {
145     // NOTE(mrry): By adding this restriction, which matches the only current
146     // usage of this (fairly obscure) method, we do not need to store a
147     // redundant copy of the original graph in `*out_state`.
148     return errors::Internal(
149         "MakeForPrunedGraph is only supported when `base_execution_state` is "
150         "the Session-level `GraphExecutionState`.");
151   }
152 
153   // NOTE(mrry): This makes a copy of `graph_def`, which is
154   // regrettable. We could make `GraphDef` objects shareable between
155   // execution states to optimize pruned graph execution, but since
156   // this case is primarily used for interactive sessions, we make the
157   // bet that graph construction is not performance-critical. (Note
158   // also that the previous version used `Extend()`, which is strictly
159   // more expensive than copying a `GraphDef`.)
160   GraphDef temp(*base_execution_state.original_graph_def_);
161   auto flib_def = std::make_unique<FunctionLibraryDefinition>(
162       OpRegistry::Global(), temp.library());
163   TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&temp, *flib_def, 0));
164   auto ret = absl::WrapUnique(
165       new GraphExecutionState(nullptr, std::move(flib_def), options));
166 
167   auto base_graph = std::make_unique<Graph>(OpRegistry::Global());
168   TF_RETURN_IF_ERROR(
169       ConvertGraphDefToGraph({}, std::move(temp), base_graph.get()));
170 
171   // Rewrite the graph before placement.
172   ret->rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata);
173   TF_RETURN_IF_ERROR(ret->PruneGraph(subgraph_options, base_graph.get(),
174                                      ret->rewrite_metadata_.get()));
175   TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph)));
176   TF_RETURN_IF_ERROR(ret->BuildGraph(subgraph_options, out_client_graph));
177   *out_state = std::move(ret);
178   return OkStatus();
179 }
180 
Extend(const GraphDef & extension_def,std::unique_ptr<GraphExecutionState> * out) const181 Status GraphExecutionState::Extend(
182     const GraphDef& extension_def,
183     std::unique_ptr<GraphExecutionState>* out) const {
184   if (session_options_->config.experimental().optimize_for_static_graph()) {
185     return errors::FailedPrecondition(
186         "Extending the graph is not supported when "
187         "`optimize_for_static_graph` is true.");
188   }
189 
190   GraphDef gdef;
191 
192   // 1. Copy the function library.
193   TF_RETURN_IF_ERROR(flib_def_->AddLibrary(extension_def.library()));
194   *gdef.mutable_library() = flib_def_->ToProto();
195 
196   // 2. Build an index of the new node names.
197   std::unordered_set<string> new_names;
198   for (const NodeDef& node : extension_def.node()) {
199     new_names.insert(node.name());
200   }
201 
202   // 3. Add the non-duplicates from the old graph to the new graph.
203   //    Return an error if the same node name appears in both the
204   //    old graph and the extension.
205   for (const NodeDef& node : original_graph_def_->node()) {
206     if (new_names.count(node.name()) == 0) {
207       *gdef.add_node() = node;
208     } else {
209       return errors::InvalidArgument(
210           "GraphDef argument to Extend includes node '", node.name(),
211           "', which was created by a previous call to Create or Extend in this "
212           "session.");
213     }
214   }
215 
216   // 4. Merge the versions field.
217   int old_node_size = gdef.node_size();
218   gdef.mutable_node()->MergeFrom(extension_def.node());
219   TF_RETURN_IF_ERROR(
220       AddDefaultAttrsToGraphDef(&gdef, *flib_def_, old_node_size));
221   // Merge versions
222   if (gdef.has_versions()) {
223     if (gdef.versions().producer() != extension_def.versions().producer()) {
224       return errors::InvalidArgument(
225           "Can't extend GraphDef at version ", gdef.versions().producer(),
226           " with graph at version ", extension_def.versions().producer());
227     }
228     VersionDef* versions = gdef.mutable_versions();
229     versions->set_min_consumer(std::max(
230         versions->min_consumer(), extension_def.versions().min_consumer()));
231     if (extension_def.versions().bad_consumers_size()) {
232       // Add new bad_consumers that aren't already marked bad.
233       //
234       // Note: This implementation is quadratic time if there are many calls to
235       // ExtendLocked with many bad consumers.  Since this is unlikely, and
236       // fixing it would require data structures outside of this routine,
237       // quadratic time it is.
238       auto* bad_consumers = versions->mutable_bad_consumers();
239       const std::unordered_set<int> existing(bad_consumers->begin(),
240                                              bad_consumers->end());
241       for (const int v : extension_def.versions().bad_consumers()) {
242         if (existing.find(v) == existing.end()) {
243           bad_consumers->Add(v);
244         }
245       }
246     }
247 
248   } else {
249     gdef.mutable_versions()->CopyFrom(extension_def.versions());
250   }
251 
252   // 5. Validate that the final graphdef is valid.
253   if (gdef.versions().producer() >= 5) {
254     // Validate the graph: we assume that merging two valid graphs
255     // should maintain graph validity.
256     TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *flib_def_));
257   }
258 
259   // 6. Add the extension.
260   GraphExecutionStateOptions combined_options;
261   combined_options.device_set = device_set_;
262   combined_options.session_options = session_options_;
263   combined_options.session_handle = session_handle_;
264   combined_options.stateful_placements = stateful_placements_;
265 
266   TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&gdef, *flib_def_, 0));
267   auto flib_def = std::make_unique<FunctionLibraryDefinition>(
268       OpRegistry::Global(), gdef.library());
269   auto new_execution_state = absl::WrapUnique(
270       new GraphExecutionState(std::make_unique<GraphDef>(std::move(gdef)),
271                               std::move(flib_def), combined_options));
272 
273   if (!session_options_->config.graph_options().place_pruned_graph()) {
274     auto base_graph = std::make_unique<Graph>(OpRegistry::Global());
275     TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
276         {}, *new_execution_state->original_graph_def_, base_graph.get()));
277     TF_RETURN_IF_ERROR(
278         new_execution_state->InitBaseGraph(std::move(base_graph)));
279   }
280   *out = std::move(new_execution_state);
281 
282   // NOTE(mrry): Extend() is likely to be used for non-throughput-sensitive
283   // interactive workloads, but in future we may want to transfer other
284   // parts of the placement and/or cost model.
285   return OkStatus();
286 }
287 
SaveStatefulNodes(Graph * graph)288 void GraphExecutionState::SaveStatefulNodes(Graph* graph) {
289   for (Node* n : graph->nodes()) {
290     if (n->op_def().is_stateful()) {
291       VLOG(2) << "Saving " << n->DebugString();
292       stateful_placements_[n->name()] = n->assigned_device_name();
293     }
294   }
295 }
296 
RestoreStatefulNodes(Graph * graph)297 void GraphExecutionState::RestoreStatefulNodes(Graph* graph) {
298   for (Node* n : graph->nodes()) {
299     if (n->op_def().is_stateful()) {
300       auto iter = stateful_placements_.find(n->name());
301       if (iter != stateful_placements_.end()) {
302         n->set_assigned_device_name(iter->second);
303         VLOG(2) << "Restored " << n->DebugString();
304       }
305     }
306   }
307 }
308 
309 namespace {
310 
311 class TensorConnectionPruneRewrite : public subgraph::PruneRewrite {
312  public:
TensorConnectionPruneRewrite(const string * endpoint_name,NodeBuilder::NodeOut from_tensor)313   TensorConnectionPruneRewrite(const string* endpoint_name,
314                                NodeBuilder::NodeOut from_tensor)
315       : subgraph::PruneRewrite(endpoint_name, nullptr /* device_info */),
316         from_tensor_(std::move(from_tensor)) {}
317 
AddNode(Graph * g,NodeBuilder::NodeOut feed_tensor,Node ** out_node)318   Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
319                  Node** out_node) override {
320     Status s;
321     auto check_no_cycle_fn = [this, feed_tensor, &s](Node* n) {
322       if (n == feed_tensor.node) {
323         s.Update(errors::InvalidArgument(
324             "Requested Tensor connection between nodes \"",
325             feed_tensor.node->name(), "\" and \"", from_tensor_.node->name(),
326             "\" would create a cycle."));
327       }
328     };
329     ReverseDFSFrom(*g, {from_tensor_.node}, std::move(check_no_cycle_fn),
330                    nullptr);
331     TF_RETURN_IF_ERROR(s);
332 
333     TF_RETURN_IF_ERROR(
334         NodeBuilder(strings::StrCat("_identity_", feed_tensor.node->name(), "_",
335                                     feed_tensor.index),
336                     "Identity")
337             .Input(from_tensor_)
338             .Attr("T",
339                   BaseType(from_tensor_.node->output_type(from_tensor_.index)))
340             .Finalize(g, out_node));
341 
342     (*out_node)->set_assigned_device_name(
343         feed_tensor.node->assigned_device_name());
344     return OkStatus();
345   }
346 
347  private:
348   NodeBuilder::NodeOut from_tensor_;
349 };
350 
351 template <class Map>
LookupDevice(const DeviceSet & device_set,const string & tensor_name,const Map & tensor2device,const tensorflow::DeviceAttributes ** out_device_attrs)352 Status LookupDevice(const DeviceSet& device_set, const string& tensor_name,
353                     const Map& tensor2device,
354                     const tensorflow::DeviceAttributes** out_device_attrs) {
355   *out_device_attrs = nullptr;
356   if (tensor2device.empty()) {
357     *out_device_attrs = &device_set.client_device()->attributes();
358     return OkStatus();
359   }
360   const auto it = tensor2device.find(tensor_name);
361   if (it == tensor2device.end()) {
362     *out_device_attrs = &device_set.client_device()->attributes();
363     return OkStatus();
364   }
365   DeviceNameUtils::ParsedName parsed_name;
366   if (!DeviceNameUtils::ParseFullName(it->second, &parsed_name)) {
367     return errors::InvalidArgument("Invalid device name ('", it->second,
368                                    "') provided for the tensor '", tensor_name,
369                                    "' in CallableOptions");
370   }
371   Device* device = device_set.FindDeviceByName(
372       DeviceNameUtils::ParsedNameToString(parsed_name));
373   if (device == nullptr) {
374     return errors::InvalidArgument("Device '", it->second,
375                                    "' specified for tensor '", tensor_name,
376                                    "' in CallableOptions does not exist");
377   }
378   *out_device_attrs = &device->attributes();
379   return OkStatus();
380 }
381 
382 struct TensorAndDevice {
383   // WARNING: backing memory for the 'tensor' field is NOT owend.
384   const TensorId tensor;
385   // WARNING: device pointer is not owned, so must outlive TensorAndDevice.
386   const DeviceAttributes* device;
387 };
388 
389 // Tensors of some DataTypes cannot placed in device memory as feeds or
390 // fetches. Validate against a allowlist of those known to work.
IsFeedAndFetchSupported(DataType dtype,const string & device_type)391 bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) {
392   // The mechanism for supporting feeds of device-backed Tensors requires
393   // the _Arg kernel to be registered for the corresponding type (and that
394   // the input to the kernel be in device and not host memory).
395   //
396   // The mechanism for supporting fetches of device-backed Tensors requires
397   // the _Retval kernel to be registered for the corresponding type (and
398   // that the output is produced in device and not host memory).
399   //
400   // For now, we return true iff there are _Arg AND _Retval kernels for dtype on
401   // the device. False negatives are okay, false positives would be bad.
402   //
403   // TODO(ashankar): Instead of a allowlist here, perhaps we could query
404   // the kernel registry for _Arg and _Retval kernels instead.
405   if (device_type == DEVICE_CPU) return true;
406   if (device_type != DEVICE_GPU &&
407       !DeviceFactory::IsPluggableDevice(device_type))
408     return false;
409   switch (dtype) {
410     case DT_BFLOAT16:
411     case DT_BOOL:
412     case DT_COMPLEX128:
413     case DT_COMPLEX64:
414     case DT_DOUBLE:
415     case DT_FLOAT:
416     case DT_HALF:
417     case DT_INT16:
418     case DT_INT64:
419     case DT_INT8:
420     case DT_UINT16:
421     case DT_UINT8:
422       return true;
423     default:
424       return false;
425   }
426 }
427 
ValidateFeedAndFetchDevices(const Graph & graph,const std::vector<TensorAndDevice> & tensors_and_devices)428 Status ValidateFeedAndFetchDevices(
429     const Graph& graph,
430     const std::vector<TensorAndDevice>& tensors_and_devices) {
431   if (tensors_and_devices.empty()) return OkStatus();
432   std::vector<bool> found(tensors_and_devices.size(), false);
433   for (const Node* node : graph.nodes()) {
434     // Linearly looping through all nodes and then all feed+fetch tensors isn't
435     // quite efficient. At the time of this writing, the expectation was that
436     // tensors_and_devices.size() is really small in practice, so this won't be
437     // problematic.
438     // Revist and make a more efficient lookup possible if needed (e.g., perhaps
439     // Graph can maintain a map from node name to Node*).
440     for (int i = 0; i < tensors_and_devices.size(); ++i) {
441       const TensorAndDevice& td = tensors_and_devices[i];
442       if (td.tensor.first != node->name()) continue;
443       found[i] = true;
444       TF_RETURN_IF_ERROR(graph.IsValidOutputTensor(node, td.tensor.second));
445       const DataType dtype = node->output_type(td.tensor.second);
446       if (!IsFeedAndFetchSupported(dtype, td.device->device_type())) {
447         return errors::Unimplemented(
448             "Cannot feed or fetch tensor '", td.tensor.ToString(),
449             "' from device ", td.device->name(), " as feeding/fetching from ",
450             td.device->device_type(), " devices is not yet supported for ",
451             DataTypeString(dtype), " tensors");
452       }
453     }
454   }
455   for (int i = 0; i < found.size(); ++i) {
456     if (!found[i]) {
457       return errors::InvalidArgument(
458           "Tensor ", tensors_and_devices[i].tensor.ToString(),
459           ", specified in either feed_devices or fetch_devices was not found "
460           "in the Graph");
461     }
462   }
463   return OkStatus();
464 }
465 
GetFeedShapeAndTypeFromAttribute(const NodeDef & node,PartialTensorShape * shape,DataType * type)466 Status GetFeedShapeAndTypeFromAttribute(const NodeDef& node,
467                                         PartialTensorShape* shape,
468                                         DataType* type) {
469   static const gtl::FlatSet<string>* const kHasExplicitShapeAttribute =
470       CHECK_NOTNULL((new gtl::FlatSet<string>{
471           "Placeholder", "PlaceholderV2", "PlaceholderWithDefault",
472           "ParallelConcat", "ImmutableConst", "_ParallelConcatStart",
473           "InfeedDequeue", "OutfeedDequeue", "CollectiveBcastSend",
474           "CollectiveBcastRecv", "AccumulateNV2", "VariableV2", "Variable",
475           "TemporaryVariable", "NcclBroadcast", "_ScopedAllocator",
476           "_ScopedAllocatorConcat"}));
477 
478   // All the node types handled here have their output datatype set in
479   // either attribute 'dtype' or 'T'.
480   if (!TryGetNodeAttr(node, "dtype", type) &&
481       !TryGetNodeAttr(node, "T", type)) {
482     return errors::InvalidArgument(
483         "Could not determine output type for feed node: ", node.name(),
484         " of type ", node.op());
485   }
486 
487   // First handle the case of feeding a const node.
488   if (node.op() == "Const" && HasNodeAttr(node, "value")) {
489     *shape =
490         PartialTensorShape(node.attr().at("value").tensor().tensor_shape());
491   } else if (kHasExplicitShapeAttribute->find(node.op()) !=
492              kHasExplicitShapeAttribute->end()) {
493     TF_RETURN_IF_ERROR(GetNodeAttr(node, "shape", shape));
494   } else {
495     return errors::InvalidArgument("Could not determine shape for feed node: ",
496                                    node.name(), " of type ", node.op());
497   }
498   return OkStatus();
499 }
500 
501 }  // namespace
502 
PruneGraph(const BuildGraphOptions & options,Graph * graph,subgraph::RewriteGraphMetadata * out_rewrite_metadata)503 Status GraphExecutionState::PruneGraph(
504     const BuildGraphOptions& options, Graph* graph,
505     subgraph::RewriteGraphMetadata* out_rewrite_metadata) {
506   std::vector<std::unique_ptr<subgraph::PruneRewrite>> feed_rewrites;
507   feed_rewrites.reserve(options.callable_options.feed_size());
508   std::vector<std::unique_ptr<subgraph::PruneRewrite>> fetch_rewrites;
509   fetch_rewrites.reserve(options.callable_options.fetch_size());
510   if (options.use_function_convention) {
511     std::vector<TensorAndDevice> tensors_and_devices;
512     for (int i = 0; i < options.callable_options.feed_size(); ++i) {
513       // WARNING: feed MUST be a reference, since ArgFeedRewrite and
514       // tensors_and_devices holds on to its address.
515       const string& feed = options.callable_options.feed(i);
516       const DeviceAttributes* device_info;
517       TF_RETURN_IF_ERROR(LookupDevice(*device_set_, feed,
518                                       options.callable_options.feed_devices(),
519                                       &device_info));
520       feed_rewrites.emplace_back(
521           new subgraph::ArgFeedRewrite(&feed, device_info, i));
522       tensors_and_devices.push_back({ParseTensorName(feed), device_info});
523     }
524     if (!options.callable_options.fetch_devices().empty() &&
525         !options.callable_options.fetch_skip_sync()) {
526       return errors::Unimplemented(
527           "CallableOptions.fetch_skip_sync = false is not yet implemented. You "
528           "can set it to true instead, but MUST ensure that Device::Sync() is "
529           "invoked on the Device corresponding to the fetched tensor before "
530           "dereferencing the Tensor's memory.");
531     }
532     for (int i = 0; i < options.callable_options.fetch_size(); ++i) {
533       // WARNING: fetch MUST be a reference, since RetvalFetchRewrite and
534       // tensors_and_devices holds on to its address.
535       const string& fetch = options.callable_options.fetch(i);
536       const DeviceAttributes* device_info;
537       TF_RETURN_IF_ERROR(LookupDevice(*device_set_, fetch,
538                                       options.callable_options.fetch_devices(),
539                                       &device_info));
540       fetch_rewrites.emplace_back(
541           new subgraph::RetvalFetchRewrite(&fetch, device_info, i));
542       tensors_and_devices.push_back({ParseTensorName(fetch), device_info});
543     }
544     TF_RETURN_IF_ERROR(
545         ValidateFeedAndFetchDevices(*graph, tensors_and_devices));
546   } else {
547     if (!options.callable_options.feed_devices().empty() ||
548         !options.callable_options.fetch_devices().empty()) {
549       return errors::Unimplemented(
550           "CallableOptions::feed_devices and CallableOptions::fetch_devices "
551           "to configure feeding/fetching tensors to/from device memory is not "
552           "yet supported when using a remote session.");
553     }
554     const DeviceAttributes* device_info =
555         &device_set_->client_device()->attributes();
556     for (const string& feed : options.callable_options.feed()) {
557       feed_rewrites.emplace_back(
558           new subgraph::RecvFeedRewrite(&feed, device_info));
559     }
560     for (const string& fetch : options.callable_options.fetch()) {
561       fetch_rewrites.emplace_back(
562           new subgraph::SendFetchRewrite(&fetch, device_info));
563     }
564   }
565 
566   for (const TensorConnection& tensor_connection :
567        options.callable_options.tensor_connection()) {
568     Node* from_node = nullptr;
569     TensorId from_id(ParseTensorName(tensor_connection.from_tensor()));
570 
571     for (Node* n : graph->nodes()) {
572       if (n->name() == from_id.first) {
573         from_node = n;
574         break;
575       }
576     }
577     if (from_node == nullptr) {
578       return errors::InvalidArgument(
579           "Requested tensor connection from unknown node: \"",
580           tensor_connection.to_tensor(), "\".");
581     }
582     if (from_id.second >= from_node->num_outputs()) {
583       return errors::InvalidArgument(
584           "Requested tensor connection from unknown edge: \"",
585           tensor_connection.to_tensor(),
586           "\" (actual number of outputs = ", from_node->num_outputs(), ").");
587     }
588 
589     feed_rewrites.emplace_back(new TensorConnectionPruneRewrite(
590         &tensor_connection.to_tensor(), {from_node, from_id.second}));
591   }
592 
593   std::vector<string> target_node_names(
594       options.callable_options.target().begin(),
595       options.callable_options.target().end());
596   TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
597       graph, feed_rewrites, fetch_rewrites, target_node_names,
598       out_rewrite_metadata));
599 
600   CHECK_EQ(out_rewrite_metadata->feed_types.size(),
601            options.callable_options.feed_size() +
602                options.callable_options.tensor_connection_size());
603   for (int i = 0; i < options.callable_options.tensor_connection_size(); ++i) {
604     out_rewrite_metadata->feed_types.pop_back();
605   }
606   return OkStatus();
607 }
608 
InitBaseGraph(std::unique_ptr<Graph> && new_graph)609 Status GraphExecutionState::InitBaseGraph(std::unique_ptr<Graph>&& new_graph) {
610   // Save stateful placements before placing.
611   RestoreStatefulNodes(new_graph.get());
612 
613   GraphOptimizationPassOptions optimization_options;
614   optimization_options.session_handle = session_handle_;
615   optimization_options.session_options = session_options_;
616   optimization_options.graph = &new_graph;
617   optimization_options.flib_def = flib_def_.get();
618   optimization_options.device_set = device_set_;
619 
620   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
621       OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
622 
623   Placer placer(new_graph.get(), "", flib_def_.get(), device_set_,
624                 /* default_local_device= */ nullptr,
625                 session_options_ == nullptr ||
626                     session_options_->config.allow_soft_placement(),
627                 session_options_ != nullptr &&
628                     session_options_->config.log_device_placement());
629   // TODO(mrry): Consider making the Placer cancellable.
630   TF_RETURN_IF_ERROR(placer.Run());
631 
632   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
633       OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
634 
635   for (const Node* n : new_graph->nodes()) {
636     VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id();
637     node_name_to_cost_id_map_[n->name()] = n->cost_id();
638   }
639 
640   SaveStatefulNodes(new_graph.get());
641   graph_ = new_graph.release();
642   return OkStatus();
643 }
644 
OptimizeGraph(const BuildGraphOptions & options,const Graph & graph,const FunctionLibraryDefinition * flib_def,std::unique_ptr<Graph> * optimized_graph,std::unique_ptr<FunctionLibraryDefinition> * optimized_flib)645 Status GraphExecutionState::OptimizeGraph(
646     const BuildGraphOptions& options, const Graph& graph,
647     const FunctionLibraryDefinition* flib_def,
648     std::unique_ptr<Graph>* optimized_graph,
649     std::unique_ptr<FunctionLibraryDefinition>* optimized_flib) {
650 #ifdef IS_MOBILE_PLATFORM
651   return errors::InvalidArgument("Mobile platforms not supported");
652 #else
653   if (session_options_->config.graph_options().place_pruned_graph()) {
654     return errors::InvalidArgument("Can't optimize a pruned graph");
655   }
656 
657   if (grappler::MetaOptimizerEnabled(session_options_->config)) {
658     // Here we build the GrapplerItem before calling the optimizer.
659     grappler::GrapplerItem item;
660     item.id = "tf_graph";
661 
662     // Add devices to the GrapplerItem
663     // It's ok to skip invalid device annotations in Grappler.
664     for (const Device* d : device_set_->devices()) {
665       Status added_device = item.AddDevice(d->name());
666       if (!added_device.ok()) VLOG(3) << added_device.error_message();
667     }
668     VLOG(3) << "Grappler available devices: "
669             << absl::StrJoin(item.devices(), ", ");
670 
671     // Add fetches to the GrapplerItem.
672     item.fetch.insert(item.fetch.end(),
673                       options.callable_options.fetch().begin(),
674                       options.callable_options.fetch().end());
675     item.fetch.insert(item.fetch.end(),
676                       options.callable_options.target().begin(),
677                       options.callable_options.target().end());
678 
679     for (const TensorConnection& tensor_connection :
680          options.callable_options.tensor_connection()) {
681       item.fetch.push_back(tensor_connection.from_tensor());
682     }
683 
684     // Add feeds to the GrapplerItem if we know them.
685     absl::flat_hash_set<absl::string_view> node_names;
686     if (!(options.callable_options.feed().empty() &&
687           options.callable_options.tensor_connection().empty())) {
688       std::vector<SafeTensorId> feeds;
689 
690       for (const string& feed : options.callable_options.feed()) {
691         feeds.emplace_back(ParseTensorName(feed));
692       }
693       for (const TensorConnection& tensor_connection :
694            options.callable_options.tensor_connection()) {
695         feeds.emplace_back(ParseTensorName(tensor_connection.to_tensor()));
696       }
697 
698       // For feeds with tensor index 0 we try to find the corresponding node in
699       // the graph to infer feed data type and shape.
700       absl::flat_hash_set<absl::string_view> feed_nodes;
701 
702       // For feeds with tensor index larger than 0, we can't infer data type or
703       // shape from the graph. Currently we only support type and shape
704       // inference from a small set of node types: Placeholder, Const, etc...
705       for (const SafeTensorId& feed : feeds) {
706         if (feed.index() > 0) {
707           VLOG(3) << "Add undefined feed for: " << feed.ToString();
708           Tensor fake_input(DT_INVALID, {0});
709           item.feed.emplace_back(feed.ToString(), fake_input);
710         } else {
711           VLOG(3) << "Add node for feed inference: " << feed.ToString();
712           feed_nodes.insert(feed.node());
713           continue;
714         }
715       }
716 
717       // For feeds with tensor index == 0 we try to infer data type and tensor
718       // shape from the graph, by looking at the fed node attributes.
719       node_names.reserve(graph.num_nodes());
720       for (const Node* node : graph.nodes()) {
721         node_names.insert(node->name());
722         if (feed_nodes.find(node->name()) == feed_nodes.end()) continue;
723 
724         // Try to get the type and shape of the feed node.
725         PartialTensorShape partial_shape;
726         DataType type;
727         Status st = GetFeedShapeAndTypeFromAttribute(node->def(),
728                                                      &partial_shape, &type);
729 
730         // Failed to get type and shape of the feed node.
731         if (!st.ok()) {
732           VLOG(3) << "Failed to infer feed node type and shape."
733                   << " Add undefined feed for: " << node->name();
734           Tensor fake_input(DT_INVALID, {0});
735           item.feed.emplace_back(node->name(), fake_input);
736           continue;
737         }
738 
739         // If the shape of the placeholder is only partially known, we are free
740         // to set unknown dimensions of its shape to any value we desire. We
741         // choose 0 to minimize the memory impact. Note that this only matters
742         // if an optimizer chooses to run the graph.
743         TensorShape shape;
744         if (partial_shape.unknown_rank()) {
745           shape = TensorShape({0});
746         } else {
747           for (int i = 0; i < partial_shape.dims(); ++i) {
748             if (partial_shape.dim_size(i) < 0) {
749               partial_shape.set_dim(i, 0);
750             }
751           }
752           if (!partial_shape.AsTensorShape(&shape)) {
753             return errors::InvalidArgument(
754                 "Could not derive shape for feed node: ",
755                 node->def().DebugString());
756           }
757         }
758 
759         VLOG(3) << "Add feed for: " << node->name() << "; type: " << type
760                 << "; shape: " << shape;
761         Tensor fake_input(type, shape);
762         item.feed.emplace_back(node->name(), fake_input);
763       }
764     }
765 
766     // Validate that the feeds and fetches are valid.
767     if (node_names.empty()) {
768       // Collect all node names in the graph if we didn't already.
769       node_names.reserve(graph.num_nodes());
770       for (const Node* node : graph.nodes()) {
771         node_names.insert(node->name());
772       }
773     }
774     for (const auto& feed : item.feed) {
775       SafeTensorId tensor_id = ParseTensorName(feed.first);
776       if (node_names.find(tensor_id.node()) == node_names.end()) {
777         return errors::InvalidArgument("Invalid feed, no such node in graph: ",
778                                        feed.first);
779       }
780     }
781     for (const auto& fetch : item.fetch) {
782       SafeTensorId tensor_id = ParseTensorName(fetch);
783       if (node_names.find(tensor_id.node()) == node_names.end()) {
784         return errors::InvalidArgument("Invalid fetch, no such node in graph: ",
785                                        fetch);
786       }
787     }
788 
789     // Convert Graph to GraphDef and add it to the GrapplerItem.
790     graph.ToGraphDef(&item.graph);
791     // TODO(b/114748242): Add a unit test to test this bug fix.
792     if (flib_def) {
793       *item.graph.mutable_library() = flib_def->ToProto();
794     }
795 
796     // Construct a virtual cluster and find the cpu_device, which the
797     // ConstantFolding optimizer will use for partial evaluation of the graph.
798     grappler::VirtualCluster cluster(device_set_);
799     Device* cpu_device = nullptr;
800     for (const auto& device : device_set_->devices()) {
801       if (device->parsed_name().id == 0 &&
802           StringPiece(device->parsed_name().type) == "CPU" &&
803           device->GetAllocator(AllocatorAttributes()) != nullptr) {
804         cpu_device = device;
805       }
806     }
807 
808     // Now we can run the MetaOptimizer on the constructed GrapplerItem.
809     GraphDef new_graph;
810     TF_RETURN_IF_ERROR(
811         grappler::RunMetaOptimizer(std::move(item), session_options_->config,
812                                    cpu_device, &cluster, &new_graph));
813 
814     // Merge optimized graph function library with an original library.
815     // Optimized graph might have new functions specialized for it's
816     // instantiation context (see Grappler function optimizer), and modified
817     // function body for the existing functions.
818     optimized_flib->reset(new FunctionLibraryDefinition(*flib_def));
819 
820     for (const FunctionDef& fdef : new_graph.library().function()) {
821       const string& func_name = fdef.signature().name();
822 
823       if ((*optimized_flib)->Contains(func_name)) {
824         VLOG(3) << "Replace function: name=" << func_name;
825         TF_RETURN_IF_ERROR((*optimized_flib)->ReplaceFunction(func_name, fdef));
826       } else {
827         VLOG(3) << "Add new function: name=" << func_name;
828         TF_RETURN_IF_ERROR((*optimized_flib)->AddFunctionDef(fdef));
829       }
830     }
831     optimized_graph->reset(new Graph(OpRegistry::Global()));
832 
833     // Convert the optimized GraphDef back to a Graph.
834     GraphConstructorOptions opts;
835     opts.allow_internal_ops = true;
836     TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, std::move(new_graph),
837                                               optimized_graph->get()));
838     // The graph conversion sets the requested device names but not the
839     // assigned device names. However, since at this point the graph is placed
840     // TF expects an assigned device name for every node. Therefore we copy
841     // the requested device into the assigned device field.
842     for (Node* node : optimized_graph->get()->nodes()) {
843       node->set_assigned_device_name(node->requested_device());
844     }
845     return OkStatus();
846   } else {
847     return errors::InvalidArgument("Meta Optimizer disabled");
848   }
849 #endif  // IS_MOBILE_PLATFORM
850 }
851 
BuildGraph(const BuildGraphOptions & options,std::unique_ptr<ClientGraph> * out)852 Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
853                                        std::unique_ptr<ClientGraph>* out) {
854   VLOG(1) << "BuildGraph";
855   const uint64 start_time_usecs = Env::Default()->NowMicros();
856   if (!graph_) {
857     // It is only valid to call this method directly when the original graph
858     // was created with the option `place_pruned_graph == false`.
859     return errors::Internal(
860         "Attempted to prune a graph that has not been fully initialized.");
861   }
862 
863   // Grappler optimization might change the structure of a graph itself, and
864   // also it can add/prune functions to/from the library.
865   std::unique_ptr<Graph> optimized_graph;
866   std::unique_ptr<FunctionLibraryDefinition> optimized_flib;
867 
868   Status s = OptimizeGraph(options, *graph_, flib_def_.get(), &optimized_graph,
869                            &optimized_flib);
870   if (!s.ok()) {
871     VLOG(2) << "Grappler optimization failed. Error: " << s.error_message();
872     // Simply copy the original graph and the function library if we couldn't
873     // optimize it.
874     optimized_graph.reset(new Graph(flib_def_.get()));
875     CopyGraph(*graph_, optimized_graph.get());
876     optimized_flib.reset(new FunctionLibraryDefinition(*flib_def_));
877   }
878 
879   subgraph::RewriteGraphMetadata rewrite_metadata;
880   if (session_options_ == nullptr ||
881       !session_options_->config.graph_options().place_pruned_graph()) {
882     TF_RETURN_IF_ERROR(
883         PruneGraph(options, optimized_graph.get(), &rewrite_metadata));
884   } else {
885     // This GraphExecutionState represents a graph that was
886     // pruned when this was constructed, so we copy the metadata from
887     // a member variable.
888     CHECK(rewrite_metadata_);
889     rewrite_metadata = *rewrite_metadata_;
890   }
891 
892   CHECK_EQ(options.callable_options.feed_size(),
893            rewrite_metadata.feed_types.size());
894   CHECK_EQ(options.callable_options.fetch_size(),
895            rewrite_metadata.fetch_types.size());
896 
897   // TODO(andydavis): Clarify optimization pass requirements around CostModel.
898   GraphOptimizationPassOptions optimization_options;
899   optimization_options.session_options = session_options_;
900   optimization_options.graph = &optimized_graph;
901   optimization_options.flib_def = optimized_flib.get();
902   optimization_options.device_set = device_set_;
903 
904   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
905       OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
906 
907   int64_t collective_graph_key = options.collective_graph_key;
908   if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
909     // BuildGraphOptions does not specify a collective_graph_key.  Check all
910     // nodes in the Graph and FunctionLibraryDefinition for collective ops and
911     // if found, initialize a collective_graph_key as a hash of the ordered set
912     // of instance keys.
913     std::set<int32> instance_key_set;
914     bool has_collective_v2 = false;
915     for (Node* node : optimized_graph->nodes()) {
916       if (node->IsCollective()) {
917         int32_t instance_key;
918         TF_RETURN_IF_ERROR(
919             GetNodeAttr(node->attrs(), "instance_key", &instance_key));
920         instance_key_set.emplace(instance_key);
921       } else if (IsCollectiveV2(node->type_string())) {
922         has_collective_v2 = true;
923       } else {
924         const FunctionDef* fdef = optimized_flib->Find(node->def().op());
925         if (fdef != nullptr) {
926           for (const NodeDef& ndef : fdef->node_def()) {
927             if (ndef.op() == "CollectiveReduce" ||
928                 ndef.op() == "CollectiveBcastSend" ||
929                 ndef.op() == "CollectiveBcastRecv" ||
930                 ndef.op() == "CollectiveGather") {
931               int32_t instance_key;
932               TF_RETURN_IF_ERROR(
933                   GetNodeAttr(ndef, "instance_key", &instance_key));
934               instance_key_set.emplace(instance_key);
935             } else if (IsCollectiveV2(ndef.op())) {
936               has_collective_v2 = true;
937             }
938           }
939         }
940       }
941     }
942     if (!instance_key_set.empty()) {
943       uint64 hash = 0x8774aa605c729c72ULL;
944       for (int32_t instance_key : instance_key_set) {
945         hash = Hash64Combine(instance_key, hash);
946       }
947       collective_graph_key = hash;
948     } else if (has_collective_v2) {
949       collective_graph_key = 0x8774aa605c729c72ULL;
950     }
951   }
952 
953   // Make collective execution order deterministic if needed.
954   if (options.collective_order != GraphCollectiveOrder::kNone) {
955     TF_RETURN_IF_ERROR(
956         OrderCollectives(optimized_graph.get(), options.collective_order));
957   }
958 
959   // Copy the extracted graph in order to make its node ids dense,
960   // since the local CostModel used to record its stats is sized by
961   // the largest node id.
962   std::unique_ptr<ClientGraph> dense_copy(
963       new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types,
964                       rewrite_metadata.fetch_types, collective_graph_key));
965   CopyGraph(*optimized_graph, &dense_copy->graph);
966 
967   // TODO(vrv): We should check invariants of the graph here.
968   metrics::UpdateGraphBuildTime(Env::Default()->NowMicros() - start_time_usecs);
969   *out = std::move(dense_copy);
970   return OkStatus();
971 }
972 
973 }  // namespace tensorflow
974