• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/graph_execution_state.h"
17 
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/memory/memory.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/graph_constructor.h"
30 #include "tensorflow/core/common_runtime/metrics.h"
31 #include "tensorflow/core/common_runtime/optimization_registry.h"
32 #include "tensorflow/core/common_runtime/placer.h"
33 #include "tensorflow/core/framework/attr_value.pb.h"
34 #include "tensorflow/core/framework/function.h"
35 #include "tensorflow/core/framework/function.pb.h"
36 #include "tensorflow/core/framework/graph.pb.h"
37 #include "tensorflow/core/framework/graph_def_util.h"
38 #include "tensorflow/core/framework/node_def.pb.h"
39 #include "tensorflow/core/framework/op.h"
40 #include "tensorflow/core/framework/tensor.pb.h"
41 #include "tensorflow/core/framework/versions.pb.h"
42 #include "tensorflow/core/graph/algorithm.h"
43 #include "tensorflow/core/graph/collective_order.h"
44 #include "tensorflow/core/graph/graph.h"
45 #include "tensorflow/core/graph/subgraph.h"
46 #include "tensorflow/core/graph/tensor_id.h"
47 #include "tensorflow/core/graph/validate.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/gtl/flatset.h"
51 #include "tensorflow/core/lib/strings/stringprintf.h"
52 #include "tensorflow/core/platform/logging.h"
53 #include "tensorflow/core/platform/types.h"
54 #include "tensorflow/core/util/device_name_utils.h"
55 #include "tensorflow/core/util/util.h"
56 
57 #ifndef IS_MOBILE_PLATFORM
58 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
59 #include "tensorflow/core/grappler/grappler_item.h"
60 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
61 #endif  // IS_MOBILE_PLATFORM
62 
63 namespace tensorflow {
64 
65 namespace {
IsCollectiveV2(const string & op)66 bool IsCollectiveV2(const string& op) {
67   return op == "CollectiveReduceV2" || op == "CollectiveGatherV2" ||
68          op == "CollectiveBcastRecvV2" || op == "CollectiveBcastSendV2";
69 }
70 }  // namespace
71 
GraphExecutionState(std::unique_ptr<GraphDef> && graph_def,std::unique_ptr<FunctionLibraryDefinition> && flib_def,const GraphExecutionStateOptions & options)72 GraphExecutionState::GraphExecutionState(
73     std::unique_ptr<GraphDef>&& graph_def,
74     std::unique_ptr<FunctionLibraryDefinition>&& flib_def,
75     const GraphExecutionStateOptions& options)
76     : stateful_placements_(options.stateful_placements),
77       original_graph_def_(std::move(graph_def)),
78       device_set_(options.device_set),
79       session_options_(options.session_options),
80       session_handle_(options.session_handle),
81       flib_def_(std::move(flib_def)),
82       graph_(nullptr) {}
83 
~GraphExecutionState()84 GraphExecutionState::~GraphExecutionState() {
85   node_name_to_cost_id_map_.clear();
86   delete graph_;
87 }
88 
MakeForBaseGraph(GraphDef && graph_def,const GraphExecutionStateOptions & options,std::unique_ptr<GraphExecutionState> * out_state)89 /* static */ Status GraphExecutionState::MakeForBaseGraph(
90     GraphDef&& graph_def, const GraphExecutionStateOptions& options,
91     std::unique_ptr<GraphExecutionState>* out_state) {
92 #ifndef __ANDROID__
93   VLOG(4) << "Graph proto is \n" << graph_def.DebugString();
94 #endif  // __ANDROID__
95 
96   auto flib_def = absl::make_unique<FunctionLibraryDefinition>(
97       OpRegistry::Global(), graph_def.library());
98 
99   TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&graph_def, *flib_def, 0));
100 
101   if (options.session_options->config.graph_options().place_pruned_graph() ||
102       !options.session_options->config.experimental()
103            .optimize_for_static_graph()) {
104     auto ret = absl::WrapUnique(new GraphExecutionState(
105         absl::make_unique<GraphDef>(std::move(graph_def)), std::move(flib_def),
106         options));
107 
108     // When place_pruned_graph is true, a different Graph* will be initialized
109     // each time we prune the original graph, so there is no need to
110     // construct a Graph* in this case.
111     if (!options.session_options->config.graph_options().place_pruned_graph()) {
112       auto base_graph = absl::make_unique<Graph>(OpRegistry::Global());
113       TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *ret->original_graph_def_,
114                                                 base_graph.get()));
115       TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph)));
116     }
117     *out_state = std::move(ret);
118   } else {
119     auto ret = absl::WrapUnique(
120         new GraphExecutionState(nullptr, std::move(flib_def), options));
121     auto base_graph = absl::make_unique<Graph>(OpRegistry::Global());
122     TF_RETURN_IF_ERROR(
123         ConvertGraphDefToGraph({}, std::move(graph_def), base_graph.get()));
124     TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph)));
125     *out_state = std::move(ret);
126   }
127   return Status::OK();
128 }
129 
MakeForPrunedGraph(const GraphExecutionState & base_execution_state,const GraphExecutionStateOptions & options,const BuildGraphOptions & subgraph_options,std::unique_ptr<GraphExecutionState> * out_state,std::unique_ptr<ClientGraph> * out_client_graph)130 /* static */ Status GraphExecutionState::MakeForPrunedGraph(
131     const GraphExecutionState& base_execution_state,
132     const GraphExecutionStateOptions& options,
133     const BuildGraphOptions& subgraph_options,
134     std::unique_ptr<GraphExecutionState>* out_state,
135     std::unique_ptr<ClientGraph>* out_client_graph) {
136   if (!(base_execution_state.session_options_->config.graph_options()
137             .place_pruned_graph() &&
138         options.session_options->config.graph_options().place_pruned_graph())) {
139     return errors::Internal(
140         "MakeForPrunedGraph is only supported when the `place_pruned_graph` "
141         "option is true.");
142   }
143   if (!base_execution_state.original_graph_def_) {
144     // NOTE(mrry): By adding this restriction, which matches the only current
145     // usage of this (fairly obscure) method, we do not need to store a
146     // redundant copy of the original graph in `*out_state`.
147     return errors::Internal(
148         "MakeForPrunedGraph is only supported when `base_execution_state` is "
149         "the Session-level `GraphExecutionState`.");
150   }
151 
152   // NOTE(mrry): This makes a copy of `graph_def`, which is
153   // regrettable. We could make `GraphDef` objects sharable between
154   // execution states to optimize pruned graph execution, but since
155   // this case is primarily used for interactive sessions, we make the
156   // bet that graph construction is not performance-critical. (Note
157   // also that the previous version used `Extend()`, which is strictly
158   // more expensive than copying a `GraphDef`.)
159   GraphDef temp(*base_execution_state.original_graph_def_);
160   auto flib_def = absl::make_unique<FunctionLibraryDefinition>(
161       OpRegistry::Global(), temp.library());
162   TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&temp, *flib_def, 0));
163   auto ret = absl::WrapUnique(
164       new GraphExecutionState(nullptr, std::move(flib_def), options));
165 
166   auto base_graph = absl::make_unique<Graph>(OpRegistry::Global());
167   TF_RETURN_IF_ERROR(
168       ConvertGraphDefToGraph({}, std::move(temp), base_graph.get()));
169 
170   // Rewrite the graph before placement.
171   ret->rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata);
172   TF_RETURN_IF_ERROR(ret->PruneGraph(subgraph_options, base_graph.get(),
173                                      ret->rewrite_metadata_.get()));
174   TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph)));
175   TF_RETURN_IF_ERROR(ret->BuildGraph(subgraph_options, out_client_graph));
176   *out_state = std::move(ret);
177   return Status::OK();
178 }
179 
Extend(const GraphDef & extension_def,std::unique_ptr<GraphExecutionState> * out) const180 Status GraphExecutionState::Extend(
181     const GraphDef& extension_def,
182     std::unique_ptr<GraphExecutionState>* out) const {
183   if (session_options_->config.experimental().optimize_for_static_graph()) {
184     return errors::FailedPrecondition(
185         "Extending the graph is not supported when "
186         "`optimize_for_static_graph` is true.");
187   }
188 
189   GraphDef gdef;
190 
191   // 1. Copy the function library.
192   TF_RETURN_IF_ERROR(flib_def_->AddLibrary(extension_def.library()));
193   *gdef.mutable_library() = flib_def_->ToProto();
194 
195   // 2. Build an index of the new node names.
196   std::unordered_set<string> new_names;
197   for (const NodeDef& node : extension_def.node()) {
198     new_names.insert(node.name());
199   }
200 
201   // 3. Add the non-duplicates from the old graph to the new graph.
202   //    Return an error if the same node name appears in both the
203   //    old graph and the extension.
204   for (const NodeDef& node : original_graph_def_->node()) {
205     if (new_names.count(node.name()) == 0) {
206       *gdef.add_node() = node;
207     } else {
208       return errors::InvalidArgument(
209           "GraphDef argument to Extend includes node '", node.name(),
210           "', which was created by a previous call to Create or Extend in this "
211           "session.");
212     }
213   }
214 
215   // 4. Merge the versions field.
216   int old_node_size = gdef.node_size();
217   gdef.mutable_node()->MergeFrom(extension_def.node());
218   TF_RETURN_IF_ERROR(
219       AddDefaultAttrsToGraphDef(&gdef, *flib_def_, old_node_size));
220   // Merge versions
221   if (gdef.has_versions()) {
222     if (gdef.versions().producer() != extension_def.versions().producer()) {
223       return errors::InvalidArgument(
224           "Can't extend GraphDef at version ", gdef.versions().producer(),
225           " with graph at version ", extension_def.versions().producer());
226     }
227     VersionDef* versions = gdef.mutable_versions();
228     versions->set_min_consumer(std::max(
229         versions->min_consumer(), extension_def.versions().min_consumer()));
230     if (extension_def.versions().bad_consumers_size()) {
231       // Add new bad_consumers that aren't already marked bad.
232       //
233       // Note: This implementation is quadratic time if there are many calls to
234       // ExtendLocked with many bad consumers.  Since this is unlikely, and
235       // fixing it would require data structures outside of this routine,
236       // quadratic time it is.
237       auto* bad_consumers = versions->mutable_bad_consumers();
238       const std::unordered_set<int> existing(bad_consumers->begin(),
239                                              bad_consumers->end());
240       for (const int v : extension_def.versions().bad_consumers()) {
241         if (existing.find(v) == existing.end()) {
242           bad_consumers->Add(v);
243         }
244       }
245     }
246 
247   } else {
248     gdef.mutable_versions()->CopyFrom(extension_def.versions());
249   }
250 
251   // 5. Validate that the final graphdef is valid.
252   if (gdef.versions().producer() >= 5) {
253     // Validate the graph: we assume that merging two valid graphs
254     // should maintain graph validity.
255     TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *flib_def_));
256   }
257 
258   // 6. Add the extension.
259   GraphExecutionStateOptions combined_options;
260   combined_options.device_set = device_set_;
261   combined_options.session_options = session_options_;
262   combined_options.session_handle = session_handle_;
263   combined_options.stateful_placements = stateful_placements_;
264 
265   TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&gdef, *flib_def_, 0));
266   auto flib_def = absl::make_unique<FunctionLibraryDefinition>(
267       OpRegistry::Global(), gdef.library());
268   auto new_execution_state = absl::WrapUnique(
269       new GraphExecutionState(absl::make_unique<GraphDef>(std::move(gdef)),
270                               std::move(flib_def), combined_options));
271 
272   if (!session_options_->config.graph_options().place_pruned_graph()) {
273     auto base_graph = absl::make_unique<Graph>(OpRegistry::Global());
274     TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
275         {}, *new_execution_state->original_graph_def_, base_graph.get()));
276     TF_RETURN_IF_ERROR(
277         new_execution_state->InitBaseGraph(std::move(base_graph)));
278   }
279   *out = std::move(new_execution_state);
280 
281   // NOTE(mrry): Extend() is likely to be used for non-throughput-sensitive
282   // interactive workloads, but in future we may want to transfer other
283   // parts of the placement and/or cost model.
284   return Status::OK();
285 }
286 
SaveStatefulNodes(Graph * graph)287 void GraphExecutionState::SaveStatefulNodes(Graph* graph) {
288   for (Node* n : graph->nodes()) {
289     if (n->op_def().is_stateful()) {
290       VLOG(2) << "Saving " << n->DebugString();
291       stateful_placements_[n->name()] = n->assigned_device_name();
292     }
293   }
294 }
295 
RestoreStatefulNodes(Graph * graph)296 void GraphExecutionState::RestoreStatefulNodes(Graph* graph) {
297   for (Node* n : graph->nodes()) {
298     if (n->op_def().is_stateful()) {
299       auto iter = stateful_placements_.find(n->name());
300       if (iter != stateful_placements_.end()) {
301         n->set_assigned_device_name(iter->second);
302         VLOG(2) << "Restored " << n->DebugString();
303       }
304     }
305   }
306 }
307 
308 namespace {
309 
310 class TensorConnectionPruneRewrite : public subgraph::PruneRewrite {
311  public:
TensorConnectionPruneRewrite(const string * endpoint_name,NodeBuilder::NodeOut from_tensor)312   TensorConnectionPruneRewrite(const string* endpoint_name,
313                                NodeBuilder::NodeOut from_tensor)
314       : subgraph::PruneRewrite(endpoint_name, nullptr /* device_info */),
315         from_tensor_(std::move(from_tensor)) {}
316 
AddNode(Graph * g,NodeBuilder::NodeOut feed_tensor,Node ** out_node)317   Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
318                  Node** out_node) override {
319     Status s;
320     auto check_no_cycle_fn = [this, feed_tensor, &s](Node* n) {
321       if (n == feed_tensor.node) {
322         s.Update(errors::InvalidArgument(
323             "Requested Tensor connection between nodes \"",
324             feed_tensor.node->name(), "\" and \"", from_tensor_.node->name(),
325             "\" would create a cycle."));
326       }
327     };
328     ReverseDFSFrom(*g, {from_tensor_.node}, std::move(check_no_cycle_fn),
329                    nullptr);
330     TF_RETURN_IF_ERROR(s);
331 
332     TF_RETURN_IF_ERROR(
333         NodeBuilder(strings::StrCat("_identity_", feed_tensor.node->name(), "_",
334                                     feed_tensor.index),
335                     "Identity")
336             .Input(from_tensor_)
337             .Attr("T",
338                   BaseType(from_tensor_.node->output_type(from_tensor_.index)))
339             .Finalize(g, out_node));
340 
341     (*out_node)->set_assigned_device_name(
342         feed_tensor.node->assigned_device_name());
343     return Status::OK();
344   }
345 
346  private:
347   NodeBuilder::NodeOut from_tensor_;
348 };
349 
350 template <class Map>
LookupDevice(const DeviceSet & device_set,const string & tensor_name,const Map & tensor2device,const tensorflow::DeviceAttributes ** out_device_attrs)351 Status LookupDevice(const DeviceSet& device_set, const string& tensor_name,
352                     const Map& tensor2device,
353                     const tensorflow::DeviceAttributes** out_device_attrs) {
354   *out_device_attrs = nullptr;
355   if (tensor2device.empty()) {
356     *out_device_attrs = &device_set.client_device()->attributes();
357     return Status::OK();
358   }
359   const auto it = tensor2device.find(tensor_name);
360   if (it == tensor2device.end()) {
361     *out_device_attrs = &device_set.client_device()->attributes();
362     return Status::OK();
363   }
364   DeviceNameUtils::ParsedName parsed_name;
365   if (!DeviceNameUtils::ParseFullName(it->second, &parsed_name)) {
366     return errors::InvalidArgument("Invalid device name ('", it->second,
367                                    "') provided for the tensor '", tensor_name,
368                                    "' in CallableOptions");
369   }
370   Device* device = device_set.FindDeviceByName(
371       DeviceNameUtils::ParsedNameToString(parsed_name));
372   if (device == nullptr) {
373     return errors::InvalidArgument("Device '", it->second,
374                                    "' specified for tensor '", tensor_name,
375                                    "' in CallableOptions does not exist");
376   }
377   *out_device_attrs = &device->attributes();
378   return Status::OK();
379 }
380 
381 struct TensorAndDevice {
382   // WARNING: backing memory for the 'tensor' field is NOT owend.
383   const TensorId tensor;
384   // WARNING: device pointer is not owned, so must outlive TensorAndDevice.
385   const DeviceAttributes* device;
386 };
387 
388 // Tensors of some DataTypes cannot placed in device memory as feeds or
389 // fetches. Validate against a allowlist of those known to work.
IsFeedAndFetchSupported(DataType dtype,const string & device_type)390 bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) {
391   // The mechanism for supporting feeds of device-backed Tensors requires
392   // the _Arg kernel to be registered for the corresponding type (and that
393   // the input to the kernel be in device and not host memory).
394   //
395   // The mechanism for supporting fetches of device-backed Tensors requires
396   // the _Retval kernel to be registered for the corresponding type (and
397   // that the output is produced in device and not host memory).
398   //
399   // For now, we return true iff there are _Arg AND _Retval kernels for dtype on
400   // the device. False negatives are okay, false positives would be bad.
401   //
402   // TODO(ashankar): Instead of a allowlist here, perhaps we could query
403   // the kernel registry for _Arg and _Retval kernels instead.
404   if (device_type == DEVICE_CPU) return true;
405   if (device_type != DEVICE_GPU) return false;
406   switch (dtype) {
407     case DT_BFLOAT16:
408     case DT_BOOL:
409     case DT_COMPLEX128:
410     case DT_COMPLEX64:
411     case DT_DOUBLE:
412     case DT_FLOAT:
413     case DT_HALF:
414     case DT_INT16:
415     case DT_INT64:
416     case DT_INT8:
417     case DT_UINT16:
418     case DT_UINT8:
419       return true;
420     default:
421       return false;
422   }
423 }
424 
ValidateFeedAndFetchDevices(const Graph & graph,const std::vector<TensorAndDevice> & tensors_and_devices)425 Status ValidateFeedAndFetchDevices(
426     const Graph& graph,
427     const std::vector<TensorAndDevice>& tensors_and_devices) {
428   if (tensors_and_devices.empty()) return Status::OK();
429   std::vector<bool> found(tensors_and_devices.size(), false);
430   for (const Node* node : graph.nodes()) {
431     // Linearly looping through all nodes and then all feed+fetch tensors isn't
432     // quite efficient. At the time of this writing, the expectation was that
433     // tensors_and_devices.size() is really small in practice, so this won't be
434     // problematic.
435     // Revist and make a more efficient lookup possible if needed (e.g., perhaps
436     // Graph can maintain a map from node name to Node*).
437     for (int i = 0; i < tensors_and_devices.size(); ++i) {
438       const TensorAndDevice& td = tensors_and_devices[i];
439       if (td.tensor.first != node->name()) continue;
440       found[i] = true;
441       TF_RETURN_IF_ERROR(graph.IsValidOutputTensor(node, td.tensor.second));
442       const DataType dtype = node->output_type(td.tensor.second);
443       if (!IsFeedAndFetchSupported(dtype, td.device->device_type())) {
444         return errors::Unimplemented(
445             "Cannot feed or fetch tensor '", td.tensor.ToString(),
446             "' from device ", td.device->name(), " as feeding/fetching from ",
447             td.device->device_type(), " devices is not yet supported for ",
448             DataTypeString(dtype), " tensors");
449       }
450     }
451   }
452   for (int i = 0; i < found.size(); ++i) {
453     if (!found[i]) {
454       return errors::InvalidArgument(
455           "Tensor ", tensors_and_devices[i].tensor.ToString(),
456           ", specified in either feed_devices or fetch_devices was not found "
457           "in the Graph");
458     }
459   }
460   return Status::OK();
461 }
462 
GetFeedShapeAndTypeFromAttribute(const NodeDef & node,PartialTensorShape * shape,DataType * type)463 Status GetFeedShapeAndTypeFromAttribute(const NodeDef& node,
464                                         PartialTensorShape* shape,
465                                         DataType* type) {
466   static const gtl::FlatSet<string>* const kHasExplicitShapeAttribute =
467       CHECK_NOTNULL((new gtl::FlatSet<string>{
468           "Placeholder", "PlaceholderV2", "PlaceholderWithDefault",
469           "ParallelConcat", "ImmutableConst", "_ParallelConcatStart",
470           "InfeedDequeue", "OutfeedDequeue", "CollectiveBcastSend",
471           "CollectiveBcastRecv", "AccumulateNV2", "VariableV2", "Variable",
472           "TemporaryVariable", "NcclBroadcast", "_ScopedAllocator",
473           "_ScopedAllocatorConcat"}));
474 
475   // All the node types handled here have their output datatype set in
476   // either attribute 'dtype' or 'T'.
477   if (!TryGetNodeAttr(node, "dtype", type) &&
478       !TryGetNodeAttr(node, "T", type)) {
479     return errors::InvalidArgument(
480         "Could not determine output type for feed node: ", node.name(),
481         " of type ", node.op());
482   }
483 
484   // First handle the case of feeding a const node.
485   if (node.op() == "Const" && HasNodeAttr(node, "value")) {
486     *shape =
487         PartialTensorShape(node.attr().at("value").tensor().tensor_shape());
488   } else if (kHasExplicitShapeAttribute->find(node.op()) !=
489              kHasExplicitShapeAttribute->end()) {
490     TF_RETURN_IF_ERROR(GetNodeAttr(node, "shape", shape));
491   } else {
492     return errors::InvalidArgument("Could not determine shape for feed node: ",
493                                    node.name(), " of type ", node.op());
494   }
495   return Status::OK();
496 }
497 
498 }  // namespace
499 
PruneGraph(const BuildGraphOptions & options,Graph * graph,subgraph::RewriteGraphMetadata * out_rewrite_metadata)500 Status GraphExecutionState::PruneGraph(
501     const BuildGraphOptions& options, Graph* graph,
502     subgraph::RewriteGraphMetadata* out_rewrite_metadata) {
503   std::vector<std::unique_ptr<subgraph::PruneRewrite>> feed_rewrites;
504   feed_rewrites.reserve(options.callable_options.feed_size());
505   std::vector<std::unique_ptr<subgraph::PruneRewrite>> fetch_rewrites;
506   fetch_rewrites.reserve(options.callable_options.fetch_size());
507   if (options.use_function_convention) {
508     std::vector<TensorAndDevice> tensors_and_devices;
509     for (int i = 0; i < options.callable_options.feed_size(); ++i) {
510       // WARNING: feed MUST be a reference, since ArgFeedRewrite and
511       // tensors_and_devices holds on to its address.
512       const string& feed = options.callable_options.feed(i);
513       const DeviceAttributes* device_info;
514       TF_RETURN_IF_ERROR(LookupDevice(*device_set_, feed,
515                                       options.callable_options.feed_devices(),
516                                       &device_info));
517       feed_rewrites.emplace_back(
518           new subgraph::ArgFeedRewrite(&feed, device_info, i));
519       tensors_and_devices.push_back({ParseTensorName(feed), device_info});
520     }
521     if (!options.callable_options.fetch_devices().empty() &&
522         !options.callable_options.fetch_skip_sync()) {
523       return errors::Unimplemented(
524           "CallableOptions.fetch_skip_sync = false is not yet implemented. You "
525           "can set it to true instead, but MUST ensure that Device::Sync() is "
526           "invoked on the Device corresponding to the fetched tensor before "
527           "dereferencing the Tensor's memory.");
528     }
529     for (int i = 0; i < options.callable_options.fetch_size(); ++i) {
530       // WARNING: fetch MUST be a reference, since RetvalFetchRewrite and
531       // tensors_and_devices holds on to its address.
532       const string& fetch = options.callable_options.fetch(i);
533       const DeviceAttributes* device_info;
534       TF_RETURN_IF_ERROR(LookupDevice(*device_set_, fetch,
535                                       options.callable_options.fetch_devices(),
536                                       &device_info));
537       fetch_rewrites.emplace_back(
538           new subgraph::RetvalFetchRewrite(&fetch, device_info, i));
539       tensors_and_devices.push_back({ParseTensorName(fetch), device_info});
540     }
541     TF_RETURN_IF_ERROR(
542         ValidateFeedAndFetchDevices(*graph, tensors_and_devices));
543   } else {
544     if (!options.callable_options.feed_devices().empty() ||
545         !options.callable_options.fetch_devices().empty()) {
546       return errors::Unimplemented(
547           "CallableOptions::feed_devices and CallableOptions::fetch_devices "
548           "to configure feeding/fetching tensors to/from device memory is not "
549           "yet supported when using a remote session.");
550     }
551     const DeviceAttributes* device_info =
552         &device_set_->client_device()->attributes();
553     for (const string& feed : options.callable_options.feed()) {
554       feed_rewrites.emplace_back(
555           new subgraph::RecvFeedRewrite(&feed, device_info));
556     }
557     for (const string& fetch : options.callable_options.fetch()) {
558       fetch_rewrites.emplace_back(
559           new subgraph::SendFetchRewrite(&fetch, device_info));
560     }
561   }
562 
563   for (const TensorConnection& tensor_connection :
564        options.callable_options.tensor_connection()) {
565     Node* from_node = nullptr;
566     TensorId from_id(ParseTensorName(tensor_connection.from_tensor()));
567 
568     for (Node* n : graph->nodes()) {
569       if (n->name() == from_id.first) {
570         from_node = n;
571         break;
572       }
573     }
574     if (from_node == nullptr) {
575       return errors::InvalidArgument(
576           "Requested tensor connection from unknown node: \"",
577           tensor_connection.to_tensor(), "\".");
578     }
579     if (from_id.second >= from_node->num_outputs()) {
580       return errors::InvalidArgument(
581           "Requested tensor connection from unknown edge: \"",
582           tensor_connection.to_tensor(),
583           "\" (actual number of outputs = ", from_node->num_outputs(), ").");
584     }
585 
586     feed_rewrites.emplace_back(new TensorConnectionPruneRewrite(
587         &tensor_connection.to_tensor(), {from_node, from_id.second}));
588   }
589 
590   std::vector<string> target_node_names(
591       options.callable_options.target().begin(),
592       options.callable_options.target().end());
593   TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
594       graph, feed_rewrites, fetch_rewrites, target_node_names,
595       out_rewrite_metadata));
596 
597   CHECK_EQ(out_rewrite_metadata->feed_types.size(),
598            options.callable_options.feed_size() +
599                options.callable_options.tensor_connection_size());
600   for (int i = 0; i < options.callable_options.tensor_connection_size(); ++i) {
601     out_rewrite_metadata->feed_types.pop_back();
602   }
603   return Status::OK();
604 }
605 
InitBaseGraph(std::unique_ptr<Graph> && new_graph)606 Status GraphExecutionState::InitBaseGraph(std::unique_ptr<Graph>&& new_graph) {
607   // Save stateful placements before placing.
608   RestoreStatefulNodes(new_graph.get());
609 
610   GraphOptimizationPassOptions optimization_options;
611   optimization_options.session_handle = session_handle_;
612   optimization_options.session_options = session_options_;
613   optimization_options.graph = &new_graph;
614   optimization_options.flib_def = flib_def_.get();
615   optimization_options.device_set = device_set_;
616 
617   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
618       OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
619 
620   Placer placer(new_graph.get(), "", flib_def_.get(), device_set_,
621                 /* default_local_device= */ nullptr,
622                 session_options_ == nullptr ||
623                     session_options_->config.allow_soft_placement(),
624                 session_options_ != nullptr &&
625                     session_options_->config.log_device_placement());
626   // TODO(mrry): Consider making the Placer cancellable.
627   TF_RETURN_IF_ERROR(placer.Run());
628 
629   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
630       OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
631 
632   for (const Node* n : new_graph->nodes()) {
633     VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id();
634     node_name_to_cost_id_map_[n->name()] = n->cost_id();
635   }
636 
637   SaveStatefulNodes(new_graph.get());
638   graph_ = new_graph.release();
639   return Status::OK();
640 }
641 
OptimizeGraph(const BuildGraphOptions & options,std::unique_ptr<Graph> * optimized_graph,std::unique_ptr<FunctionLibraryDefinition> * optimized_flib)642 Status GraphExecutionState::OptimizeGraph(
643     const BuildGraphOptions& options, std::unique_ptr<Graph>* optimized_graph,
644     std::unique_ptr<FunctionLibraryDefinition>* optimized_flib) {
645 #ifdef IS_MOBILE_PLATFORM
646   return errors::InvalidArgument("Mobile platforms not supported");
647 #else
648   if (session_options_->config.graph_options().place_pruned_graph()) {
649     return errors::InvalidArgument("Can't optimize a pruned graph");
650   }
651 
652   if (grappler::MetaOptimizerEnabled(session_options_->config)) {
653     // Here we build the GrapplerItem before calling the optimizer.
654     grappler::GrapplerItem item;
655     item.id = "tf_graph";
656 
657     // Add devices to the GrapplerItem
658     // It's ok to skip invalid device annotations in Grappler.
659     for (const Device* d : device_set_->devices()) {
660       Status added_device = item.AddDevice(d->name());
661       if (!added_device.ok()) VLOG(3) << added_device.error_message();
662     }
663     VLOG(3) << "Grappler available devices: "
664             << absl::StrJoin(item.devices(), ", ");
665 
666     // Add fetches to the GrapplerItem.
667     item.fetch.insert(item.fetch.end(),
668                       options.callable_options.fetch().begin(),
669                       options.callable_options.fetch().end());
670     item.fetch.insert(item.fetch.end(),
671                       options.callable_options.target().begin(),
672                       options.callable_options.target().end());
673 
674     for (const TensorConnection& tensor_connection :
675          options.callable_options.tensor_connection()) {
676       item.fetch.push_back(tensor_connection.from_tensor());
677     }
678 
679     // Add feeds to the GrapplerItem if we know them.
680     absl::flat_hash_set<absl::string_view> node_names;
681     if (!(options.callable_options.feed().empty() &&
682           options.callable_options.tensor_connection().empty())) {
683       std::vector<SafeTensorId> feeds;
684 
685       for (const string& feed : options.callable_options.feed()) {
686         feeds.emplace_back(ParseTensorName(feed));
687       }
688       for (const TensorConnection& tensor_connection :
689            options.callable_options.tensor_connection()) {
690         feeds.emplace_back(ParseTensorName(tensor_connection.to_tensor()));
691       }
692 
693       // For feeds with tensor index 0 we try to find the corresponding node in
694       // the graph to infer feed data type and shape.
695       absl::flat_hash_set<absl::string_view> feed_nodes;
696 
697       // For feeds with tensor index larger than 0, we can't infer data type or
698       // shape from the graph. Currently we only support type and shape
699       // inference from a small set of node types: Placeholder, Const, etc...
700       for (const SafeTensorId& feed : feeds) {
701         if (feed.index() > 0) {
702           VLOG(3) << "Add undefined feed for: " << feed.ToString();
703           Tensor fake_input(DT_INVALID, {0});
704           item.feed.emplace_back(feed.ToString(), fake_input);
705         } else {
706           VLOG(3) << "Add node for feed inference: " << feed.ToString();
707           feed_nodes.insert(feed.node());
708           continue;
709         }
710       }
711 
712       // For feeds with tensor index == 0 we try to infer data type and tensor
713       // shape from the graph, by looking at the fed node attributes.
714       node_names.reserve(graph_->num_nodes());
715       for (const Node* node : graph_->nodes()) {
716         node_names.insert(node->name());
717         if (feed_nodes.find(node->name()) == feed_nodes.end()) continue;
718 
719         // Try to get the type and shape of the feed node.
720         PartialTensorShape partial_shape;
721         DataType type;
722         Status st = GetFeedShapeAndTypeFromAttribute(node->def(),
723                                                      &partial_shape, &type);
724 
725         // Failed to get type and shape of the feed node.
726         if (!st.ok()) {
727           VLOG(3) << "Failed to infer feed node type and shape."
728                   << " Add undefined feed for: " << node->name();
729           Tensor fake_input(DT_INVALID, {0});
730           item.feed.emplace_back(node->name(), fake_input);
731           continue;
732         }
733 
734         // If the shape of the placeholder is only partially known, we are free
735         // to set unknown dimensions of its shape to any value we desire. We
736         // choose 0 to minimize the memory impact. Note that this only matters
737         // if an optimizer chooses to run the graph.
738         TensorShape shape;
739         if (partial_shape.unknown_rank()) {
740           shape = TensorShape({0});
741         } else {
742           for (int i = 0; i < partial_shape.dims(); ++i) {
743             if (partial_shape.dim_size(i) < 0) {
744               partial_shape.set_dim(i, 0);
745             }
746           }
747           if (!partial_shape.AsTensorShape(&shape)) {
748             return errors::InvalidArgument(
749                 "Could not derive shape for feed node: ",
750                 node->def().DebugString());
751           }
752         }
753 
754         VLOG(3) << "Add feed for: " << node->name() << "; type: " << type
755                 << "; shape: " << shape;
756         Tensor fake_input(type, shape);
757         item.feed.emplace_back(node->name(), fake_input);
758       }
759     }
760 
761     // Validate that the feeds and fetches are valid.
762     if (node_names.empty()) {
763       // Collect all node names in the graph if we didn't already.
764       node_names.reserve(graph_->num_nodes());
765       for (const Node* node : graph_->nodes()) {
766         node_names.insert(node->name());
767       }
768     }
769     for (const auto& feed : item.feed) {
770       SafeTensorId tensor_id = ParseTensorName(feed.first);
771       if (node_names.find(tensor_id.node()) == node_names.end()) {
772         return errors::InvalidArgument("Invalid feed, no such node in graph: ",
773                                        feed.first);
774       }
775     }
776     for (const auto& fetch : item.fetch) {
777       SafeTensorId tensor_id = ParseTensorName(fetch);
778       if (node_names.find(tensor_id.node()) == node_names.end()) {
779         return errors::InvalidArgument("Invalid fetch, no such node in graph: ",
780                                        fetch);
781       }
782     }
783 
784     // Convert Graph to GraphDef and add it to the GrapplerItem.
785     graph_->ToGraphDef(&item.graph);
786     // TODO(b/114748242): Add a unit test to test this bug fix.
787     if (flib_def_) {
788       *item.graph.mutable_library() = flib_def_->ToProto();
789     }
790 
791     // Construct a virtual cluster and find the cpu_device, which the
792     // ConstantFolding optimizer will use for partial evaluation of the graph.
793     grappler::VirtualCluster cluster(device_set_);
794     Device* cpu_device = nullptr;
795     for (const auto& device : device_set_->devices()) {
796       if (device->parsed_name().id == 0 &&
797           StringPiece(device->parsed_name().type) == "CPU" &&
798           device->GetAllocator(AllocatorAttributes()) != nullptr) {
799         cpu_device = device;
800       }
801     }
802 
803     // Now we can run the MetaOptimizer on the constructed GrapplerItem.
804     GraphDef new_graph;
805     TF_RETURN_IF_ERROR(
806         grappler::RunMetaOptimizer(std::move(item), session_options_->config,
807                                    cpu_device, &cluster, &new_graph));
808 
809     // Merge optimized graph function library with an original library.
810     // Optimized graph might have new functions specialized for it's
811     // instantiation context (see Grappler function optimizer), and modified
812     // function body for the existing functions.
813     optimized_flib->reset(new FunctionLibraryDefinition(*flib_def_));
814 
815     for (const FunctionDef& fdef : new_graph.library().function()) {
816       const string& func_name = fdef.signature().name();
817 
818       if ((*optimized_flib)->Contains(func_name)) {
819         VLOG(3) << "Replace function: name=" << func_name;
820         TF_RETURN_IF_ERROR((*optimized_flib)->ReplaceFunction(func_name, fdef));
821       } else {
822         VLOG(3) << "Add new function: name=" << func_name;
823         TF_RETURN_IF_ERROR((*optimized_flib)->AddFunctionDef(fdef));
824       }
825     }
826     optimized_graph->reset(new Graph(OpRegistry::Global()));
827 
828     // Convert the optimized GraphDef back to a Graph.
829     GraphConstructorOptions opts;
830     opts.allow_internal_ops = true;
831     TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, std::move(new_graph),
832                                               optimized_graph->get()));
833     // The graph conversion sets the requested device names but not the
834     // assigned device names. However, since at this point the graph is placed
835     // TF expects an assigned device name for every node. Therefore we copy
836     // the requested device into the assigned device field.
837     for (Node* node : optimized_graph->get()->nodes()) {
838       node->set_assigned_device_name(node->requested_device());
839     }
840     return Status::OK();
841   } else {
842     return errors::InvalidArgument("Meta Optimizer disabled");
843   }
844 #endif  // IS_MOBILE_PLATFORM
845 }
846 
BuildGraph(const BuildGraphOptions & options,std::unique_ptr<ClientGraph> * out)847 Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
848                                        std::unique_ptr<ClientGraph>* out) {
849   VLOG(1) << "BuildGraph";
850   const uint64 start_time_usecs = Env::Default()->NowMicros();
851   if (!graph_) {
852     // It is only valid to call this method directly when the original graph
853     // was created with the option `place_pruned_graph == false`.
854     return errors::Internal(
855         "Attempted to prune a graph that has not been fully initialized.");
856   }
857 
858   // Grappler optimization might change the structure of a graph itself, and
859   // also it can add/prune functions to/from the library.
860   std::unique_ptr<Graph> optimized_graph;
861   std::unique_ptr<FunctionLibraryDefinition> optimized_flib;
862 
863   Status s = OptimizeGraph(options, &optimized_graph, &optimized_flib);
864   if (!s.ok()) {
865     VLOG(2) << "Grappler optimization failed. Error: " << s.error_message();
866     // Simply copy the original graph and the function library if we couldn't
867     // optimize it.
868     optimized_graph.reset(new Graph(flib_def_.get()));
869     CopyGraph(*graph_, optimized_graph.get());
870     optimized_flib.reset(new FunctionLibraryDefinition(*flib_def_));
871   }
872 
873   subgraph::RewriteGraphMetadata rewrite_metadata;
874   if (session_options_ == nullptr ||
875       !session_options_->config.graph_options().place_pruned_graph()) {
876     TF_RETURN_IF_ERROR(
877         PruneGraph(options, optimized_graph.get(), &rewrite_metadata));
878   } else {
879     // This GraphExecutionState represents a graph that was
880     // pruned when this was constructed, so we copy the metadata from
881     // a member variable.
882     CHECK(rewrite_metadata_);
883     rewrite_metadata = *rewrite_metadata_;
884   }
885 
886   CHECK_EQ(options.callable_options.feed_size(),
887            rewrite_metadata.feed_types.size());
888   CHECK_EQ(options.callable_options.fetch_size(),
889            rewrite_metadata.fetch_types.size());
890 
891   // TODO(andydavis): Clarify optimization pass requirements around CostModel.
892   GraphOptimizationPassOptions optimization_options;
893   optimization_options.session_options = session_options_;
894   optimization_options.graph = &optimized_graph;
895   optimization_options.flib_def = optimized_flib.get();
896   optimization_options.device_set = device_set_;
897 
898   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
899       OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
900 
901   int64 collective_graph_key = options.collective_graph_key;
902   if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
903     // BuildGraphOptions does not specify a collective_graph_key.  Check all
904     // nodes in the Graph and FunctionLibraryDefinition for collective ops and
905     // if found, initialize a collective_graph_key as a hash of the ordered set
906     // of instance keys.
907     std::set<int32> instance_key_set;
908     bool has_collective_v2 = false;
909     for (Node* node : optimized_graph->nodes()) {
910       if (node->IsCollective()) {
911         int32 instance_key;
912         TF_RETURN_IF_ERROR(
913             GetNodeAttr(node->attrs(), "instance_key", &instance_key));
914         instance_key_set.emplace(instance_key);
915       } else if (IsCollectiveV2(node->type_string())) {
916         has_collective_v2 = true;
917       } else {
918         const FunctionDef* fdef = optimized_flib->Find(node->def().op());
919         if (fdef != nullptr) {
920           for (const NodeDef& ndef : fdef->node_def()) {
921             if (ndef.op() == "CollectiveReduce" ||
922                 ndef.op() == "CollectiveBcastSend" ||
923                 ndef.op() == "CollectiveBcastRecv" ||
924                 ndef.op() == "CollectiveGather") {
925               int32 instance_key;
926               TF_RETURN_IF_ERROR(
927                   GetNodeAttr(ndef, "instance_key", &instance_key));
928               instance_key_set.emplace(instance_key);
929             } else if (IsCollectiveV2(ndef.op())) {
930               has_collective_v2 = true;
931             }
932           }
933         }
934       }
935     }
936     if (!instance_key_set.empty()) {
937       uint64 hash = 0x8774aa605c729c72ULL;
938       for (int32 instance_key : instance_key_set) {
939         hash = Hash64Combine(instance_key, hash);
940       }
941       collective_graph_key = hash;
942     } else if (has_collective_v2) {
943       collective_graph_key = 0x8774aa605c729c72ULL;
944     }
945   }
946 
947   // Make collective execution order deterministic if needed.
948   if (options.collective_order != GraphCollectiveOrder::kNone) {
949     TF_RETURN_IF_ERROR(
950         OrderCollectives(optimized_graph.get(), options.collective_order));
951   }
952 
953   // Copy the extracted graph in order to make its node ids dense,
954   // since the local CostModel used to record its stats is sized by
955   // the largest node id.
956   std::unique_ptr<ClientGraph> dense_copy(
957       new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types,
958                       rewrite_metadata.fetch_types, collective_graph_key));
959   CopyGraph(*optimized_graph, &dense_copy->graph);
960 
961   // TODO(vrv): We should check invariants of the graph here.
962   metrics::UpdateGraphBuildTime(Env::Default()->NowMicros() - start_time_usecs);
963   *out = std::move(dense_copy);
964   return Status::OK();
965 }
966 
967 }  // namespace tensorflow
968