1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/grappler_item_builder.h"
16
17 #include <unordered_map>
18 #include <unordered_set>
19 #include <vector>
20
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_factory.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/graph_optimizer.h"
26 #include "tensorflow/core/framework/attr_value.pb.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/function.pb.h"
29 #include "tensorflow/core/framework/graph_def_util.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/tensor.pb.h"
33 #include "tensorflow/core/framework/tensor_shape.pb.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/framework/variable.pb.h"
36 #include "tensorflow/core/framework/versions.pb.h"
37 #include "tensorflow/core/graph/graph_constructor.h"
38 #include "tensorflow/core/grappler/inputs/utils.h"
39 #include "tensorflow/core/grappler/op_types.h"
40 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
41 #include "tensorflow/core/grappler/utils.h"
42 #include "tensorflow/core/lib/gtl/map_util.h"
43 #include "tensorflow/core/lib/io/path.h"
44 #include "tensorflow/core/platform/protobuf_internal.h"
45 #include "tensorflow/core/protobuf/meta_graph.pb.h"
46 #include "tensorflow/core/protobuf/saver.pb.h"
47 #include "tensorflow/core/public/session_options.h"
48
49 namespace tensorflow {
50 namespace grappler {
51
52 namespace {
53
InitializeTensor(DataType type,Tensor * tensor)54 void InitializeTensor(DataType type, Tensor* tensor) {
55 const int period = 7;
56 if (type == DT_FLOAT) {
57 auto flat = tensor->flat<float>();
58 // Populate numbers 0, 0.1, 0.2, ..., 0.5, 0.6, 0, 0.1, 0.2, ...
59 for (int i = 0; i < flat.size(); i++) {
60 flat(i) = static_cast<float>(i % period) / 10.0f;
61 }
62 } else if (type == DT_INT64) {
63 auto flat = tensor->flat<int64>();
64 // Populate numbers 0, 1, 2, ..., 5, 6, 0, 1, 2, ...
65 for (int i = 0; i < flat.size(); i++) {
66 flat(i) = i % period;
67 }
68 } else if (type != DT_STRING && type != DT_RESOURCE && type != DT_VARIANT) {
69 // DT_STRING, DT_RESOURCE and DT_VARIANT are not simple types according to
70 // is_simple_type<> in tensorflow/core/framework/type_traits.h, and
71 // Allocator will run non-trivial constructor/destructor for a Tensor with
72 // one of these types, so we should not memset its buffer.
73 memset(const_cast<char*>(tensor->tensor_data().data()), 0,
74 tensor->tensor_data().size());
75 }
76 }
77
78 // Applies the same graph pruning logic to the graph as Session.Run in TF.
79 // If the returned status is not OK, item state may be inconsistent.
PruneGraph(GrapplerItem * item)80 Status PruneGraph(GrapplerItem* item) {
81 ModelPruner pruner;
82 GraphDef pruned_graph;
83 Cluster* cluster = nullptr; // ModelPruner doesn't check cluster.
84 TF_RETURN_IF_ERROR(pruner.Optimize(cluster, *item, &pruned_graph));
85 item->graph = std::move(pruned_graph);
86 return Status::OK();
87 }
88
89 // Replace any unknown dimensions in a shape with
90 // cfg.placeholder_unknown_output_shape_dim if it is no less than 0.
ReplaceUnknownShapeDim(const ItemConfig & cfg,const TensorShapeProto & shape_pb_in,TensorShapeProto * shape_pb_out,TensorShape * shape_out)91 Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
92 const TensorShapeProto& shape_pb_in,
93 TensorShapeProto* shape_pb_out,
94 TensorShape* shape_out) {
95 std::vector<int32> dims;
96 for (const auto& dim_proto : shape_pb_in.dim()) {
97 if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
98 dim_proto.size() == -1) {
99 dims.push_back(cfg.placeholder_unknown_output_shape_dim);
100 shape_pb_out->add_dim()->set_size(
101 cfg.placeholder_unknown_output_shape_dim);
102 } else {
103 dims.push_back(std::max<int32>(1, dim_proto.size()));
104 shape_pb_out->add_dim()->set_size(dim_proto.size());
105 }
106 }
107 return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out);
108 }
109
110 // Replace unknown dimensions in Placeholder shape if
111 // cfg.placeholder_unknown_output_shape_dim is set or
112 // the Placeholder node has _output_shapes.
113 // Otherwise keep it intact to keep compatible with shape annotation
114 // (b/134092018).
UpdatePlaceholderShape(const ItemConfig & cfg,const std::unordered_set<string> & signature_feed_nodes,GrapplerItem * new_item,NodeDef * node)115 Status UpdatePlaceholderShape(
116 const ItemConfig& cfg,
117 const std::unordered_set<string>& signature_feed_nodes,
118 GrapplerItem* new_item, NodeDef* node) {
119 if (node->attr().count("dtype") == 0) {
120 return errors::Internal("Unknown type for placeholder ", node->name(),
121 ", skipping this input");
122 }
123 DataType type = node->attr().at("dtype").type();
124
125 // TODO(andiryxu): Consider cfg.placeholder_unknown_output_shape_dim >= 0 and
126 // _output_shapes is present case.
127 if (node->attr().count("shape") == 0) {
128 return errors::Internal("Unknown shape for placeholder ", node->name(),
129 ", skipping this input");
130 }
131
132 // Replace all unknown dimensions in the placeholder's tensorshape proto
133 // with cfg.placeholder_unknown_output_shape_dim and create a tensorshape
134 // from it. We do this because in newer protos, the input placeholder
135 // shape is not empty if the shape is partially defined.
136 TensorShape shape;
137 TensorShapeProto shape_proto;
138 Status make_shape_status = ReplaceUnknownShapeDim(
139 cfg, node->attr().at("shape").shape(), &shape_proto, &shape);
140 if (!make_shape_status.ok()) {
141 return errors::Internal("Invalid shape for placeholder ", node->name(),
142 ": ", make_shape_status, ", skipping this input");
143 }
144
145 // Some placeholder nodes have a mis-match between the node
146 // attribute "shape" and a different node attribute "_output_shapes".
147 // Specifically, a shape with shape.dims() == 0 could indicate either
148 // a scalar or an unknown shape. In those cases, we check _output_shapes
149 // for additional information.
150 // This case is observed in the bnmt graphs. Have not observed any
151 // cases where there was more than 1 _output_shapes, so limit it
152 // to cases where there is only 1 _output_shapes.
153 // We only do this if cfg.placeholder_unknown_output_shape_dim has
154 // been set to avoid crashing non-BNMT graphs.
155 // TODO(andiryxu): Investigate if this is a bug in BNMT graph.
156 if ((cfg.placeholder_unknown_output_shape_dim >= 0) && (shape.dims() == 0) &&
157 (node->attr().count("_output_shapes") == 1)) {
158 const auto& output_shapes =
159 node->attr().at("_output_shapes").list().shape(0);
160
161 if (output_shapes.dim_size() != 0) {
162 shape.Clear();
163 shape_proto.clear_dim();
164
165 for (const auto& dim : output_shapes.dim()) {
166 auto size = dim.size();
167 if (size == -1) size = cfg.placeholder_unknown_output_shape_dim;
168 shape.AddDim(size);
169 shape_proto.add_dim()->set_size(size);
170 }
171 }
172 }
173
174 Tensor fake_input(type, shape);
175 InitializeTensor(type, &fake_input);
176
177 if (cfg.feed_nodes.empty()) {
178 // No specific feed nodes were given. Assume all placeholders are fed.
179 if (signature_feed_nodes.count(node->name()) == 0) {
180 new_item->feed.emplace_back(node->name(), fake_input);
181 }
182 } else if (cfg.feed_nodes.count(node->name()) > 0) {
183 // If specific feed nodes were given, only update their tensors.
184 auto it = find_if(new_item->feed.begin(), new_item->feed.end(),
185 [&node](std::pair<string, Tensor>& f) {
186 return f.first == node->name();
187 });
188 DCHECK(it != new_item->feed.end());
189 it->second = fake_input;
190 }
191
192 // Set the shape of the node in the graph. This is needed for statically
193 // inferring shapes and is a no-op when dynamically inferring shapes as
194 // the Placeholder shape will match the shape passed from new_item->feed.
195 // Only replace node shape with known shape. For unknown shape keep it intact
196 // (b/134092018).
197 if (!shape_proto.dim().empty())
198 *(node->mutable_attr()->at("shape").mutable_shape()) = shape_proto;
199
200 return Status::OK();
201 }
202
203 } // namespace
204
RuntimeGraphOptimizer(const GraphDef & graph_def_arg,GraphDef * output_graph_def,const ItemConfig & cfg)205 Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
206 GraphDef* output_graph_def,
207 const ItemConfig& cfg) {
208 // This is a temporary change that optimizes the graph in context of a single
209 // gpu machine. Down the line, we may want to make grappler_item_builder aware
210 // of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated
211 // in order to get the correct session options and environment, and performing
212 // the correct optimizations.
213
214 // Return input as is if no graph-modifying config is set.
215 if (!cfg.apply_optimizations && !cfg.inline_functions &&
216 !cfg.erase_noinline_attributes) {
217 if (output_graph_def != &graph_def_arg) {
218 *output_graph_def = graph_def_arg;
219 }
220 return Status::OK();
221 }
222
223 // Create a session option for a single GPU device.
224 SessionOptions options;
225
226 // Make a local copy of graph def, because we need to change some things.
227 GraphDef graph_def(graph_def_arg);
228
229 if (cfg.erase_noinline_attributes) {
230 // TF optimizer doesn't inline functions with "_noinline" attribute,
231 // so let's go over the function library and erase it.
232 for (auto& func : *graph_def.mutable_library()->mutable_function()) {
233 func.mutable_attr()->erase("_noinline");
234 }
235 }
236
237 // Instantiate all variables for function library runtime creation.
238 std::vector<std::unique_ptr<Device>> devices;
239 // Only CPU device is used so instead of calling DeviceFactory::AddDevices()
240 // with dummy session config, which will conflict with user defined options
241 // and create unwanted devices, call cpu_factory->CreateDevices() to get CPU
242 // only devices.
243 DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
244 TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
245 options, "/job:localhost/replica:0/task:0", &devices));
246 Device* cpu_device = devices[0].get();
247 auto dvc_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
248 FunctionLibraryDefinition function_library(OpRegistry::Global(),
249 graph_def.library());
250 Env* env = Env::Default();
251
252 // Optimizer options: L1 and inlining. L1 is default.
253 OptimizerOptions* optimizer_opts =
254 options.config.mutable_graph_options()->mutable_optimizer_options();
255 if (cfg.apply_optimizations) {
256 optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions::L1);
257 } else {
258 optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions::L0);
259 }
260 optimizer_opts->set_do_function_inlining(cfg.inline_functions);
261
262 // Create the function library runtime.
263 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
264 new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, &options.config,
265 graph_def.versions().producer(),
266 &function_library, *optimizer_opts));
267 FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name());
268
269 // Create the GraphOptimizer to optimize the graph def.
270 GraphConstructorOptions graph_ctor_opts;
271 graph_ctor_opts.allow_internal_ops = true;
272 graph_ctor_opts.expect_device_spec = false;
273 std::unique_ptr<Graph> graphptr(new Graph(function_library));
274
275 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
276 graph_ctor_opts, std::move(graph_def), graphptr.get()));
277
278 // Optimize the graph.
279 ::tensorflow::GraphOptimizer optimizer(*optimizer_opts);
280 optimizer.Optimize(flr, env, cpu_device, &graphptr, /*shape_map=*/nullptr);
281 graphptr->ToGraphDef(output_graph_def);
282
283 // The default values of attributes might have been stripped by the optimizer.
284 // Add them back.
285 return AddDefaultAttrsToGraphDef(output_graph_def, *graphptr->op_registry(),
286 0, true);
287 }
288
GrapplerItemFromMetaGraphDef(const string & id,const MetaGraphDef & meta_graph,const ItemConfig & cfg)289 std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
290 const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) {
291 if (id.empty()) {
292 LOG(ERROR) << "id must be non-empty.";
293 return nullptr;
294 }
295 std::unique_ptr<GrapplerItem> new_item(new GrapplerItem());
296 new_item->id = id;
297 new_item->graph = meta_graph.graph_def();
298
299 // Fill in feed nodes from config, if any provided.
300 for (const auto& feed_node : cfg.feed_nodes) {
301 const string feed_name = NodeName(feed_node);
302 new_item->feed.emplace_back(feed_name, Tensor());
303 }
304 for (const auto& fetch_node : cfg.fetch_nodes) {
305 new_item->fetch.emplace_back(NodeName(fetch_node));
306 }
307
308 // Attempt to detect the fetch node(s) if they were not set explicitly.
309 if (new_item->fetch.empty() &&
310 meta_graph.collection_def().count("train_op") > 0) {
311 const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
312 if (nodes.has_node_list()) {
313 for (const auto& node : nodes.node_list().value()) {
314 new_item->fetch.push_back(NodeName(node));
315 }
316 }
317 }
318
319 // Detect feed and fetch nodes from signature defs. Signatures may share same
320 // inputs or outputs.
321 std::unordered_set<string> signature_feed_nodes;
322 std::unordered_set<string> signature_fetch_nodes;
323 for (const auto& name_and_signature : meta_graph.signature_def()) {
324 for (const auto& name_and_input : name_and_signature.second.inputs()) {
325 const TensorInfo& input = name_and_input.second;
326 if (input.has_coo_sparse()) {
327 // Define the shapes following the comment of CooSparse.
328 // TODO(yuefengz): we probably want to use different dim values for the
329 // three tensors of a SparseTensor.
330 int64 dim = std::max(1, cfg.placeholder_unknown_output_shape_dim);
331 TensorShape shape_1d({dim});
332 TensorShape shape_2d({dim, dim});
333
334 if (gtl::InsertIfNotPresent(
335 &signature_feed_nodes,
336 NodeName(input.coo_sparse().values_tensor_name()))) {
337 Tensor value_tensor(input.dtype(), shape_1d);
338 InitializeTensor(input.dtype(), &value_tensor);
339 new_item->feed.emplace_back(
340 NodeName(input.coo_sparse().values_tensor_name()), value_tensor);
341 }
342 if (gtl::InsertIfNotPresent(
343 &signature_feed_nodes,
344 NodeName(input.coo_sparse().indices_tensor_name()))) {
345 Tensor indices_tensor(DT_INT64, shape_2d);
346 InitializeTensor(input.dtype(), &indices_tensor);
347 new_item->feed.emplace_back(
348 NodeName(input.coo_sparse().indices_tensor_name()),
349 indices_tensor);
350 }
351 if (gtl::InsertIfNotPresent(
352 &signature_feed_nodes,
353 NodeName(input.coo_sparse().dense_shape_tensor_name()))) {
354 Tensor dense_shape_tensor(DT_INT64, shape_1d);
355 InitializeTensor(input.dtype(), &dense_shape_tensor);
356 new_item->feed.emplace_back(
357 NodeName(input.coo_sparse().dense_shape_tensor_name()),
358 dense_shape_tensor);
359 }
360 } else {
361 if (gtl::InsertIfNotPresent(&signature_feed_nodes,
362 NodeName(input.name()))) {
363 TensorShape shape;
364 TensorShapeProto shape_proto;
365 Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(),
366 &shape_proto, &shape);
367 if (!s.ok()) {
368 LOG(ERROR) << "Invalid shape for signature input " << input.name()
369 << ": " << s << ", skipping this input";
370 return nullptr;
371 }
372
373 Tensor fake_input(input.dtype(), shape);
374 InitializeTensor(input.dtype(), &fake_input);
375 new_item->feed.emplace_back(NodeName(input.name()), fake_input);
376 }
377 }
378 }
379 for (const auto& name_and_output : name_and_signature.second.outputs()) {
380 const TensorInfo& output = name_and_output.second;
381 if (output.has_coo_sparse()) {
382 if (gtl::InsertIfNotPresent(
383 &signature_fetch_nodes,
384 NodeName(output.coo_sparse().values_tensor_name()))) {
385 new_item->fetch.push_back(
386 NodeName(output.coo_sparse().values_tensor_name()));
387 }
388 if (gtl::InsertIfNotPresent(
389 &signature_fetch_nodes,
390 NodeName(output.coo_sparse().indices_tensor_name()))) {
391 new_item->fetch.push_back(
392 NodeName(output.coo_sparse().indices_tensor_name()));
393 }
394 if (gtl::InsertIfNotPresent(
395 &signature_fetch_nodes,
396 NodeName(output.coo_sparse().dense_shape_tensor_name()))) {
397 new_item->fetch.push_back(
398 NodeName(output.coo_sparse().dense_shape_tensor_name()));
399 }
400 } else {
401 if (gtl::InsertIfNotPresent(&signature_fetch_nodes,
402 NodeName(output.name()))) {
403 new_item->fetch.push_back(NodeName(output.name()));
404 }
405 }
406 }
407 }
408
409 for (const auto& feed : new_item->feed) {
410 if (feed.first.empty()) {
411 LOG(ERROR) << "Invalid feed node name skipping this input";
412 return nullptr;
413 } else {
414 VLOG(1) << "Will use feed node " << feed.first;
415 }
416 }
417
418 for (const auto& fetch : new_item->fetch) {
419 if (fetch.empty()) {
420 LOG(ERROR) << "Invalid fetch node name skipping this input";
421 return nullptr;
422 } else {
423 VLOG(1) << "Will use fetch node " << fetch;
424 }
425 }
426
427 if (new_item->fetch.empty()) {
428 LOG(ERROR) << "Failed to detect the fetch node(s), skipping this input";
429 return nullptr;
430 }
431
432 // TODO(yuefengz): consider handling saved_model_main_op and legacy_init_op.
433 // The reason why they are difficult to handle is because they may not intend
434 // to initialize all variables that are required to run fetch nodes. We may
435 // have to run restore op first.
436
437 // Try to find initializers from variables and tables as init ops.
438 for (const string& var_collection :
439 {"variables", "local_variables", "model_variables",
440 "trainable_variables"}) {
441 if (meta_graph.collection_def().count(var_collection) == 0) {
442 continue;
443 }
444 const CollectionDef& vars = meta_graph.collection_def().at(var_collection);
445 for (const auto& raw_var : vars.bytes_list().value()) {
446 VariableDef var;
447 var.ParseFromString(raw_var);
448 if (!var.initializer_name().empty()) {
449 new_item->init_ops.push_back(NodeName(var.initializer_name()));
450 }
451 }
452 }
453
454 if (meta_graph.collection_def().count("table_initializer") > 0) {
455 const CollectionDef& inits =
456 meta_graph.collection_def().at("table_initializer");
457 if (inits.has_node_list()) {
458 for (const auto& node : inits.node_list().value()) {
459 new_item->init_ops.push_back(NodeName(node));
460 // Tables are initialized from files, which can take a long time. Add
461 // 30 minutes to the initialization time for each table to avoid
462 // timing out.
463 // TODO(bsteiner): adjust the timeout based on the file size.
464 new_item->expected_init_time += 30 * 60;
465 }
466 }
467 }
468
469 // We keep the mapping from asset node to asset files. This should have been
470 // used as feed but since asset node is usually a constant node, we will fill
471 // the values of these constant nodes with their actual asset file paths.
472 std::unordered_map<string, string> asset_node_to_value;
473
474 // Assets file may have changed their directory, we assemble their new paths
475 // if assets_directory_override is set. We also make sure we still can
476 // access these asset files.
477 if (!cfg.assets_directory_override.empty()) {
478 if (meta_graph.collection_def().count("saved_model_assets") > 0) {
479 const CollectionDef& collection =
480 meta_graph.collection_def().at("saved_model_assets");
481 const auto& any_assets = collection.any_list().value();
482 if (!any_assets.empty()) {
483 #ifndef TENSORFLOW_LITE_PROTOS
484 for (const auto& any_asset : any_assets) {
485 AssetFileDef asset_file_def;
486 if (!ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef")
487 .ok()) {
488 LOG(ERROR) << "Failed to parse AssetFile.";
489 continue;
490 }
491 string asset_filepath = io::JoinPath(cfg.assets_directory_override,
492 asset_file_def.filename());
493 if (!FilesExist({asset_filepath}, nullptr)) {
494 LOG(ERROR) << "Can't access one or more of the asset files "
495 << asset_filepath << ", skipping this input";
496 return nullptr;
497 }
498 asset_node_to_value[NodeName(asset_file_def.tensor_info().name())] =
499 asset_filepath;
500 }
501 #else
502 LOG(ERROR) << "Can't parse AssetFileDef on mobile.";
503 return nullptr;
504 #endif // TENSORFLOW_LITE_PROTOS
505 }
506 }
507 } else if (meta_graph.collection_def().count("asset_filepaths") > 0) {
508 const CollectionDef& file_paths =
509 meta_graph.collection_def().at("asset_filepaths");
510 std::vector<string> paths;
511 for (const auto& raw_path : file_paths.bytes_list().value()) {
512 paths.push_back(raw_path);
513 }
514 if (!FilesExist(paths, nullptr)) {
515 LOG(ERROR) << "Can't access one or more of the asset files, skipping "
516 "this input";
517 return nullptr;
518 }
519 }
520
521 if (meta_graph.collection_def().count("queue_runners") > 0) {
522 const CollectionDef& vars = meta_graph.collection_def().at("queue_runners");
523 for (const auto& raw : vars.bytes_list().value()) {
524 QueueRunnerDef queue_runner;
525 if (!queue_runner.ParseFromString(raw)) {
526 LOG(ERROR) << "Could not parse queue_runners, skipping this input";
527 return nullptr;
528 }
529 if (queue_runner.cancel_op_name().empty()) {
530 LOG(ERROR) << "Queue without a cancel op, skipping this input";
531 return nullptr;
532 }
533 new_item->queue_runners.push_back(queue_runner);
534 }
535 }
536
537 // Add each node referenced in a collection to the list of nodes to keep.
538 for (const auto& col : meta_graph.collection_def()) {
539 const CollectionDef& collection = col.second;
540 for (const string& node : collection.node_list().value()) {
541 new_item->keep_ops.push_back(NodeName(node));
542 }
543 }
544
545 for (auto& node : *new_item->graph.mutable_node()) {
546 if (IsPlaceholder(node) && node.op() != "PlaceholderWithDefault") {
547 Status s = UpdatePlaceholderShape(cfg, signature_feed_nodes,
548 new_item.get(), &node);
549 if (!s.ok()) return nullptr;
550 } else if (IsConstant(node)) {
551 auto it = asset_node_to_value.find(node.name());
552 if (it != asset_node_to_value.end()) {
553 auto iter = node.mutable_attr()->find("value");
554 if (iter == node.attr().end()) {
555 LOG(ERROR) << "Value attribute expected in const op for asset files";
556 return nullptr;
557 }
558 if (!iter->second.has_tensor() ||
559 iter->second.tensor().string_val_size() != 1) {
560 LOG(INFO) << "Unexpected AttrValue proto: "
561 << iter->second.DebugString();
562 return nullptr;
563 }
564 LOG(INFO) << "Using asset file " << it->second << " for node "
565 << node.name();
566 *(iter->second.mutable_tensor()->mutable_string_val(0)) = it->second;
567 }
568 }
569
570 // Erase the recorded result of any previous shape inference to start again
571 // from scratch.
572 node.mutable_attr()->erase("_output_shapes");
573
574 // Delete user specified placement if requested.
575 if (cfg.ignore_user_placement) {
576 node.clear_device();
577 }
578 // Delete colocation constraints if requested.
579 if (cfg.ignore_colocation) {
580 auto attr = node.mutable_attr();
581 auto it = attr->find("_class");
582 if (it != attr->end()) {
583 attr->erase(it);
584 }
585 }
586 }
587
588 if (meta_graph.collection_def().count("savers") > 0) {
589 const CollectionDef& savers = meta_graph.collection_def().at("savers");
590 for (const auto& raw : savers.bytes_list().value()) {
591 SaverDef saver;
592 // Skip bad savers since we don't need saves/restores to be able to run a
593 // graph.
594 if (!saver.ParseFromString(raw)) {
595 continue;
596 }
597 if (saver.filename_tensor_name().empty()) {
598 continue;
599 }
600 new_item->save_op = saver.save_tensor_name();
601 new_item->restore_op = saver.restore_op_name();
602 new_item->save_restore_loc_tensor = saver.filename_tensor_name();
603 // Only use the first saver since it's not clear what to do if there's
604 // more than one.
605 break;
606 }
607 } else {
608 const SaverDef& saver = meta_graph.saver_def();
609 new_item->save_op = saver.save_tensor_name();
610 new_item->restore_op = saver.restore_op_name();
611 new_item->save_restore_loc_tensor = saver.filename_tensor_name();
612 }
613
614 // Instantiate all the missing attributes with their default values.
615 Status attr_status = AddDefaultAttrsToGraphDef(
616 &new_item->graph,
617 FunctionLibraryDefinition(OpRegistry::Global(),
618 new_item->graph.library()),
619 0, true);
620 if (!attr_status.ok()) {
621 LOG(ERROR) << "Failed to instantiate default attribute values: "
622 << attr_status.error_message();
623 return nullptr;
624 }
625
626 // Optimize the graph (function inlining, l1 optimizations, etc).
627 VLOG(1) << "Number of nodes in graph before RuntimeGraphOptimizer: "
628 << new_item->graph.node_size();
629 Status optimize_status =
630 RuntimeGraphOptimizer(new_item->graph, &new_item->graph, cfg);
631 if (!optimize_status.ok()) {
632 LOG(ERROR) << "Graph preprocessing failed: " << optimize_status;
633 return nullptr;
634 }
635 VLOG(1) << "Number of nodes in graph after RuntimeGraphOptimizer: "
636 << new_item->graph.node_size();
637
638 if (cfg.prune_graph) {
639 VLOG(1) << "Pruning graph...";
640 auto status = PruneGraph(new_item.get());
641 if (!status.ok()) {
642 LOG(ERROR) << "Pruning failed: " << status.error_message();
643 return nullptr;
644 }
645 VLOG(1) << "Number of nodes in graph after pruning: "
646 << new_item->graph.node_size();
647 }
648
649 // Validate feed, fetch and init nodes
650 std::unordered_set<string> nodes;
651 for (const auto& node : new_item->graph.node()) {
652 nodes.insert(node.name());
653 }
654 for (const auto& feed : new_item->feed) {
655 if (nodes.find(feed.first) == nodes.end()) {
656 LOG(ERROR) << "Feed node " << feed.first << " doesn't exist in graph";
657 return nullptr;
658 }
659 }
660 for (const auto& fetch : new_item->fetch) {
661 if (nodes.find(fetch) == nodes.end()) {
662 LOG(ERROR) << "Fetch node " << fetch << " doesn't exist in graph";
663 return nullptr;
664 }
665 }
666 for (const auto& init : new_item->init_ops) {
667 if (nodes.find(init) == nodes.end()) {
668 LOG(ERROR) << "Init node " << init << " doesn't exist in graph";
669 return nullptr;
670 }
671 }
672 return new_item;
673 }
674
GrapplerItemFromMetaGraphDefFile(const string & id,const string & meta_graph_file,const ItemConfig & cfg)675 std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDefFile(
676 const string& id, const string& meta_graph_file, const ItemConfig& cfg) {
677 MetaGraphDef meta_graph;
678 if (!ReadMetaGraphDefFromFile(meta_graph_file, &meta_graph).ok()) {
679 LOG(ERROR) << "Failed to read " << meta_graph_file;
680 return nullptr;
681 }
682 return GrapplerItemFromMetaGraphDef(id, meta_graph, cfg);
683 }
684
685 } // end namespace grappler
686 } // end namespace tensorflow
687