1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/graph_execution_state.h"
17
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/memory/memory.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/graph_constructor.h"
30 #include "tensorflow/core/common_runtime/metrics.h"
31 #include "tensorflow/core/common_runtime/optimization_registry.h"
32 #include "tensorflow/core/common_runtime/placer.h"
33 #include "tensorflow/core/framework/attr_value.pb.h"
34 #include "tensorflow/core/framework/function.h"
35 #include "tensorflow/core/framework/function.pb.h"
36 #include "tensorflow/core/framework/graph.pb.h"
37 #include "tensorflow/core/framework/graph_def_util.h"
38 #include "tensorflow/core/framework/node_def.pb.h"
39 #include "tensorflow/core/framework/op.h"
40 #include "tensorflow/core/framework/tensor.pb.h"
41 #include "tensorflow/core/framework/versions.pb.h"
42 #include "tensorflow/core/graph/algorithm.h"
43 #include "tensorflow/core/graph/collective_order.h"
44 #include "tensorflow/core/graph/graph.h"
45 #include "tensorflow/core/graph/subgraph.h"
46 #include "tensorflow/core/graph/tensor_id.h"
47 #include "tensorflow/core/graph/validate.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/gtl/flatset.h"
51 #include "tensorflow/core/lib/strings/stringprintf.h"
52 #include "tensorflow/core/platform/logging.h"
53 #include "tensorflow/core/platform/types.h"
54 #include "tensorflow/core/util/device_name_utils.h"
55 #include "tensorflow/core/util/util.h"
56
57 #ifndef IS_MOBILE_PLATFORM
58 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
59 #include "tensorflow/core/grappler/grappler_item.h"
60 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
61 #endif // IS_MOBILE_PLATFORM
62
63 namespace tensorflow {
64
65 namespace {
IsCollectiveV2(const string & op)66 bool IsCollectiveV2(const string& op) {
67 return op == "CollectiveReduceV2" || op == "CollectiveGatherV2" ||
68 op == "CollectiveBcastRecvV2" || op == "CollectiveBcastSendV2";
69 }
70 } // namespace
71
GraphExecutionState(std::unique_ptr<GraphDef> && graph_def,std::unique_ptr<FunctionLibraryDefinition> && flib_def,const GraphExecutionStateOptions & options)72 GraphExecutionState::GraphExecutionState(
73 std::unique_ptr<GraphDef>&& graph_def,
74 std::unique_ptr<FunctionLibraryDefinition>&& flib_def,
75 const GraphExecutionStateOptions& options)
76 : stateful_placements_(options.stateful_placements),
77 original_graph_def_(std::move(graph_def)),
78 device_set_(options.device_set),
79 session_options_(options.session_options),
80 session_handle_(options.session_handle),
81 flib_def_(std::move(flib_def)),
82 graph_(nullptr) {}
83
~GraphExecutionState()84 GraphExecutionState::~GraphExecutionState() {
85 node_name_to_cost_id_map_.clear();
86 delete graph_;
87 }
88
MakeForBaseGraph(GraphDef && graph_def,const GraphExecutionStateOptions & options,std::unique_ptr<GraphExecutionState> * out_state)89 /* static */ Status GraphExecutionState::MakeForBaseGraph(
90 GraphDef&& graph_def, const GraphExecutionStateOptions& options,
91 std::unique_ptr<GraphExecutionState>* out_state) {
92 #ifndef __ANDROID__
93 VLOG(4) << "Graph proto is \n" << graph_def.DebugString();
94 #endif // __ANDROID__
95
96 auto flib_def = absl::make_unique<FunctionLibraryDefinition>(
97 OpRegistry::Global(), graph_def.library());
98
99 TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&graph_def, *flib_def, 0));
100
101 if (options.session_options->config.graph_options().place_pruned_graph() ||
102 !options.session_options->config.experimental()
103 .optimize_for_static_graph()) {
104 auto ret = absl::WrapUnique(new GraphExecutionState(
105 absl::make_unique<GraphDef>(std::move(graph_def)), std::move(flib_def),
106 options));
107
108 // When place_pruned_graph is true, a different Graph* will be initialized
109 // each time we prune the original graph, so there is no need to
110 // construct a Graph* in this case.
111 if (!options.session_options->config.graph_options().place_pruned_graph()) {
112 auto base_graph = absl::make_unique<Graph>(OpRegistry::Global());
113 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *ret->original_graph_def_,
114 base_graph.get()));
115 TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph)));
116 }
117 *out_state = std::move(ret);
118 } else {
119 auto ret = absl::WrapUnique(
120 new GraphExecutionState(nullptr, std::move(flib_def), options));
121 auto base_graph = absl::make_unique<Graph>(OpRegistry::Global());
122 TF_RETURN_IF_ERROR(
123 ConvertGraphDefToGraph({}, std::move(graph_def), base_graph.get()));
124 TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph)));
125 *out_state = std::move(ret);
126 }
127 return Status::OK();
128 }
129
MakeForPrunedGraph(const GraphExecutionState & base_execution_state,const GraphExecutionStateOptions & options,const BuildGraphOptions & subgraph_options,std::unique_ptr<GraphExecutionState> * out_state,std::unique_ptr<ClientGraph> * out_client_graph)130 /* static */ Status GraphExecutionState::MakeForPrunedGraph(
131 const GraphExecutionState& base_execution_state,
132 const GraphExecutionStateOptions& options,
133 const BuildGraphOptions& subgraph_options,
134 std::unique_ptr<GraphExecutionState>* out_state,
135 std::unique_ptr<ClientGraph>* out_client_graph) {
136 if (!(base_execution_state.session_options_->config.graph_options()
137 .place_pruned_graph() &&
138 options.session_options->config.graph_options().place_pruned_graph())) {
139 return errors::Internal(
140 "MakeForPrunedGraph is only supported when the `place_pruned_graph` "
141 "option is true.");
142 }
143 if (!base_execution_state.original_graph_def_) {
144 // NOTE(mrry): By adding this restriction, which matches the only current
145 // usage of this (fairly obscure) method, we do not need to store a
146 // redundant copy of the original graph in `*out_state`.
147 return errors::Internal(
148 "MakeForPrunedGraph is only supported when `base_execution_state` is "
149 "the Session-level `GraphExecutionState`.");
150 }
151
152 // NOTE(mrry): This makes a copy of `graph_def`, which is
153 // regrettable. We could make `GraphDef` objects sharable between
154 // execution states to optimize pruned graph execution, but since
155 // this case is primarily used for interactive sessions, we make the
156 // bet that graph construction is not performance-critical. (Note
157 // also that the previous version used `Extend()`, which is strictly
158 // more expensive than copying a `GraphDef`.)
159 GraphDef temp(*base_execution_state.original_graph_def_);
160 auto flib_def = absl::make_unique<FunctionLibraryDefinition>(
161 OpRegistry::Global(), temp.library());
162 TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&temp, *flib_def, 0));
163 auto ret = absl::WrapUnique(
164 new GraphExecutionState(nullptr, std::move(flib_def), options));
165
166 auto base_graph = absl::make_unique<Graph>(OpRegistry::Global());
167 TF_RETURN_IF_ERROR(
168 ConvertGraphDefToGraph({}, std::move(temp), base_graph.get()));
169
170 // Rewrite the graph before placement.
171 ret->rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata);
172 TF_RETURN_IF_ERROR(ret->PruneGraph(subgraph_options, base_graph.get(),
173 ret->rewrite_metadata_.get()));
174 TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph)));
175 TF_RETURN_IF_ERROR(ret->BuildGraph(subgraph_options, out_client_graph));
176 *out_state = std::move(ret);
177 return Status::OK();
178 }
179
Extend(const GraphDef & extension_def,std::unique_ptr<GraphExecutionState> * out) const180 Status GraphExecutionState::Extend(
181 const GraphDef& extension_def,
182 std::unique_ptr<GraphExecutionState>* out) const {
183 if (session_options_->config.experimental().optimize_for_static_graph()) {
184 return errors::FailedPrecondition(
185 "Extending the graph is not supported when "
186 "`optimize_for_static_graph` is true.");
187 }
188
189 GraphDef gdef;
190
191 // 1. Copy the function library.
192 TF_RETURN_IF_ERROR(flib_def_->AddLibrary(extension_def.library()));
193 *gdef.mutable_library() = flib_def_->ToProto();
194
195 // 2. Build an index of the new node names.
196 std::unordered_set<string> new_names;
197 for (const NodeDef& node : extension_def.node()) {
198 new_names.insert(node.name());
199 }
200
201 // 3. Add the non-duplicates from the old graph to the new graph.
202 // Return an error if the same node name appears in both the
203 // old graph and the extension.
204 for (const NodeDef& node : original_graph_def_->node()) {
205 if (new_names.count(node.name()) == 0) {
206 *gdef.add_node() = node;
207 } else {
208 return errors::InvalidArgument(
209 "GraphDef argument to Extend includes node '", node.name(),
210 "', which was created by a previous call to Create or Extend in this "
211 "session.");
212 }
213 }
214
215 // 4. Merge the versions field.
216 int old_node_size = gdef.node_size();
217 gdef.mutable_node()->MergeFrom(extension_def.node());
218 TF_RETURN_IF_ERROR(
219 AddDefaultAttrsToGraphDef(&gdef, *flib_def_, old_node_size));
220 // Merge versions
221 if (gdef.has_versions()) {
222 if (gdef.versions().producer() != extension_def.versions().producer()) {
223 return errors::InvalidArgument(
224 "Can't extend GraphDef at version ", gdef.versions().producer(),
225 " with graph at version ", extension_def.versions().producer());
226 }
227 VersionDef* versions = gdef.mutable_versions();
228 versions->set_min_consumer(std::max(
229 versions->min_consumer(), extension_def.versions().min_consumer()));
230 if (extension_def.versions().bad_consumers_size()) {
231 // Add new bad_consumers that aren't already marked bad.
232 //
233 // Note: This implementation is quadratic time if there are many calls to
234 // ExtendLocked with many bad consumers. Since this is unlikely, and
235 // fixing it would require data structures outside of this routine,
236 // quadratic time it is.
237 auto* bad_consumers = versions->mutable_bad_consumers();
238 const std::unordered_set<int> existing(bad_consumers->begin(),
239 bad_consumers->end());
240 for (const int v : extension_def.versions().bad_consumers()) {
241 if (existing.find(v) == existing.end()) {
242 bad_consumers->Add(v);
243 }
244 }
245 }
246
247 } else {
248 gdef.mutable_versions()->CopyFrom(extension_def.versions());
249 }
250
251 // 5. Validate that the final graphdef is valid.
252 if (gdef.versions().producer() >= 5) {
253 // Validate the graph: we assume that merging two valid graphs
254 // should maintain graph validity.
255 TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *flib_def_));
256 }
257
258 // 6. Add the extension.
259 GraphExecutionStateOptions combined_options;
260 combined_options.device_set = device_set_;
261 combined_options.session_options = session_options_;
262 combined_options.session_handle = session_handle_;
263 combined_options.stateful_placements = stateful_placements_;
264
265 TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&gdef, *flib_def_, 0));
266 auto flib_def = absl::make_unique<FunctionLibraryDefinition>(
267 OpRegistry::Global(), gdef.library());
268 auto new_execution_state = absl::WrapUnique(
269 new GraphExecutionState(absl::make_unique<GraphDef>(std::move(gdef)),
270 std::move(flib_def), combined_options));
271
272 if (!session_options_->config.graph_options().place_pruned_graph()) {
273 auto base_graph = absl::make_unique<Graph>(OpRegistry::Global());
274 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
275 {}, *new_execution_state->original_graph_def_, base_graph.get()));
276 TF_RETURN_IF_ERROR(
277 new_execution_state->InitBaseGraph(std::move(base_graph)));
278 }
279 *out = std::move(new_execution_state);
280
281 // NOTE(mrry): Extend() is likely to be used for non-throughput-sensitive
282 // interactive workloads, but in future we may want to transfer other
283 // parts of the placement and/or cost model.
284 return Status::OK();
285 }
286
SaveStatefulNodes(Graph * graph)287 void GraphExecutionState::SaveStatefulNodes(Graph* graph) {
288 for (Node* n : graph->nodes()) {
289 if (n->op_def().is_stateful()) {
290 VLOG(2) << "Saving " << n->DebugString();
291 stateful_placements_[n->name()] = n->assigned_device_name();
292 }
293 }
294 }
295
RestoreStatefulNodes(Graph * graph)296 void GraphExecutionState::RestoreStatefulNodes(Graph* graph) {
297 for (Node* n : graph->nodes()) {
298 if (n->op_def().is_stateful()) {
299 auto iter = stateful_placements_.find(n->name());
300 if (iter != stateful_placements_.end()) {
301 n->set_assigned_device_name(iter->second);
302 VLOG(2) << "Restored " << n->DebugString();
303 }
304 }
305 }
306 }
307
308 namespace {
309
310 class TensorConnectionPruneRewrite : public subgraph::PruneRewrite {
311 public:
TensorConnectionPruneRewrite(const string * endpoint_name,NodeBuilder::NodeOut from_tensor)312 TensorConnectionPruneRewrite(const string* endpoint_name,
313 NodeBuilder::NodeOut from_tensor)
314 : subgraph::PruneRewrite(endpoint_name, nullptr /* device_info */),
315 from_tensor_(std::move(from_tensor)) {}
316
AddNode(Graph * g,NodeBuilder::NodeOut feed_tensor,Node ** out_node)317 Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
318 Node** out_node) override {
319 Status s;
320 auto check_no_cycle_fn = [this, feed_tensor, &s](Node* n) {
321 if (n == feed_tensor.node) {
322 s.Update(errors::InvalidArgument(
323 "Requested Tensor connection between nodes \"",
324 feed_tensor.node->name(), "\" and \"", from_tensor_.node->name(),
325 "\" would create a cycle."));
326 }
327 };
328 ReverseDFSFrom(*g, {from_tensor_.node}, std::move(check_no_cycle_fn),
329 nullptr);
330 TF_RETURN_IF_ERROR(s);
331
332 TF_RETURN_IF_ERROR(
333 NodeBuilder(strings::StrCat("_identity_", feed_tensor.node->name(), "_",
334 feed_tensor.index),
335 "Identity")
336 .Input(from_tensor_)
337 .Attr("T",
338 BaseType(from_tensor_.node->output_type(from_tensor_.index)))
339 .Finalize(g, out_node));
340
341 (*out_node)->set_assigned_device_name(
342 feed_tensor.node->assigned_device_name());
343 return Status::OK();
344 }
345
346 private:
347 NodeBuilder::NodeOut from_tensor_;
348 };
349
350 template <class Map>
LookupDevice(const DeviceSet & device_set,const string & tensor_name,const Map & tensor2device,const tensorflow::DeviceAttributes ** out_device_attrs)351 Status LookupDevice(const DeviceSet& device_set, const string& tensor_name,
352 const Map& tensor2device,
353 const tensorflow::DeviceAttributes** out_device_attrs) {
354 *out_device_attrs = nullptr;
355 if (tensor2device.empty()) {
356 *out_device_attrs = &device_set.client_device()->attributes();
357 return Status::OK();
358 }
359 const auto it = tensor2device.find(tensor_name);
360 if (it == tensor2device.end()) {
361 *out_device_attrs = &device_set.client_device()->attributes();
362 return Status::OK();
363 }
364 DeviceNameUtils::ParsedName parsed_name;
365 if (!DeviceNameUtils::ParseFullName(it->second, &parsed_name)) {
366 return errors::InvalidArgument("Invalid device name ('", it->second,
367 "') provided for the tensor '", tensor_name,
368 "' in CallableOptions");
369 }
370 Device* device = device_set.FindDeviceByName(
371 DeviceNameUtils::ParsedNameToString(parsed_name));
372 if (device == nullptr) {
373 return errors::InvalidArgument("Device '", it->second,
374 "' specified for tensor '", tensor_name,
375 "' in CallableOptions does not exist");
376 }
377 *out_device_attrs = &device->attributes();
378 return Status::OK();
379 }
380
381 struct TensorAndDevice {
382 // WARNING: backing memory for the 'tensor' field is NOT owend.
383 const TensorId tensor;
384 // WARNING: device pointer is not owned, so must outlive TensorAndDevice.
385 const DeviceAttributes* device;
386 };
387
388 // Tensors of some DataTypes cannot placed in device memory as feeds or
389 // fetches. Validate against a allowlist of those known to work.
IsFeedAndFetchSupported(DataType dtype,const string & device_type)390 bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) {
391 // The mechanism for supporting feeds of device-backed Tensors requires
392 // the _Arg kernel to be registered for the corresponding type (and that
393 // the input to the kernel be in device and not host memory).
394 //
395 // The mechanism for supporting fetches of device-backed Tensors requires
396 // the _Retval kernel to be registered for the corresponding type (and
397 // that the output is produced in device and not host memory).
398 //
399 // For now, we return true iff there are _Arg AND _Retval kernels for dtype on
400 // the device. False negatives are okay, false positives would be bad.
401 //
402 // TODO(ashankar): Instead of a allowlist here, perhaps we could query
403 // the kernel registry for _Arg and _Retval kernels instead.
404 if (device_type == DEVICE_CPU) return true;
405 if (device_type != DEVICE_GPU) return false;
406 switch (dtype) {
407 case DT_BFLOAT16:
408 case DT_BOOL:
409 case DT_COMPLEX128:
410 case DT_COMPLEX64:
411 case DT_DOUBLE:
412 case DT_FLOAT:
413 case DT_HALF:
414 case DT_INT16:
415 case DT_INT64:
416 case DT_INT8:
417 case DT_UINT16:
418 case DT_UINT8:
419 return true;
420 default:
421 return false;
422 }
423 }
424
ValidateFeedAndFetchDevices(const Graph & graph,const std::vector<TensorAndDevice> & tensors_and_devices)425 Status ValidateFeedAndFetchDevices(
426 const Graph& graph,
427 const std::vector<TensorAndDevice>& tensors_and_devices) {
428 if (tensors_and_devices.empty()) return Status::OK();
429 std::vector<bool> found(tensors_and_devices.size(), false);
430 for (const Node* node : graph.nodes()) {
431 // Linearly looping through all nodes and then all feed+fetch tensors isn't
432 // quite efficient. At the time of this writing, the expectation was that
433 // tensors_and_devices.size() is really small in practice, so this won't be
434 // problematic.
435 // Revist and make a more efficient lookup possible if needed (e.g., perhaps
436 // Graph can maintain a map from node name to Node*).
437 for (int i = 0; i < tensors_and_devices.size(); ++i) {
438 const TensorAndDevice& td = tensors_and_devices[i];
439 if (td.tensor.first != node->name()) continue;
440 found[i] = true;
441 TF_RETURN_IF_ERROR(graph.IsValidOutputTensor(node, td.tensor.second));
442 const DataType dtype = node->output_type(td.tensor.second);
443 if (!IsFeedAndFetchSupported(dtype, td.device->device_type())) {
444 return errors::Unimplemented(
445 "Cannot feed or fetch tensor '", td.tensor.ToString(),
446 "' from device ", td.device->name(), " as feeding/fetching from ",
447 td.device->device_type(), " devices is not yet supported for ",
448 DataTypeString(dtype), " tensors");
449 }
450 }
451 }
452 for (int i = 0; i < found.size(); ++i) {
453 if (!found[i]) {
454 return errors::InvalidArgument(
455 "Tensor ", tensors_and_devices[i].tensor.ToString(),
456 ", specified in either feed_devices or fetch_devices was not found "
457 "in the Graph");
458 }
459 }
460 return Status::OK();
461 }
462
GetFeedShapeAndTypeFromAttribute(const NodeDef & node,PartialTensorShape * shape,DataType * type)463 Status GetFeedShapeAndTypeFromAttribute(const NodeDef& node,
464 PartialTensorShape* shape,
465 DataType* type) {
466 static const gtl::FlatSet<string>* const kHasExplicitShapeAttribute =
467 CHECK_NOTNULL((new gtl::FlatSet<string>{
468 "Placeholder", "PlaceholderV2", "PlaceholderWithDefault",
469 "ParallelConcat", "ImmutableConst", "_ParallelConcatStart",
470 "InfeedDequeue", "OutfeedDequeue", "CollectiveBcastSend",
471 "CollectiveBcastRecv", "AccumulateNV2", "VariableV2", "Variable",
472 "TemporaryVariable", "NcclBroadcast", "_ScopedAllocator",
473 "_ScopedAllocatorConcat"}));
474
475 // All the node types handled here have their output datatype set in
476 // either attribute 'dtype' or 'T'.
477 if (!TryGetNodeAttr(node, "dtype", type) &&
478 !TryGetNodeAttr(node, "T", type)) {
479 return errors::InvalidArgument(
480 "Could not determine output type for feed node: ", node.name(),
481 " of type ", node.op());
482 }
483
484 // First handle the case of feeding a const node.
485 if (node.op() == "Const" && HasNodeAttr(node, "value")) {
486 *shape =
487 PartialTensorShape(node.attr().at("value").tensor().tensor_shape());
488 } else if (kHasExplicitShapeAttribute->find(node.op()) !=
489 kHasExplicitShapeAttribute->end()) {
490 TF_RETURN_IF_ERROR(GetNodeAttr(node, "shape", shape));
491 } else {
492 return errors::InvalidArgument("Could not determine shape for feed node: ",
493 node.name(), " of type ", node.op());
494 }
495 return Status::OK();
496 }
497
498 } // namespace
499
PruneGraph(const BuildGraphOptions & options,Graph * graph,subgraph::RewriteGraphMetadata * out_rewrite_metadata)500 Status GraphExecutionState::PruneGraph(
501 const BuildGraphOptions& options, Graph* graph,
502 subgraph::RewriteGraphMetadata* out_rewrite_metadata) {
503 std::vector<std::unique_ptr<subgraph::PruneRewrite>> feed_rewrites;
504 feed_rewrites.reserve(options.callable_options.feed_size());
505 std::vector<std::unique_ptr<subgraph::PruneRewrite>> fetch_rewrites;
506 fetch_rewrites.reserve(options.callable_options.fetch_size());
507 if (options.use_function_convention) {
508 std::vector<TensorAndDevice> tensors_and_devices;
509 for (int i = 0; i < options.callable_options.feed_size(); ++i) {
510 // WARNING: feed MUST be a reference, since ArgFeedRewrite and
511 // tensors_and_devices holds on to its address.
512 const string& feed = options.callable_options.feed(i);
513 const DeviceAttributes* device_info;
514 TF_RETURN_IF_ERROR(LookupDevice(*device_set_, feed,
515 options.callable_options.feed_devices(),
516 &device_info));
517 feed_rewrites.emplace_back(
518 new subgraph::ArgFeedRewrite(&feed, device_info, i));
519 tensors_and_devices.push_back({ParseTensorName(feed), device_info});
520 }
521 if (!options.callable_options.fetch_devices().empty() &&
522 !options.callable_options.fetch_skip_sync()) {
523 return errors::Unimplemented(
524 "CallableOptions.fetch_skip_sync = false is not yet implemented. You "
525 "can set it to true instead, but MUST ensure that Device::Sync() is "
526 "invoked on the Device corresponding to the fetched tensor before "
527 "dereferencing the Tensor's memory.");
528 }
529 for (int i = 0; i < options.callable_options.fetch_size(); ++i) {
530 // WARNING: fetch MUST be a reference, since RetvalFetchRewrite and
531 // tensors_and_devices holds on to its address.
532 const string& fetch = options.callable_options.fetch(i);
533 const DeviceAttributes* device_info;
534 TF_RETURN_IF_ERROR(LookupDevice(*device_set_, fetch,
535 options.callable_options.fetch_devices(),
536 &device_info));
537 fetch_rewrites.emplace_back(
538 new subgraph::RetvalFetchRewrite(&fetch, device_info, i));
539 tensors_and_devices.push_back({ParseTensorName(fetch), device_info});
540 }
541 TF_RETURN_IF_ERROR(
542 ValidateFeedAndFetchDevices(*graph, tensors_and_devices));
543 } else {
544 if (!options.callable_options.feed_devices().empty() ||
545 !options.callable_options.fetch_devices().empty()) {
546 return errors::Unimplemented(
547 "CallableOptions::feed_devices and CallableOptions::fetch_devices "
548 "to configure feeding/fetching tensors to/from device memory is not "
549 "yet supported when using a remote session.");
550 }
551 const DeviceAttributes* device_info =
552 &device_set_->client_device()->attributes();
553 for (const string& feed : options.callable_options.feed()) {
554 feed_rewrites.emplace_back(
555 new subgraph::RecvFeedRewrite(&feed, device_info));
556 }
557 for (const string& fetch : options.callable_options.fetch()) {
558 fetch_rewrites.emplace_back(
559 new subgraph::SendFetchRewrite(&fetch, device_info));
560 }
561 }
562
563 for (const TensorConnection& tensor_connection :
564 options.callable_options.tensor_connection()) {
565 Node* from_node = nullptr;
566 TensorId from_id(ParseTensorName(tensor_connection.from_tensor()));
567
568 for (Node* n : graph->nodes()) {
569 if (n->name() == from_id.first) {
570 from_node = n;
571 break;
572 }
573 }
574 if (from_node == nullptr) {
575 return errors::InvalidArgument(
576 "Requested tensor connection from unknown node: \"",
577 tensor_connection.to_tensor(), "\".");
578 }
579 if (from_id.second >= from_node->num_outputs()) {
580 return errors::InvalidArgument(
581 "Requested tensor connection from unknown edge: \"",
582 tensor_connection.to_tensor(),
583 "\" (actual number of outputs = ", from_node->num_outputs(), ").");
584 }
585
586 feed_rewrites.emplace_back(new TensorConnectionPruneRewrite(
587 &tensor_connection.to_tensor(), {from_node, from_id.second}));
588 }
589
590 std::vector<string> target_node_names(
591 options.callable_options.target().begin(),
592 options.callable_options.target().end());
593 TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
594 graph, feed_rewrites, fetch_rewrites, target_node_names,
595 out_rewrite_metadata));
596
597 CHECK_EQ(out_rewrite_metadata->feed_types.size(),
598 options.callable_options.feed_size() +
599 options.callable_options.tensor_connection_size());
600 for (int i = 0; i < options.callable_options.tensor_connection_size(); ++i) {
601 out_rewrite_metadata->feed_types.pop_back();
602 }
603 return Status::OK();
604 }
605
InitBaseGraph(std::unique_ptr<Graph> && new_graph)606 Status GraphExecutionState::InitBaseGraph(std::unique_ptr<Graph>&& new_graph) {
607 // Save stateful placements before placing.
608 RestoreStatefulNodes(new_graph.get());
609
610 GraphOptimizationPassOptions optimization_options;
611 optimization_options.session_handle = session_handle_;
612 optimization_options.session_options = session_options_;
613 optimization_options.graph = &new_graph;
614 optimization_options.flib_def = flib_def_.get();
615 optimization_options.device_set = device_set_;
616
617 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
618 OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
619
620 Placer placer(new_graph.get(), "", flib_def_.get(), device_set_,
621 /* default_local_device= */ nullptr,
622 session_options_ == nullptr ||
623 session_options_->config.allow_soft_placement(),
624 session_options_ != nullptr &&
625 session_options_->config.log_device_placement());
626 // TODO(mrry): Consider making the Placer cancellable.
627 TF_RETURN_IF_ERROR(placer.Run());
628
629 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
630 OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
631
632 for (const Node* n : new_graph->nodes()) {
633 VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id();
634 node_name_to_cost_id_map_[n->name()] = n->cost_id();
635 }
636
637 SaveStatefulNodes(new_graph.get());
638 graph_ = new_graph.release();
639 return Status::OK();
640 }
641
OptimizeGraph(const BuildGraphOptions & options,std::unique_ptr<Graph> * optimized_graph,std::unique_ptr<FunctionLibraryDefinition> * optimized_flib)642 Status GraphExecutionState::OptimizeGraph(
643 const BuildGraphOptions& options, std::unique_ptr<Graph>* optimized_graph,
644 std::unique_ptr<FunctionLibraryDefinition>* optimized_flib) {
645 #ifdef IS_MOBILE_PLATFORM
646 return errors::InvalidArgument("Mobile platforms not supported");
647 #else
648 if (session_options_->config.graph_options().place_pruned_graph()) {
649 return errors::InvalidArgument("Can't optimize a pruned graph");
650 }
651
652 if (grappler::MetaOptimizerEnabled(session_options_->config)) {
653 // Here we build the GrapplerItem before calling the optimizer.
654 grappler::GrapplerItem item;
655 item.id = "tf_graph";
656
657 // Add devices to the GrapplerItem
658 // It's ok to skip invalid device annotations in Grappler.
659 for (const Device* d : device_set_->devices()) {
660 Status added_device = item.AddDevice(d->name());
661 if (!added_device.ok()) VLOG(3) << added_device.error_message();
662 }
663 VLOG(3) << "Grappler available devices: "
664 << absl::StrJoin(item.devices(), ", ");
665
666 // Add fetches to the GrapplerItem.
667 item.fetch.insert(item.fetch.end(),
668 options.callable_options.fetch().begin(),
669 options.callable_options.fetch().end());
670 item.fetch.insert(item.fetch.end(),
671 options.callable_options.target().begin(),
672 options.callable_options.target().end());
673
674 for (const TensorConnection& tensor_connection :
675 options.callable_options.tensor_connection()) {
676 item.fetch.push_back(tensor_connection.from_tensor());
677 }
678
679 // Add feeds to the GrapplerItem if we know them.
680 absl::flat_hash_set<absl::string_view> node_names;
681 if (!(options.callable_options.feed().empty() &&
682 options.callable_options.tensor_connection().empty())) {
683 std::vector<SafeTensorId> feeds;
684
685 for (const string& feed : options.callable_options.feed()) {
686 feeds.emplace_back(ParseTensorName(feed));
687 }
688 for (const TensorConnection& tensor_connection :
689 options.callable_options.tensor_connection()) {
690 feeds.emplace_back(ParseTensorName(tensor_connection.to_tensor()));
691 }
692
693 // For feeds with tensor index 0 we try to find the corresponding node in
694 // the graph to infer feed data type and shape.
695 absl::flat_hash_set<absl::string_view> feed_nodes;
696
697 // For feeds with tensor index larger than 0, we can't infer data type or
698 // shape from the graph. Currently we only support type and shape
699 // inference from a small set of node types: Placeholder, Const, etc...
700 for (const SafeTensorId& feed : feeds) {
701 if (feed.index() > 0) {
702 VLOG(3) << "Add undefined feed for: " << feed.ToString();
703 Tensor fake_input(DT_INVALID, {0});
704 item.feed.emplace_back(feed.ToString(), fake_input);
705 } else {
706 VLOG(3) << "Add node for feed inference: " << feed.ToString();
707 feed_nodes.insert(feed.node());
708 continue;
709 }
710 }
711
712 // For feeds with tensor index == 0 we try to infer data type and tensor
713 // shape from the graph, by looking at the fed node attributes.
714 node_names.reserve(graph_->num_nodes());
715 for (const Node* node : graph_->nodes()) {
716 node_names.insert(node->name());
717 if (feed_nodes.find(node->name()) == feed_nodes.end()) continue;
718
719 // Try to get the type and shape of the feed node.
720 PartialTensorShape partial_shape;
721 DataType type;
722 Status st = GetFeedShapeAndTypeFromAttribute(node->def(),
723 &partial_shape, &type);
724
725 // Failed to get type and shape of the feed node.
726 if (!st.ok()) {
727 VLOG(3) << "Failed to infer feed node type and shape."
728 << " Add undefined feed for: " << node->name();
729 Tensor fake_input(DT_INVALID, {0});
730 item.feed.emplace_back(node->name(), fake_input);
731 continue;
732 }
733
734 // If the shape of the placeholder is only partially known, we are free
735 // to set unknown dimensions of its shape to any value we desire. We
736 // choose 0 to minimize the memory impact. Note that this only matters
737 // if an optimizer chooses to run the graph.
738 TensorShape shape;
739 if (partial_shape.unknown_rank()) {
740 shape = TensorShape({0});
741 } else {
742 for (int i = 0; i < partial_shape.dims(); ++i) {
743 if (partial_shape.dim_size(i) < 0) {
744 partial_shape.set_dim(i, 0);
745 }
746 }
747 if (!partial_shape.AsTensorShape(&shape)) {
748 return errors::InvalidArgument(
749 "Could not derive shape for feed node: ",
750 node->def().DebugString());
751 }
752 }
753
754 VLOG(3) << "Add feed for: " << node->name() << "; type: " << type
755 << "; shape: " << shape;
756 Tensor fake_input(type, shape);
757 item.feed.emplace_back(node->name(), fake_input);
758 }
759 }
760
761 // Validate that the feeds and fetches are valid.
762 if (node_names.empty()) {
763 // Collect all node names in the graph if we didn't already.
764 node_names.reserve(graph_->num_nodes());
765 for (const Node* node : graph_->nodes()) {
766 node_names.insert(node->name());
767 }
768 }
769 for (const auto& feed : item.feed) {
770 SafeTensorId tensor_id = ParseTensorName(feed.first);
771 if (node_names.find(tensor_id.node()) == node_names.end()) {
772 return errors::InvalidArgument("Invalid feed, no such node in graph: ",
773 feed.first);
774 }
775 }
776 for (const auto& fetch : item.fetch) {
777 SafeTensorId tensor_id = ParseTensorName(fetch);
778 if (node_names.find(tensor_id.node()) == node_names.end()) {
779 return errors::InvalidArgument("Invalid fetch, no such node in graph: ",
780 fetch);
781 }
782 }
783
784 // Convert Graph to GraphDef and add it to the GrapplerItem.
785 graph_->ToGraphDef(&item.graph);
786 // TODO(b/114748242): Add a unit test to test this bug fix.
787 if (flib_def_) {
788 *item.graph.mutable_library() = flib_def_->ToProto();
789 }
790
791 // Construct a virtual cluster and find the cpu_device, which the
792 // ConstantFolding optimizer will use for partial evaluation of the graph.
793 grappler::VirtualCluster cluster(device_set_);
794 Device* cpu_device = nullptr;
795 for (const auto& device : device_set_->devices()) {
796 if (device->parsed_name().id == 0 &&
797 StringPiece(device->parsed_name().type) == "CPU" &&
798 device->GetAllocator(AllocatorAttributes()) != nullptr) {
799 cpu_device = device;
800 }
801 }
802
803 // Now we can run the MetaOptimizer on the constructed GrapplerItem.
804 GraphDef new_graph;
805 TF_RETURN_IF_ERROR(
806 grappler::RunMetaOptimizer(std::move(item), session_options_->config,
807 cpu_device, &cluster, &new_graph));
808
809 // Merge optimized graph function library with an original library.
810 // Optimized graph might have new functions specialized for it's
811 // instantiation context (see Grappler function optimizer), and modified
812 // function body for the existing functions.
813 optimized_flib->reset(new FunctionLibraryDefinition(*flib_def_));
814
815 for (const FunctionDef& fdef : new_graph.library().function()) {
816 const string& func_name = fdef.signature().name();
817
818 if ((*optimized_flib)->Contains(func_name)) {
819 VLOG(3) << "Replace function: name=" << func_name;
820 TF_RETURN_IF_ERROR((*optimized_flib)->ReplaceFunction(func_name, fdef));
821 } else {
822 VLOG(3) << "Add new function: name=" << func_name;
823 TF_RETURN_IF_ERROR((*optimized_flib)->AddFunctionDef(fdef));
824 }
825 }
826 optimized_graph->reset(new Graph(OpRegistry::Global()));
827
828 // Convert the optimized GraphDef back to a Graph.
829 GraphConstructorOptions opts;
830 opts.allow_internal_ops = true;
831 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, std::move(new_graph),
832 optimized_graph->get()));
833 // The graph conversion sets the requested device names but not the
834 // assigned device names. However, since at this point the graph is placed
835 // TF expects an assigned device name for every node. Therefore we copy
836 // the requested device into the assigned device field.
837 for (Node* node : optimized_graph->get()->nodes()) {
838 node->set_assigned_device_name(node->requested_device());
839 }
840 return Status::OK();
841 } else {
842 return errors::InvalidArgument("Meta Optimizer disabled");
843 }
844 #endif // IS_MOBILE_PLATFORM
845 }
846
BuildGraph(const BuildGraphOptions & options,std::unique_ptr<ClientGraph> * out)847 Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
848 std::unique_ptr<ClientGraph>* out) {
849 VLOG(1) << "BuildGraph";
850 const uint64 start_time_usecs = Env::Default()->NowMicros();
851 if (!graph_) {
852 // It is only valid to call this method directly when the original graph
853 // was created with the option `place_pruned_graph == false`.
854 return errors::Internal(
855 "Attempted to prune a graph that has not been fully initialized.");
856 }
857
858 // Grappler optimization might change the structure of a graph itself, and
859 // also it can add/prune functions to/from the library.
860 std::unique_ptr<Graph> optimized_graph;
861 std::unique_ptr<FunctionLibraryDefinition> optimized_flib;
862
863 Status s = OptimizeGraph(options, &optimized_graph, &optimized_flib);
864 if (!s.ok()) {
865 VLOG(2) << "Grappler optimization failed. Error: " << s.error_message();
866 // Simply copy the original graph and the function library if we couldn't
867 // optimize it.
868 optimized_graph.reset(new Graph(flib_def_.get()));
869 CopyGraph(*graph_, optimized_graph.get());
870 optimized_flib.reset(new FunctionLibraryDefinition(*flib_def_));
871 }
872
873 subgraph::RewriteGraphMetadata rewrite_metadata;
874 if (session_options_ == nullptr ||
875 !session_options_->config.graph_options().place_pruned_graph()) {
876 TF_RETURN_IF_ERROR(
877 PruneGraph(options, optimized_graph.get(), &rewrite_metadata));
878 } else {
879 // This GraphExecutionState represents a graph that was
880 // pruned when this was constructed, so we copy the metadata from
881 // a member variable.
882 CHECK(rewrite_metadata_);
883 rewrite_metadata = *rewrite_metadata_;
884 }
885
886 CHECK_EQ(options.callable_options.feed_size(),
887 rewrite_metadata.feed_types.size());
888 CHECK_EQ(options.callable_options.fetch_size(),
889 rewrite_metadata.fetch_types.size());
890
891 // TODO(andydavis): Clarify optimization pass requirements around CostModel.
892 GraphOptimizationPassOptions optimization_options;
893 optimization_options.session_options = session_options_;
894 optimization_options.graph = &optimized_graph;
895 optimization_options.flib_def = optimized_flib.get();
896 optimization_options.device_set = device_set_;
897
898 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
899 OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
900
901 int64 collective_graph_key = options.collective_graph_key;
902 if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
903 // BuildGraphOptions does not specify a collective_graph_key. Check all
904 // nodes in the Graph and FunctionLibraryDefinition for collective ops and
905 // if found, initialize a collective_graph_key as a hash of the ordered set
906 // of instance keys.
907 std::set<int32> instance_key_set;
908 bool has_collective_v2 = false;
909 for (Node* node : optimized_graph->nodes()) {
910 if (node->IsCollective()) {
911 int32 instance_key;
912 TF_RETURN_IF_ERROR(
913 GetNodeAttr(node->attrs(), "instance_key", &instance_key));
914 instance_key_set.emplace(instance_key);
915 } else if (IsCollectiveV2(node->type_string())) {
916 has_collective_v2 = true;
917 } else {
918 const FunctionDef* fdef = optimized_flib->Find(node->def().op());
919 if (fdef != nullptr) {
920 for (const NodeDef& ndef : fdef->node_def()) {
921 if (ndef.op() == "CollectiveReduce" ||
922 ndef.op() == "CollectiveBcastSend" ||
923 ndef.op() == "CollectiveBcastRecv" ||
924 ndef.op() == "CollectiveGather") {
925 int32 instance_key;
926 TF_RETURN_IF_ERROR(
927 GetNodeAttr(ndef, "instance_key", &instance_key));
928 instance_key_set.emplace(instance_key);
929 } else if (IsCollectiveV2(ndef.op())) {
930 has_collective_v2 = true;
931 }
932 }
933 }
934 }
935 }
936 if (!instance_key_set.empty()) {
937 uint64 hash = 0x8774aa605c729c72ULL;
938 for (int32 instance_key : instance_key_set) {
939 hash = Hash64Combine(instance_key, hash);
940 }
941 collective_graph_key = hash;
942 } else if (has_collective_v2) {
943 collective_graph_key = 0x8774aa605c729c72ULL;
944 }
945 }
946
947 // Make collective execution order deterministic if needed.
948 if (options.collective_order != GraphCollectiveOrder::kNone) {
949 TF_RETURN_IF_ERROR(
950 OrderCollectives(optimized_graph.get(), options.collective_order));
951 }
952
953 // Copy the extracted graph in order to make its node ids dense,
954 // since the local CostModel used to record its stats is sized by
955 // the largest node id.
956 std::unique_ptr<ClientGraph> dense_copy(
957 new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types,
958 rewrite_metadata.fetch_types, collective_graph_key));
959 CopyGraph(*optimized_graph, &dense_copy->graph);
960
961 // TODO(vrv): We should check invariants of the graph here.
962 metrics::UpdateGraphBuildTime(Env::Default()->NowMicros() - start_time_usecs);
963 *out = std::move(dense_copy);
964 return Status::OK();
965 }
966
967 } // namespace tensorflow
968