1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/graph_execution_state.h"
17
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/metrics.h"
27 #include "tensorflow/core/common_runtime/optimization_registry.h"
28 #include "tensorflow/core/common_runtime/placer.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/graph.pb_text.h"
31 #include "tensorflow/core/framework/graph_def_util.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/tensor.pb.h"
34 #include "tensorflow/core/framework/versions.pb.h"
35 #include "tensorflow/core/graph/algorithm.h"
36 #include "tensorflow/core/graph/collective_order.h"
37 #include "tensorflow/core/graph/graph.h"
38 #include "tensorflow/core/graph/graph_constructor.h"
39 #include "tensorflow/core/graph/subgraph.h"
40 #include "tensorflow/core/graph/tensor_id.h"
41 #include "tensorflow/core/graph/validate.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/gtl/flatset.h"
45 #include "tensorflow/core/lib/strings/stringprintf.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/types.h"
48 #include "tensorflow/core/util/device_name_utils.h"
49 #include "tensorflow/core/util/util.h"
50
51 #ifndef IS_MOBILE_PLATFORM
52 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
53 #include "tensorflow/core/grappler/grappler_item.h"
54 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
55 #endif // IS_MOBILE_PLATFORM
56
57 namespace tensorflow {
58
GraphExecutionState(GraphDef * graph_def,const GraphExecutionStateOptions & options)59 GraphExecutionState::GraphExecutionState(
60 GraphDef* graph_def, const GraphExecutionStateOptions& options)
61 : stateful_placements_(options.stateful_placements),
62 device_set_(options.device_set),
63 session_options_(options.session_options),
64 session_handle_(options.session_handle),
65 flib_def_(new FunctionLibraryDefinition(OpRegistry::Global(),
66 graph_def->library())),
67 graph_(nullptr) {
68 // NOTE(mrry): GraphDef does not have a move constructor, so we pass
69 // a non-const pointer and use `Swap()` to transfer the contents
70 // without copying.
71 original_graph_def_.Swap(graph_def);
72 // TODO(mrry): Publish placement visualizations or handle the log
73 // placement option.
74 }
75
~GraphExecutionState()76 GraphExecutionState::~GraphExecutionState() {
77 node_name_to_cost_id_map_.clear();
78 delete graph_;
79 }
80
MakeForBaseGraph(GraphDef * graph_def,const GraphExecutionStateOptions & options,std::unique_ptr<GraphExecutionState> * out_state)81 /* static */ Status GraphExecutionState::MakeForBaseGraph(
82 GraphDef* graph_def, const GraphExecutionStateOptions& options,
83 std::unique_ptr<GraphExecutionState>* out_state) {
84 #ifndef __ANDROID__
85 VLOG(4) << "Graph proto is \n" << graph_def->DebugString();
86 #endif // __ANDROID__
87
88 std::unique_ptr<GraphExecutionState> ret(
89 new GraphExecutionState(graph_def, options));
90
91 TF_RETURN_IF_ERROR(
92 AddDefaultAttrsToGraphDef(&ret->original_graph_def_, *ret->flib_def_, 0));
93 // TODO(mrry): Refactor InitBaseGraph() so that we don't have to
94 // pass an empty BuildGraphOptions (that isn't going to be used when
95 // place_pruned_graph is false).
96 if (!ret->session_options_->config.graph_options().place_pruned_graph()) {
97 TF_RETURN_IF_ERROR(ret->InitBaseGraph(BuildGraphOptions()));
98 }
99 *out_state = std::move(ret);
100 return Status::OK();
101 }
102
MakeForPrunedGraph(const FunctionDefLibrary & func_def_lib,const GraphExecutionStateOptions & options,const GraphDef & graph_def,const BuildGraphOptions & subgraph_options,std::unique_ptr<GraphExecutionState> * out_state,std::unique_ptr<ClientGraph> * out_client_graph)103 /* static */ Status GraphExecutionState::MakeForPrunedGraph(
104 const FunctionDefLibrary& func_def_lib,
105 const GraphExecutionStateOptions& options, const GraphDef& graph_def,
106 const BuildGraphOptions& subgraph_options,
107 std::unique_ptr<GraphExecutionState>* out_state,
108 std::unique_ptr<ClientGraph>* out_client_graph) {
109 DCHECK(options.session_options->config.graph_options().place_pruned_graph());
110 // NOTE(mrry): This makes a copy of `graph_def`, which is
111 // regrettable. We could make `GraphDef` objects sharable between
112 // execution states to optimize pruned graph execution, but since
113 // this case is primarily used for interactive sessions, we make the
114 // bet that graph construction is not performance-critical. (Note
115 // also that the previous version used `Extend()`, which is strictly
116 // more expensive than copying a `GraphDef`.)
117 GraphDef temp(graph_def);
118 std::unique_ptr<GraphExecutionState> ret(
119 new GraphExecutionState(&temp, options));
120 TF_RETURN_IF_ERROR(
121 AddDefaultAttrsToGraphDef(&ret->original_graph_def_, *ret->flib_def_, 0));
122 TF_RETURN_IF_ERROR(ret->InitBaseGraph(subgraph_options));
123 TF_RETURN_IF_ERROR(ret->BuildGraph(subgraph_options, out_client_graph));
124 *out_state = std::move(ret);
125 return Status::OK();
126 }
127
Extend(const GraphDef & extension_def,std::unique_ptr<GraphExecutionState> * out) const128 Status GraphExecutionState::Extend(
129 const GraphDef& extension_def,
130 std::unique_ptr<GraphExecutionState>* out) const {
131 GraphDef gdef;
132
133 // 1. Copy the function library.
134 TF_RETURN_IF_ERROR(flib_def_->AddLibrary(extension_def.library()));
135 *gdef.mutable_library() = flib_def_->ToProto();
136
137 // 2. Build an index of the new node names.
138 std::unordered_set<string> new_names;
139 for (const NodeDef& node : extension_def.node()) {
140 new_names.insert(node.name());
141 }
142
143 // 3. Add the non-duplicates from the old graph to the new graph.
144 // Return an error if the same node name appears in both the
145 // old graph and the extension.
146 for (const NodeDef& node : original_graph_def_.node()) {
147 if (new_names.count(node.name()) == 0) {
148 *gdef.add_node() = node;
149 } else {
150 return errors::InvalidArgument(tensorflow::strings::Printf(
151 "GraphDef argument to Extend includes node '%s', which was created "
152 "by a previous call to Create or Extend in this session.",
153 node.name().c_str()));
154 }
155 }
156
157 // 4. Merge the versions field.
158 int old_node_size = gdef.node_size();
159 gdef.mutable_node()->MergeFrom(extension_def.node());
160 TF_RETURN_IF_ERROR(
161 AddDefaultAttrsToGraphDef(&gdef, *flib_def_, old_node_size));
162 // Merge versions
163 if (gdef.has_versions()) {
164 if (gdef.versions().producer() != extension_def.versions().producer()) {
165 return errors::InvalidArgument(
166 "Can't extend GraphDef at version ", gdef.versions().producer(),
167 " with graph at version ", extension_def.versions().producer());
168 }
169 VersionDef* versions = gdef.mutable_versions();
170 versions->set_min_consumer(std::max(
171 versions->min_consumer(), extension_def.versions().min_consumer()));
172 if (extension_def.versions().bad_consumers_size()) {
173 // Add new bad_consumers that aren't already marked bad.
174 //
175 // Note: This implementation is quadratic time if there are many calls to
176 // ExtendLocked with many bad consumers. Since this is unlikely, and
177 // fixing it would require data structures outside of this routine,
178 // quadratic time it is.
179 auto* bad_consumers = versions->mutable_bad_consumers();
180 const std::unordered_set<int> existing(bad_consumers->begin(),
181 bad_consumers->end());
182 for (const int v : extension_def.versions().bad_consumers()) {
183 if (existing.find(v) == existing.end()) {
184 bad_consumers->Add(v);
185 }
186 }
187 }
188
189 } else {
190 gdef.mutable_versions()->CopyFrom(extension_def.versions());
191 }
192
193 // 5. Validate that the final graphdef is valid.
194 if (gdef.versions().producer() >= 5) {
195 // Validate the graph: we assume that merging two valid graphs
196 // should maintain graph validity.
197 TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *flib_def_));
198 }
199
200 // 6. Add the extension.
201 GraphExecutionStateOptions combined_options;
202 combined_options.device_set = device_set_;
203 combined_options.session_options = session_options_;
204 combined_options.session_handle = session_handle_;
205 combined_options.stateful_placements = stateful_placements_;
206
207 // NOTE(mrry): `gdef` is no longer valid after the constructor
208 // executes.
209 std::unique_ptr<GraphExecutionState> new_execution_state(
210 new GraphExecutionState(&gdef, combined_options));
211
212 TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
213 &new_execution_state->original_graph_def_, *flib_def_, 0));
214 if (!session_options_->config.graph_options().place_pruned_graph()) {
215 // TODO(mrry): Refactor InitBaseGraph() so that we don't have to
216 // pass an empty BuildGraphOptions (that isn't going to be used
217 // when place_pruned_graph is false).
218 TF_RETURN_IF_ERROR(new_execution_state->InitBaseGraph(BuildGraphOptions()));
219 }
220 *out = std::move(new_execution_state);
221
222 // TODO(mrry): This is likely to be used for non-throughput-sensitive
223 // interactive workloads, but in future we may want to transfer other
224 // parts of the placement and/or cost model.
225 return Status::OK();
226 }
227
SaveStatefulNodes(Graph * graph)228 void GraphExecutionState::SaveStatefulNodes(Graph* graph) {
229 for (Node* n : graph->nodes()) {
230 if (n->op_def().is_stateful()) {
231 VLOG(2) << "Saving " << n->DebugString();
232 stateful_placements_[n->name()] = n->assigned_device_name();
233 }
234 }
235 }
236
RestoreStatefulNodes(Graph * graph)237 void GraphExecutionState::RestoreStatefulNodes(Graph* graph) {
238 for (Node* n : graph->nodes()) {
239 if (n->op_def().is_stateful()) {
240 auto iter = stateful_placements_.find(n->name());
241 if (iter != stateful_placements_.end()) {
242 n->set_assigned_device_name(iter->second);
243 VLOG(2) << "Restored " << n->DebugString();
244 }
245 }
246 }
247 }
248
249 namespace {
250
251 class TensorConnectionPruneRewrite : public subgraph::PruneRewrite {
252 public:
TensorConnectionPruneRewrite(const string * endpoint_name,NodeBuilder::NodeOut from_tensor)253 TensorConnectionPruneRewrite(const string* endpoint_name,
254 NodeBuilder::NodeOut from_tensor)
255 : subgraph::PruneRewrite(endpoint_name, nullptr /* device_info */),
256 from_tensor_(std::move(from_tensor)) {}
257
AddNode(Graph * g,NodeBuilder::NodeOut feed_tensor,Node ** out_node)258 Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
259 Node** out_node) override {
260 Status s;
261 auto check_no_cycle_fn = [this, feed_tensor, &s](Node* n) {
262 if (n == feed_tensor.node) {
263 s.Update(errors::InvalidArgument(
264 "Requested Tensor connection between nodes \"",
265 feed_tensor.node->name(), "\" and \"", from_tensor_.node->name(),
266 "\" would create a cycle."));
267 }
268 };
269 ReverseDFSFrom(*g, {from_tensor_.node}, std::move(check_no_cycle_fn),
270 nullptr);
271 TF_RETURN_IF_ERROR(s);
272
273 TF_RETURN_IF_ERROR(
274 NodeBuilder(strings::StrCat("_identity_", feed_tensor.node->name(), "_",
275 feed_tensor.index),
276 "Identity")
277 .Input(from_tensor_)
278 .Attr("T",
279 BaseType(from_tensor_.node->output_type(from_tensor_.index)))
280 .Finalize(g, out_node));
281
282 (*out_node)->set_assigned_device_name(
283 feed_tensor.node->assigned_device_name());
284 return Status::OK();
285 }
286
287 private:
288 NodeBuilder::NodeOut from_tensor_;
289 };
290
291 template <class Map>
LookupDevice(const DeviceSet & device_set,const string & tensor_name,const Map & tensor2device,const tensorflow::DeviceAttributes ** out_device_attrs)292 Status LookupDevice(const DeviceSet& device_set, const string& tensor_name,
293 const Map& tensor2device,
294 const tensorflow::DeviceAttributes** out_device_attrs) {
295 *out_device_attrs = nullptr;
296 if (tensor2device.empty()) {
297 *out_device_attrs = &device_set.client_device()->attributes();
298 return Status::OK();
299 }
300 const auto it = tensor2device.find(tensor_name);
301 if (it == tensor2device.end()) {
302 *out_device_attrs = &device_set.client_device()->attributes();
303 return Status::OK();
304 }
305 DeviceNameUtils::ParsedName parsed_name;
306 if (!DeviceNameUtils::ParseFullName(it->second, &parsed_name)) {
307 return errors::InvalidArgument("Invalid device name ('", it->second,
308 "') provided for the tensor '", tensor_name,
309 "' in CallableOptions");
310 }
311 Device* device = device_set.FindDeviceByName(
312 DeviceNameUtils::ParsedNameToString(parsed_name));
313 if (device == nullptr) {
314 return errors::InvalidArgument("Device '", it->second,
315 "' specified for tensor '", tensor_name,
316 "' in CallableOptions does not exist");
317 }
318 *out_device_attrs = &device->attributes();
319 return Status::OK();
320 }
321
322 struct TensorAndDevice {
323 // WARNING: backing memory for the 'tensor' field is NOT owend.
324 const TensorId tensor;
325 // WARNING: device pointer is not owned, so must outlive TensorAndDevice.
326 const DeviceAttributes* device;
327 };
328
329 // Tensors of some DataTypes cannot placed in device memory as feeds or
330 // fetches. Validate against a whitelist of those known to work.
IsFeedAndFetchSupported(DataType dtype,const string & device_type)331 bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) {
332 // The mechanism for supporting feeds of device-backed Tensors requires
333 // the _Arg kernel to be registered for the corresponding type (and that
334 // the input to the kernel be in device and not host memory).
335 //
336 // The mechanism for supporting fetches of device-backed Tensors requires
337 // the _Retval kernel to be registered for the corresponding type (and
338 // that the output is produced in device and not host memory).
339 //
340 // For now, we return true iff there are _Arg AND _Retval kernels for dtype on
341 // the device. False negatives are okay, false positives would be bad.
342 //
343 // TODO(ashankar): Instead of a whitelist here, perhaps we could query
344 // the kernel registry for _Arg and _Retval kernels instead.
345 if (device_type == DEVICE_CPU) return true;
346 if (device_type != DEVICE_GPU) return false;
347 switch (dtype) {
348 case DT_BFLOAT16:
349 case DT_BOOL:
350 case DT_COMPLEX128:
351 case DT_COMPLEX64:
352 case DT_DOUBLE:
353 case DT_FLOAT:
354 case DT_HALF:
355 case DT_INT16:
356 case DT_INT64:
357 case DT_INT8:
358 case DT_UINT16:
359 case DT_UINT8:
360 return true;
361 default:
362 return false;
363 }
364 }
365
ValidateFeedAndFetchDevices(const Graph & graph,const std::vector<TensorAndDevice> & tensors_and_devices)366 Status ValidateFeedAndFetchDevices(
367 const Graph& graph,
368 const std::vector<TensorAndDevice>& tensors_and_devices) {
369 if (tensors_and_devices.empty()) return Status::OK();
370 std::vector<bool> found(tensors_and_devices.size(), false);
371 for (const Node* node : graph.nodes()) {
372 // Linearly looping through all nodes and then all feed+fetch tensors isn't
373 // quite efficient. At the time of this writing, the expectation was that
374 // tensors_and_devices.size() is really small in practice, so this won't be
375 // problematic.
376 // Revist and make a more efficient lookup possible if needed (e.g., perhaps
377 // Graph can maintain a map from node name to Node*).
378 for (int i = 0; i < tensors_and_devices.size(); ++i) {
379 const TensorAndDevice& td = tensors_and_devices[i];
380 if (td.tensor.first != node->name()) continue;
381 found[i] = true;
382 TF_RETURN_IF_ERROR(graph.IsValidOutputTensor(node, td.tensor.second));
383 const DataType dtype = node->output_type(td.tensor.second);
384 if (!IsFeedAndFetchSupported(dtype, td.device->device_type())) {
385 return errors::Unimplemented(
386 "Cannot feed or fetch tensor '", td.tensor.ToString(),
387 "' from device ", td.device->name(), " as feeding/fetching from ",
388 td.device->device_type(), " devices is not yet supported for ",
389 DataTypeString(dtype), " tensors");
390 }
391 }
392 }
393 for (int i = 0; i < found.size(); ++i) {
394 if (!found[i]) {
395 return errors::InvalidArgument(
396 "Tensor ", tensors_and_devices[i].tensor.ToString(),
397 ", specified in either feed_devices or fetch_devices was not found "
398 "in the Graph");
399 }
400 }
401 return Status::OK();
402 }
403
GetFeedShapeAndTypeFromAttribute(const NodeDef & node,PartialTensorShape * shape,DataType * type)404 Status GetFeedShapeAndTypeFromAttribute(const NodeDef& node,
405 PartialTensorShape* shape,
406 DataType* type) {
407 static const gtl::FlatSet<string>* const kHasExplicitShapeAttribute =
408 CHECK_NOTNULL((new gtl::FlatSet<string>{
409 "Placeholder", "PlaceholderV2", "PlaceholderWithDefault",
410 "ParallelConcat", "ImmutableConst", "_ParallelConcatStart",
411 "InfeedDequeue", "OutfeedDequeue", "CollectiveBcastSend",
412 "CollectiveBcastRecv", "AccumulateNV2", "VariableV2", "Variable",
413 "TemporaryVariable", "NcclBroadcast", "_ScopedAllocator",
414 "_ScopedAllocatorConcat"}));
415
416 // All the node types handled here have their output datatype set in
417 // either attribute 'dtype' or 'T'.
418 if (!GetNodeAttr(node, "dtype", type).ok() &&
419 !GetNodeAttr(node, "T", type).ok()) {
420 return errors::InvalidArgument(
421 "Could not determine output type for feed node: ", node.name(),
422 " of type ", node.op());
423 }
424
425 // First handle the case of feeding a const node.
426 if (node.op() == "Const" && HasNodeAttr(node, "value")) {
427 *shape =
428 PartialTensorShape(node.attr().at("value").tensor().tensor_shape());
429 } else if (kHasExplicitShapeAttribute->find(node.op()) !=
430 kHasExplicitShapeAttribute->end()) {
431 TF_RETURN_IF_ERROR(GetNodeAttr(node, "shape", shape));
432 } else {
433 return errors::InvalidArgument("Could not determine shape for feed node: ",
434 node.name(), " of type ", node.op());
435 }
436 return Status::OK();
437 }
438
439 } // namespace
440
PruneGraph(const BuildGraphOptions & options,Graph * graph,subgraph::RewriteGraphMetadata * out_rewrite_metadata)441 Status GraphExecutionState::PruneGraph(
442 const BuildGraphOptions& options, Graph* graph,
443 subgraph::RewriteGraphMetadata* out_rewrite_metadata) {
444 std::vector<std::unique_ptr<subgraph::PruneRewrite>> feed_rewrites;
445 feed_rewrites.reserve(options.callable_options.feed_size());
446 std::vector<std::unique_ptr<subgraph::PruneRewrite>> fetch_rewrites;
447 fetch_rewrites.reserve(options.callable_options.fetch_size());
448 if (options.use_function_convention) {
449 std::vector<TensorAndDevice> tensors_and_devices;
450 for (int i = 0; i < options.callable_options.feed_size(); ++i) {
451 // WARNING: feed MUST be a reference, since ArgFeedRewrite and
452 // tensors_and_devices holds on to its address.
453 const string& feed = options.callable_options.feed(i);
454 const DeviceAttributes* device_info;
455 TF_RETURN_IF_ERROR(LookupDevice(*device_set_, feed,
456 options.callable_options.feed_devices(),
457 &device_info));
458 feed_rewrites.emplace_back(
459 new subgraph::ArgFeedRewrite(&feed, device_info, i));
460 tensors_and_devices.push_back({ParseTensorName(feed), device_info});
461 }
462 if (!options.callable_options.fetch_devices().empty() &&
463 !options.callable_options.fetch_skip_sync()) {
464 return errors::Unimplemented(
465 "CallableOptions.fetch_skip_sync = false is not yet implemented. You "
466 "can set it to true instead, but MUST ensure that Device::Sync() is "
467 "invoked on the Device corresponding to the fetched tensor before "
468 "dereferencing the Tensor's memory.");
469 }
470 for (int i = 0; i < options.callable_options.fetch_size(); ++i) {
471 // WARNING: fetch MUST be a reference, since RetvalFetchRewrite and
472 // tensors_and_devices holds on to its address.
473 const string& fetch = options.callable_options.fetch(i);
474 const DeviceAttributes* device_info;
475 TF_RETURN_IF_ERROR(LookupDevice(*device_set_, fetch,
476 options.callable_options.fetch_devices(),
477 &device_info));
478 fetch_rewrites.emplace_back(
479 new subgraph::RetvalFetchRewrite(&fetch, device_info, i));
480 tensors_and_devices.push_back({ParseTensorName(fetch), device_info});
481 }
482 TF_RETURN_IF_ERROR(
483 ValidateFeedAndFetchDevices(*graph, tensors_and_devices));
484 } else {
485 if (!options.callable_options.feed_devices().empty() ||
486 !options.callable_options.fetch_devices().empty()) {
487 return errors::Unimplemented(
488 "CallableOptions::feed_devices and CallableOptions::fetch_devices "
489 "to configure feeding/fetching tensors to/from device memory is not "
490 "yet supported when using a remote session.");
491 }
492 const DeviceAttributes* device_info =
493 &device_set_->client_device()->attributes();
494 for (const string& feed : options.callable_options.feed()) {
495 feed_rewrites.emplace_back(
496 new subgraph::RecvFeedRewrite(&feed, device_info));
497 }
498 for (const string& fetch : options.callable_options.fetch()) {
499 fetch_rewrites.emplace_back(
500 new subgraph::SendFetchRewrite(&fetch, device_info));
501 }
502 }
503
504 for (const TensorConnection& tensor_connection :
505 options.callable_options.tensor_connection()) {
506 Node* from_node = nullptr;
507 TensorId from_id(ParseTensorName(tensor_connection.from_tensor()));
508
509 for (Node* n : graph->nodes()) {
510 if (n->name() == from_id.first) {
511 from_node = n;
512 break;
513 }
514 }
515 if (from_node == nullptr) {
516 return errors::InvalidArgument(
517 "Requested tensor connection from unknown node: \"",
518 tensor_connection.to_tensor(), "\".");
519 }
520 if (from_id.second >= from_node->num_outputs()) {
521 return errors::InvalidArgument(
522 "Requested tensor connection from unknown edge: \"",
523 tensor_connection.to_tensor(),
524 "\" (actual number of outputs = ", from_node->num_outputs(), ").");
525 }
526
527 feed_rewrites.emplace_back(new TensorConnectionPruneRewrite(
528 &tensor_connection.to_tensor(), {from_node, from_id.second}));
529 }
530
531 std::vector<string> target_node_names(
532 options.callable_options.target().begin(),
533 options.callable_options.target().end());
534 TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
535 graph, feed_rewrites, fetch_rewrites, target_node_names,
536 out_rewrite_metadata));
537
538 CHECK_EQ(out_rewrite_metadata->feed_types.size(),
539 options.callable_options.feed_size() +
540 options.callable_options.tensor_connection_size());
541 for (int i = 0; i < options.callable_options.tensor_connection_size(); ++i) {
542 out_rewrite_metadata->feed_types.pop_back();
543 }
544 return Status::OK();
545 }
546
InitBaseGraph(const BuildGraphOptions & options)547 Status GraphExecutionState::InitBaseGraph(const BuildGraphOptions& options) {
548 const GraphDef* graph_def = &original_graph_def_;
549
550 std::unique_ptr<Graph> new_graph(new Graph(OpRegistry::Global()));
551 GraphConstructorOptions opts;
552 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, *graph_def, new_graph.get()));
553 if (session_options_ &&
554 session_options_->config.graph_options().place_pruned_graph()) {
555 // Rewrite the graph before placement.
556 rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata);
557 TF_RETURN_IF_ERROR(
558 PruneGraph(options, new_graph.get(), rewrite_metadata_.get()));
559 }
560
561 // Save stateful placements before placing.
562 RestoreStatefulNodes(new_graph.get());
563
564 GraphOptimizationPassOptions optimization_options;
565 optimization_options.session_handle = session_handle_;
566 optimization_options.session_options = session_options_;
567 optimization_options.graph = &new_graph;
568 optimization_options.flib_def = flib_def_.get();
569 optimization_options.device_set = device_set_;
570
571 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
572 OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
573
574 Placer placer(new_graph.get(), device_set_, /* default_device= */ nullptr,
575 session_options_ == nullptr ||
576 session_options_->config.allow_soft_placement(),
577 session_options_ != nullptr &&
578 session_options_->config.log_device_placement());
579 // TODO(mrry): Consider making the Placer cancelable.
580 TF_RETURN_IF_ERROR(placer.Run());
581
582 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
583 OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
584
585 for (const Node* n : new_graph->nodes()) {
586 VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id();
587 node_name_to_cost_id_map_[n->name()] = n->cost_id();
588 }
589
590 SaveStatefulNodes(new_graph.get());
591 graph_ = new_graph.release();
592 return Status::OK();
593 }
594
OptimizeGraph(const BuildGraphOptions & options,std::unique_ptr<Graph> * optimized_graph,std::unique_ptr<FunctionLibraryDefinition> * optimized_flib)595 Status GraphExecutionState::OptimizeGraph(
596 const BuildGraphOptions& options, std::unique_ptr<Graph>* optimized_graph,
597 std::unique_ptr<FunctionLibraryDefinition>* optimized_flib) {
598 #ifndef IS_MOBILE_PLATFORM
599 if (session_options_->config.graph_options().place_pruned_graph()) {
600 return errors::InvalidArgument("Can't optimize a pruned graph");
601 }
602
603 if (grappler::MetaOptimizerEnabled(session_options_->config)) {
604 grappler::GrapplerItem item;
605 item.id = "tf_graph";
606 graph_->ToGraphDef(&item.graph);
607
608 // It's ok to skip invalid device annotations in Grappler.
609 Status inferred_devices = item.InferDevicesFromGraph();
610 if (!inferred_devices.ok()) {
611 VLOG(3) << inferred_devices.error_message();
612 }
613
614 // TODO(b/114748242): Add a unit test to test this bug fix.
615 if (flib_def_) {
616 *item.graph.mutable_library() = flib_def_->ToProto();
617 }
618
619 item.fetch.insert(item.fetch.end(),
620 options.callable_options.fetch().begin(),
621 options.callable_options.fetch().end());
622 item.fetch.insert(item.fetch.end(),
623 options.callable_options.target().begin(),
624 options.callable_options.target().end());
625
626 for (const TensorConnection& tensor_connection :
627 options.callable_options.tensor_connection()) {
628 item.fetch.push_back(tensor_connection.from_tensor());
629 }
630
631 if (!(options.callable_options.feed().empty() &&
632 options.callable_options.tensor_connection().empty())) {
633 std::unordered_set<string> feeds;
634 for (const string& feed : options.callable_options.feed()) {
635 TensorId id = ParseTensorName(feed);
636 if (id.second != 0) {
637 return errors::InvalidArgument("Unsupported feed: ", feed);
638 }
639 feeds.emplace(id.first);
640 }
641 for (const TensorConnection& tensor_connection :
642 options.callable_options.tensor_connection()) {
643 TensorId id = ParseTensorName(tensor_connection.to_tensor());
644 if (id.second != 0) {
645 return errors::InvalidArgument("Unsupported feed: ",
646 tensor_connection.to_tensor());
647 }
648 feeds.emplace(id.first);
649 }
650 for (const NodeDef& node : original_graph_def_.node()) {
651 if (feeds.find(node.name()) == feeds.end()) {
652 continue;
653 }
654 // Get the type and shape of the feed node.
655 PartialTensorShape partial_shape;
656 DataType type;
657 TF_RETURN_IF_ERROR(
658 GetFeedShapeAndTypeFromAttribute(node, &partial_shape, &type));
659 // If the shape of the placeholder is only partially known, we are free
660 // to set unknown dimensions of its shape to any value we desire. We
661 // choose 0 to minimize the memory impact. Note that this only matters
662 // if an optimizer chooses to run the graph.
663 TensorShape shape;
664 if (partial_shape.unknown_rank()) {
665 shape = TensorShape({0});
666 } else {
667 for (int i = 0; i < partial_shape.dims(); ++i) {
668 if (partial_shape.dim_size(i) < 0) {
669 partial_shape.set_dim(i, 0);
670 }
671 }
672 if (!partial_shape.AsTensorShape(&shape)) {
673 return errors::InvalidArgument(
674 "Could not derive shape for feed node: ", node.DebugString());
675 }
676 }
677
678 Tensor fake_input(type, shape);
679 item.feed.emplace_back(node.name(), fake_input);
680 }
681 }
682
683 Device* cpu_device = nullptr;
684 for (const auto& device : device_set_->devices()) {
685 if (device->parsed_name().id == 0 &&
686 StringPiece(device->parsed_name().type) == "CPU" &&
687 device->GetAllocator(AllocatorAttributes()) != nullptr) {
688 cpu_device = device;
689 }
690 }
691 grappler::VirtualCluster cluster(device_set_);
692 GraphDef new_graph;
693 TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(
694 item, session_options_->config, cpu_device, &cluster, &new_graph));
695
696 // Merge optimized graph function library with an original library.
697 // Optimized graph might have new functions specialized for it's
698 // instantiation context (see Grappler function optimizer), and modified
699 // function body for the existing functions.
700 optimized_flib->reset(new FunctionLibraryDefinition(*flib_def_));
701
702 for (const FunctionDef& fdef : new_graph.library().function()) {
703 const string& func_name = fdef.signature().name();
704
705 if ((*optimized_flib)->Contains(func_name)) {
706 VLOG(3) << "Replace function: name=" << func_name;
707 TF_RETURN_IF_ERROR((*optimized_flib)->ReplaceFunction(func_name, fdef));
708 } else {
709 VLOG(3) << "Add new function: name=" << func_name;
710 TF_RETURN_IF_ERROR((*optimized_flib)->AddFunctionDef(fdef));
711 }
712 }
713
714 optimized_graph->reset(new Graph(OpRegistry::Global()));
715
716 GraphConstructorOptions opts;
717 opts.allow_internal_ops = true;
718 TF_RETURN_IF_ERROR(
719 ConvertGraphDefToGraph(opts, new_graph, optimized_graph->get()));
720 // The graph conversion sets the requested device names but not the
721 // assigned device names. However, since at this point the graph is placed
722 // TF expects an assigned device name for every node. Therefore we copy
723 // the requested device into the assigned device field.
724 for (Node* node : optimized_graph->get()->nodes()) {
725 node->set_assigned_device_name(node->requested_device());
726 }
727 return Status::OK();
728 } else {
729 return errors::InvalidArgument("Meta Optimizer disabled");
730 }
731 #else
732 return errors::InvalidArgument("Mobile platforms not supported");
733 #endif // IS_MOBILE_PLATFORM
734 }
735
BuildGraph(const BuildGraphOptions & options,std::unique_ptr<ClientGraph> * out)736 Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
737 std::unique_ptr<ClientGraph>* out) {
738 VLOG(1) << "BuildGraph";
739 const uint64 start_time_usecs = Env::Default()->NowMicros();
740 if (!graph_) {
741 // It is only valid to call this method directly when the original graph
742 // was created with the option `place_pruned_graph == false`.
743 return errors::Internal(
744 "Attempted to prune a graph that has not been fully initialized.");
745 }
746
747 // Grappler optimization might change the structure of a graph itself, and
748 // also it can add/prune functions to/from the library.
749 std::unique_ptr<Graph> optimized_graph;
750 std::unique_ptr<FunctionLibraryDefinition> optimized_flib;
751
752 Status s = OptimizeGraph(options, &optimized_graph, &optimized_flib);
753 if (!s.ok()) {
754 VLOG(2) << "Grappler optimization failed. Error: " << s.error_message();
755 // Simply copy the original graph and the function library if we couldn't
756 // optimize it.
757 optimized_graph.reset(new Graph(flib_def_.get()));
758 CopyGraph(*graph_, optimized_graph.get());
759 optimized_flib.reset(new FunctionLibraryDefinition(*flib_def_));
760 }
761
762 subgraph::RewriteGraphMetadata rewrite_metadata;
763 if (session_options_ == nullptr ||
764 !session_options_->config.graph_options().place_pruned_graph()) {
765 TF_RETURN_IF_ERROR(
766 PruneGraph(options, optimized_graph.get(), &rewrite_metadata));
767 } else {
768 // This GraphExecutionState represents a graph that was
769 // pruned when this was constructed, so we copy the metadata from
770 // a member variable.
771 CHECK(rewrite_metadata_);
772 rewrite_metadata = *rewrite_metadata_;
773 }
774
775 CHECK_EQ(options.callable_options.feed_size(),
776 rewrite_metadata.feed_types.size());
777 CHECK_EQ(options.callable_options.fetch_size(),
778 rewrite_metadata.fetch_types.size());
779
780 // TODO(andydavis): Clarify optimization pass requirements around CostModel.
781 GraphOptimizationPassOptions optimization_options;
782 optimization_options.session_options = session_options_;
783 optimization_options.graph = &optimized_graph;
784 optimization_options.flib_def = optimized_flib.get();
785 optimization_options.device_set = device_set_;
786
787 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
788 OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
789
790 int64 collective_graph_key = options.collective_graph_key;
791 if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
792 // BuildGraphOptions does not specify a collective_graph_key. Check all
793 // nodes in the Graph and FunctionLibraryDefinition for collective ops and
794 // if found, initialize a collective_graph_key as a hash of the ordered set
795 // of instance keys.
796 std::set<int32> instance_key_set;
797 for (Node* node : optimized_graph->nodes()) {
798 if (node->IsCollective()) {
799 int32 instance_key;
800 TF_RETURN_IF_ERROR(
801 GetNodeAttr(node->attrs(), "instance_key", &instance_key));
802 instance_key_set.emplace(instance_key);
803 } else {
804 const FunctionDef* fdef = optimized_flib->Find(node->def().op());
805 if (fdef != nullptr) {
806 for (const NodeDef& ndef : fdef->node_def()) {
807 if (ndef.op() == "CollectiveReduce" ||
808 ndef.op() == "CollectiveBcastSend" ||
809 ndef.op() == "CollectiveBcastRecv") {
810 int32 instance_key;
811 TF_RETURN_IF_ERROR(
812 GetNodeAttr(ndef, "instance_key", &instance_key));
813 instance_key_set.emplace(instance_key);
814 }
815 }
816 }
817 }
818 }
819 if (!instance_key_set.empty()) {
820 uint64 hash = 0x8774aa605c729c72ULL;
821 for (int32 instance_key : instance_key_set) {
822 hash = Hash64Combine(instance_key, hash);
823 }
824 collective_graph_key = hash;
825 }
826 }
827
828 // Make collective execution order deterministic if needed.
829 if (options.collective_order != GraphCollectiveOrder::kNone) {
830 TF_RETURN_IF_ERROR(
831 OrderCollectives(optimized_graph.get(), options.collective_order));
832 }
833
834 // Copy the extracted graph in order to make its node ids dense,
835 // since the local CostModel used to record its stats is sized by
836 // the largest node id.
837 std::unique_ptr<ClientGraph> dense_copy(
838 new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types,
839 rewrite_metadata.fetch_types, collective_graph_key));
840 CopyGraph(*optimized_graph, &dense_copy->graph);
841
842 // TODO(vrv): We should check invariants of the graph here.
843 metrics::UpdateGraphBuildTime(Env::Default()->NowMicros() - start_time_usecs);
844 *out = std::move(dense_copy);
845 return Status::OK();
846 }
847
848 } // namespace tensorflow
849