1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/grappler_item_builder.h"
16
17 #include <type_traits>
18 #include <unordered_map>
19 #include <unordered_set>
20 #include <vector>
21
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_factory.h"
24 #include "tensorflow/core/common_runtime/device_mgr.h"
25 #include "tensorflow/core/common_runtime/function.h"
26 #include "tensorflow/core/common_runtime/graph_constructor.h"
27 #include "tensorflow/core/common_runtime/graph_optimizer.h"
28 #include "tensorflow/core/framework/attr_value.pb.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/function.pb.h"
31 #include "tensorflow/core/framework/graph_def_util.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/tensor.pb.h"
35 #include "tensorflow/core/framework/tensor_shape.pb.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/framework/variable.pb.h"
38 #include "tensorflow/core/framework/versions.pb.h"
39 #include "tensorflow/core/grappler/inputs/utils.h"
40 #include "tensorflow/core/grappler/op_types.h"
41 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
42 #include "tensorflow/core/grappler/utils.h"
43 #include "tensorflow/core/lib/gtl/map_util.h"
44 #include "tensorflow/core/lib/io/path.h"
45 #include "tensorflow/core/platform/protobuf_internal.h"
46 #include "tensorflow/core/protobuf/meta_graph.pb.h"
47 #include "tensorflow/core/protobuf/saver.pb.h"
48 #include "tensorflow/core/public/session_options.h"
49
50 namespace tensorflow {
51 namespace grappler {
52
53 namespace {
54
InitializeTensor(DataType type,Tensor * tensor)55 void InitializeTensor(DataType type, Tensor* tensor) {
56 const int period = 7;
57 if (type == DT_FLOAT) {
58 auto flat = tensor->flat<float>();
59 // Populate numbers 0, 0.1, 0.2, ..., 0.5, 0.6, 0, 0.1, 0.2, ...
60 for (int i = 0; i < flat.size(); i++) {
61 flat(i) = static_cast<float>(i % period) / 10.0f;
62 }
63 } else if (type == DT_INT64) {
64 auto flat = tensor->flat<int64>();
65 // Populate numbers 0, 1, 2, ..., 5, 6, 0, 1, 2, ...
66 for (int i = 0; i < flat.size(); i++) {
67 flat(i) = i % period;
68 }
69 } else if (type != DT_STRING && type != DT_RESOURCE && type != DT_VARIANT) {
70 // DT_STRING, DT_RESOURCE and DT_VARIANT are not simple types according to
71 // is_simple_type<> in tensorflow/core/framework/type_traits.h, and
72 // Allocator will run non-trivial constructor/destructor for a Tensor with
73 // one of these types, so we should not memset its buffer.
74 memset(const_cast<char*>(tensor->tensor_data().data()), 0,
75 tensor->tensor_data().size());
76 }
77 }
78
79 // Applies the same graph pruning logic to the graph as Session.Run in TF.
80 // If the returned status is not OK, item state may be inconsistent.
PruneGraph(GrapplerItem * item)81 Status PruneGraph(GrapplerItem* item) {
82 ModelPruner pruner;
83 GraphDef pruned_graph;
84 Cluster* cluster = nullptr; // ModelPruner doesn't check cluster.
85 TF_RETURN_IF_ERROR(pruner.Optimize(cluster, *item, &pruned_graph));
86 item->graph = std::move(pruned_graph);
87 return Status::OK();
88 }
89
90 // Replace any unknown dimensions in a shape with
91 // cfg.placeholder_unknown_output_shape_dim if it is no less than 0.
ReplaceUnknownShapeDim(const ItemConfig & cfg,const TensorShapeProto & shape_pb_in,TensorShapeProto * shape_pb_out,TensorShape * shape_out)92 Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
93 const TensorShapeProto& shape_pb_in,
94 TensorShapeProto* shape_pb_out,
95 TensorShape* shape_out) {
96 std::vector<int32> dims;
97 for (const auto& dim_proto : shape_pb_in.dim()) {
98 if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
99 dim_proto.size() == -1) {
100 dims.push_back(cfg.placeholder_unknown_output_shape_dim);
101 shape_pb_out->add_dim()->set_size(
102 cfg.placeholder_unknown_output_shape_dim);
103 } else {
104 dims.push_back(std::max<int32>(1, dim_proto.size()));
105 shape_pb_out->add_dim()->set_size(dim_proto.size());
106 }
107 }
108 return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out);
109 }
110
111 // Replace unknown dimensions in Placeholder shape if
112 // cfg.placeholder_unknown_output_shape_dim is set or
113 // the Placeholder node has _output_shapes.
114 // Otherwise keep it intact to keep compatible with shape annotation
115 // (b/134092018).
UpdatePlaceholderShape(const ItemConfig & cfg,const std::unordered_set<string> & signature_feed_nodes,GrapplerItem * new_item,NodeDef * node)116 Status UpdatePlaceholderShape(
117 const ItemConfig& cfg,
118 const std::unordered_set<string>& signature_feed_nodes,
119 GrapplerItem* new_item, NodeDef* node) {
120 if (node->attr().count("dtype") == 0) {
121 return errors::Internal("Unknown type for placeholder ", node->name(),
122 ", skipping this input");
123 }
124 DataType type = node->attr().at("dtype").type();
125
126 // TODO(andiryxu): Consider cfg.placeholder_unknown_output_shape_dim >= 0 and
127 // _output_shapes is present case.
128 if (node->attr().count("shape") == 0) {
129 return errors::Internal("Unknown shape for placeholder ", node->name(),
130 ", skipping this input");
131 }
132
133 // Replace all unknown dimensions in the placeholder's tensorshape proto
134 // with cfg.placeholder_unknown_output_shape_dim and create a tensorshape
135 // from it. We do this because in newer protos, the input placeholder
136 // shape is not empty if the shape is partially defined.
137 TensorShape shape;
138 TensorShapeProto shape_proto;
139 Status make_shape_status = ReplaceUnknownShapeDim(
140 cfg, node->attr().at("shape").shape(), &shape_proto, &shape);
141 if (!make_shape_status.ok()) {
142 return errors::Internal("Invalid shape for placeholder ", node->name(),
143 ": ", make_shape_status, ", skipping this input");
144 }
145
146 // Some placeholder nodes have a mis-match between the node
147 // attribute "shape" and a different node attribute "_output_shapes".
148 // Specifically, a shape with shape.dims() == 0 could indicate either
149 // a scalar or an unknown shape. In those cases, we check _output_shapes
150 // for additional information.
151 // This case is observed in the bnmt graphs. Have not observed any
152 // cases where there was more than 1 _output_shapes, so limit it
153 // to cases where there is only 1 _output_shapes.
154 // We only do this if cfg.placeholder_unknown_output_shape_dim has
155 // been set to avoid crashing non-BNMT graphs.
156 // TODO(andiryxu): Investigate if this is a bug in BNMT graph.
157 if ((cfg.placeholder_unknown_output_shape_dim >= 0) && (shape.dims() == 0) &&
158 (node->attr().count("_output_shapes") == 1)) {
159 const auto& output_shapes =
160 node->attr().at("_output_shapes").list().shape(0);
161
162 if (output_shapes.dim_size() != 0) {
163 shape.Clear();
164 shape_proto.clear_dim();
165
166 for (const auto& dim : output_shapes.dim()) {
167 auto size = dim.size();
168 if (size == -1) size = cfg.placeholder_unknown_output_shape_dim;
169 shape.AddDim(size);
170 shape_proto.add_dim()->set_size(size);
171 }
172 }
173 }
174
175 Tensor fake_input(type, shape);
176 InitializeTensor(type, &fake_input);
177
178 if (cfg.feed_nodes.empty()) {
179 // No specific feed nodes were given. Assume all placeholders are fed.
180 if (signature_feed_nodes.count(node->name()) == 0) {
181 new_item->feed.emplace_back(node->name(), fake_input);
182 }
183 } else if (cfg.feed_nodes.count(node->name()) > 0) {
184 // If specific feed nodes were given, only update their tensors.
185 auto it = find_if(new_item->feed.begin(), new_item->feed.end(),
186 [&node](std::pair<string, Tensor>& f) {
187 return f.first == node->name();
188 });
189 DCHECK(it != new_item->feed.end());
190 it->second = fake_input;
191 }
192
193 // Set the shape of the node in the graph. This is needed for statically
194 // inferring shapes and is a no-op when dynamically inferring shapes as
195 // the Placeholder shape will match the shape passed from new_item->feed.
196 // Only replace node shape with known shape. For unknown shape keep it intact
197 // (b/134092018).
198 if (!shape_proto.dim().empty())
199 *(node->mutable_attr()->at("shape").mutable_shape()) = shape_proto;
200
201 return Status::OK();
202 }
203
204 } // namespace
205
RuntimeGraphOptimizer(const GraphDef & graph_def_arg,GraphDef * output_graph_def,const ItemConfig & cfg)206 Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
207 GraphDef* output_graph_def,
208 const ItemConfig& cfg) {
209 // This is a temporary change that optimizes the graph in context of a single
210 // gpu machine. Down the line, we may want to make grappler_item_builder aware
211 // of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated
212 // in order to get the correct session options and environment, and performing
213 // the correct optimizations.
214
215 // Return input as is if no graph-modifying config is set.
216 if (!cfg.apply_optimizations && !cfg.inline_functions &&
217 !cfg.erase_noinline_attributes) {
218 if (output_graph_def != &graph_def_arg) {
219 *output_graph_def = graph_def_arg;
220 }
221 return Status::OK();
222 }
223
224 // Create a session option for a single GPU device.
225 SessionOptions options;
226
227 // Make a local copy of graph def, because we need to change some things.
228 GraphDef graph_def(graph_def_arg);
229
230 if (cfg.erase_noinline_attributes) {
231 // TF optimizer doesn't inline functions with "_noinline" attribute,
232 // so let's go over the function library and erase it.
233 for (auto& func : *graph_def.mutable_library()->mutable_function()) {
234 func.mutable_attr()->erase("_noinline");
235 }
236 }
237
238 // Instantiate all variables for function library runtime creation.
239 std::vector<std::unique_ptr<Device>> devices;
240 // Only CPU device is used so instead of calling DeviceFactory::AddDevices()
241 // with dummy session config, which will conflict with user defined options
242 // and create unwanted devices, call cpu_factory->CreateDevices() to get CPU
243 // only devices.
244 DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
245 TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
246 options, "/job:localhost/replica:0/task:0", &devices));
247 Device* cpu_device = devices[0].get();
248 auto dvc_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
249 FunctionLibraryDefinition function_library(OpRegistry::Global(),
250 graph_def.library());
251 Env* env = Env::Default();
252
253 // Optimizer options: L1 and inlining. L1 is default.
254 OptimizerOptions* optimizer_opts =
255 options.config.mutable_graph_options()->mutable_optimizer_options();
256 if (cfg.apply_optimizations) {
257 optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions::L1);
258 } else {
259 optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions::L0);
260 }
261 optimizer_opts->set_do_function_inlining(cfg.inline_functions);
262
263 // Create the function library runtime.
264 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
265 new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, &options.config,
266 graph_def.versions().producer(),
267 &function_library, *optimizer_opts));
268 FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name());
269
270 // Create the GraphOptimizer to optimize the graph def.
271 GraphConstructorOptions graph_ctor_opts;
272 graph_ctor_opts.allow_internal_ops = true;
273 graph_ctor_opts.expect_device_spec = false;
274 std::unique_ptr<Graph> graphptr(new Graph(function_library));
275
276 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
277 graph_ctor_opts, std::move(graph_def), graphptr.get()));
278
279 // Optimize the graph.
280 ::tensorflow::GraphOptimizer optimizer(*optimizer_opts);
281 optimizer.Optimize(flr, env, cpu_device, &graphptr, /*shape_map=*/nullptr);
282 graphptr->ToGraphDef(output_graph_def);
283
284 // The default values of attributes might have been stripped by the optimizer.
285 // Add them back.
286 return AddDefaultAttrsToGraphDef(output_graph_def, *graphptr->op_registry(),
287 0, true);
288 }
289
GrapplerItemFromMetaGraphDef(const string & id,const MetaGraphDef & meta_graph,const ItemConfig & cfg)290 std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
291 const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) {
292 if (id.empty()) {
293 LOG(ERROR) << "id must be non-empty.";
294 return nullptr;
295 }
296 std::unique_ptr<GrapplerItem> new_item(new GrapplerItem());
297 new_item->id = id;
298 new_item->graph = meta_graph.graph_def();
299
300 // Fill in feed nodes from config, if any provided.
301 for (const auto& feed_node : cfg.feed_nodes) {
302 const string feed_name = NodeName(feed_node);
303 new_item->feed.emplace_back(feed_name, Tensor());
304 }
305 for (const auto& fetch_node : cfg.fetch_nodes) {
306 new_item->fetch.emplace_back(NodeName(fetch_node));
307 }
308
309 // Attempt to detect the fetch node(s) if they were not set explicitly.
310 if (new_item->fetch.empty() &&
311 meta_graph.collection_def().count("train_op") > 0) {
312 const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
313 if (nodes.has_node_list()) {
314 for (const auto& node : nodes.node_list().value()) {
315 new_item->fetch.push_back(NodeName(node));
316 }
317 }
318 }
319
320 // Detect feed and fetch nodes from signature defs. Signatures may share same
321 // inputs or outputs.
322 std::unordered_set<string> signature_feed_nodes;
323 std::unordered_set<string> signature_fetch_nodes;
324 for (const auto& name_and_signature : meta_graph.signature_def()) {
325 for (const auto& name_and_input : name_and_signature.second.inputs()) {
326 const TensorInfo& input = name_and_input.second;
327 if (input.has_coo_sparse()) {
328 // Define the shapes following the comment of CooSparse.
329 // TODO(yuefengz): we probably want to use different dim values for the
330 // three tensors of a SparseTensor.
331 int64 dim = std::max(1, cfg.placeholder_unknown_output_shape_dim);
332 TensorShape shape_1d({dim});
333 TensorShape shape_2d({dim, dim});
334
335 if (gtl::InsertIfNotPresent(
336 &signature_feed_nodes,
337 NodeName(input.coo_sparse().values_tensor_name()))) {
338 Tensor value_tensor(input.dtype(), shape_1d);
339 InitializeTensor(input.dtype(), &value_tensor);
340 new_item->feed.emplace_back(
341 NodeName(input.coo_sparse().values_tensor_name()), value_tensor);
342 }
343 if (gtl::InsertIfNotPresent(
344 &signature_feed_nodes,
345 NodeName(input.coo_sparse().indices_tensor_name()))) {
346 Tensor indices_tensor(DT_INT64, shape_2d);
347 InitializeTensor(input.dtype(), &indices_tensor);
348 new_item->feed.emplace_back(
349 NodeName(input.coo_sparse().indices_tensor_name()),
350 indices_tensor);
351 }
352 if (gtl::InsertIfNotPresent(
353 &signature_feed_nodes,
354 NodeName(input.coo_sparse().dense_shape_tensor_name()))) {
355 Tensor dense_shape_tensor(DT_INT64, shape_1d);
356 InitializeTensor(input.dtype(), &dense_shape_tensor);
357 new_item->feed.emplace_back(
358 NodeName(input.coo_sparse().dense_shape_tensor_name()),
359 dense_shape_tensor);
360 }
361 } else {
362 if (gtl::InsertIfNotPresent(&signature_feed_nodes,
363 NodeName(input.name()))) {
364 TensorShape shape;
365 TensorShapeProto shape_proto;
366 Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(),
367 &shape_proto, &shape);
368 if (!s.ok()) {
369 LOG(ERROR) << "Invalid shape for signature input " << input.name()
370 << ": " << s << ", skipping this input";
371 return nullptr;
372 }
373
374 Tensor fake_input(input.dtype(), shape);
375 InitializeTensor(input.dtype(), &fake_input);
376 new_item->feed.emplace_back(NodeName(input.name()), fake_input);
377 }
378 }
379 }
380 for (const auto& name_and_output : name_and_signature.second.outputs()) {
381 const TensorInfo& output = name_and_output.second;
382 if (output.has_coo_sparse()) {
383 if (gtl::InsertIfNotPresent(
384 &signature_fetch_nodes,
385 NodeName(output.coo_sparse().values_tensor_name()))) {
386 new_item->fetch.push_back(
387 NodeName(output.coo_sparse().values_tensor_name()));
388 }
389 if (gtl::InsertIfNotPresent(
390 &signature_fetch_nodes,
391 NodeName(output.coo_sparse().indices_tensor_name()))) {
392 new_item->fetch.push_back(
393 NodeName(output.coo_sparse().indices_tensor_name()));
394 }
395 if (gtl::InsertIfNotPresent(
396 &signature_fetch_nodes,
397 NodeName(output.coo_sparse().dense_shape_tensor_name()))) {
398 new_item->fetch.push_back(
399 NodeName(output.coo_sparse().dense_shape_tensor_name()));
400 }
401 } else {
402 if (gtl::InsertIfNotPresent(&signature_fetch_nodes,
403 NodeName(output.name()))) {
404 new_item->fetch.push_back(NodeName(output.name()));
405 }
406 }
407 }
408 }
409
410 for (const auto& feed : new_item->feed) {
411 if (feed.first.empty()) {
412 LOG(ERROR) << "Invalid feed node name skipping this input";
413 return nullptr;
414 } else {
415 VLOG(1) << "Will use feed node " << feed.first;
416 }
417 }
418
419 for (const auto& fetch : new_item->fetch) {
420 if (fetch.empty()) {
421 LOG(ERROR) << "Invalid fetch node name skipping this input";
422 return nullptr;
423 } else {
424 VLOG(1) << "Will use fetch node " << fetch;
425 }
426 }
427
428 if (new_item->fetch.empty()) {
429 LOG(ERROR) << "Failed to detect the fetch node(s), skipping this input";
430 return nullptr;
431 }
432
433 // TODO(yuefengz): consider handling saved_model_main_op and legacy_init_op.
434 // The reason why they are difficult to handle is because they may not intend
435 // to initialize all variables that are required to run fetch nodes. We may
436 // have to run restore op first.
437
438 // Try to find initializers from variables and tables as init ops.
439 for (const string& var_collection :
440 {"variables", "local_variables", "model_variables",
441 "trainable_variables"}) {
442 if (meta_graph.collection_def().count(var_collection) == 0) {
443 continue;
444 }
445 const CollectionDef& vars = meta_graph.collection_def().at(var_collection);
446 for (const auto& raw_var : vars.bytes_list().value()) {
447 VariableDef var;
448 var.ParseFromString(raw_var);
449 if (!var.initializer_name().empty()) {
450 new_item->init_ops.push_back(NodeName(var.initializer_name()));
451 }
452 }
453 }
454
455 if (meta_graph.collection_def().count("table_initializer") > 0) {
456 const CollectionDef& inits =
457 meta_graph.collection_def().at("table_initializer");
458 if (inits.has_node_list()) {
459 for (const auto& node : inits.node_list().value()) {
460 new_item->init_ops.push_back(NodeName(node));
461 // Tables are initialized from files, which can take a long time. Add
462 // 30 minutes to the initialization time for each table to avoid
463 // timing out.
464 // TODO(bsteiner): adjust the timeout based on the file size.
465 new_item->expected_init_time += 30 * 60;
466 }
467 }
468 }
469
470 // We keep the mapping from asset node to asset files. This should have been
471 // used as feed but since asset node is usually a constant node, we will fill
472 // the values of these constant nodes with their actual asset file paths.
473 std::unordered_map<string, string> asset_node_to_value;
474
475 // Assets file may have changed their directory, we assemble their new paths
476 // if assets_directory_override is set. We also make sure we still can
477 // access these asset files.
478 if (!cfg.assets_directory_override.empty()) {
479 if (meta_graph.collection_def().count("saved_model_assets") > 0) {
480 const CollectionDef& collection =
481 meta_graph.collection_def().at("saved_model_assets");
482 const auto& any_assets = collection.any_list().value();
483 if (!any_assets.empty()) {
484 if (std::is_base_of<protobuf::Message, AssetFileDef>()) {
485 for (const auto& any_asset : any_assets) {
486 AssetFileDef asset_file_def;
487 if (!ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef")
488 .ok()) {
489 LOG(ERROR) << "Failed to parse AssetFile.";
490 continue;
491 }
492 string asset_filepath = io::JoinPath(cfg.assets_directory_override,
493 asset_file_def.filename());
494 if (!FilesExist({asset_filepath}, nullptr)) {
495 LOG(ERROR) << "Can't access one or more of the asset files "
496 << asset_filepath << ", skipping this input";
497 return nullptr;
498 }
499 asset_node_to_value[NodeName(asset_file_def.tensor_info().name())] =
500 asset_filepath;
501 }
502 } else {
503 LOG(ERROR) << "Can't parse AssetFileDef when using lite protos.";
504 return nullptr;
505 }
506 }
507 }
508 } else if (meta_graph.collection_def().count("asset_filepaths") > 0) {
509 const CollectionDef& file_paths =
510 meta_graph.collection_def().at("asset_filepaths");
511 std::vector<string> paths;
512 for (const auto& raw_path : file_paths.bytes_list().value()) {
513 paths.push_back(raw_path);
514 }
515 if (!FilesExist(paths, nullptr)) {
516 LOG(ERROR) << "Can't access one or more of the asset files, skipping "
517 "this input";
518 return nullptr;
519 }
520 }
521
522 if (meta_graph.collection_def().count("queue_runners") > 0) {
523 const CollectionDef& vars = meta_graph.collection_def().at("queue_runners");
524 for (const auto& raw : vars.bytes_list().value()) {
525 QueueRunnerDef queue_runner;
526 if (!queue_runner.ParseFromString(raw)) {
527 LOG(ERROR) << "Could not parse queue_runners, skipping this input";
528 return nullptr;
529 }
530 if (queue_runner.cancel_op_name().empty()) {
531 LOG(ERROR) << "Queue without a cancel op, skipping this input";
532 return nullptr;
533 }
534 new_item->queue_runners.push_back(queue_runner);
535 }
536 }
537
538 // Add each node referenced in a collection to the list of nodes to keep.
539 for (const auto& col : meta_graph.collection_def()) {
540 const CollectionDef& collection = col.second;
541 for (const string& node : collection.node_list().value()) {
542 new_item->keep_ops.push_back(NodeName(node));
543 }
544 }
545
546 for (auto& node : *new_item->graph.mutable_node()) {
547 if (IsPlaceholder(node) && node.op() != "PlaceholderWithDefault") {
548 Status s = UpdatePlaceholderShape(cfg, signature_feed_nodes,
549 new_item.get(), &node);
550 if (!s.ok()) return nullptr;
551 } else if (IsConstant(node)) {
552 auto it = asset_node_to_value.find(node.name());
553 if (it != asset_node_to_value.end()) {
554 auto iter = node.mutable_attr()->find("value");
555 if (iter == node.attr().end()) {
556 LOG(ERROR) << "Value attribute expected in const op for asset files";
557 return nullptr;
558 }
559 if (!iter->second.has_tensor() ||
560 iter->second.tensor().string_val_size() != 1) {
561 LOG(INFO) << "Unexpected AttrValue proto: "
562 << iter->second.DebugString();
563 return nullptr;
564 }
565 LOG(INFO) << "Using asset file " << it->second << " for node "
566 << node.name();
567 *(iter->second.mutable_tensor()->mutable_string_val(0)) = it->second;
568 }
569 }
570
571 // Erase the recorded result of any previous shape inference to start again
572 // from scratch.
573 node.mutable_attr()->erase("_output_shapes");
574
575 // Delete user specified placement if requested.
576 if (cfg.ignore_user_placement) {
577 node.clear_device();
578 }
579 // Delete colocation constraints if requested.
580 if (cfg.ignore_colocation) {
581 auto attr = node.mutable_attr();
582 auto it = attr->find("_class");
583 if (it != attr->end()) {
584 attr->erase(it);
585 }
586 }
587 }
588
589 if (meta_graph.collection_def().count("savers") > 0) {
590 const CollectionDef& savers = meta_graph.collection_def().at("savers");
591 for (const auto& raw : savers.bytes_list().value()) {
592 SaverDef saver;
593 // Skip bad savers since we don't need saves/restores to be able to run a
594 // graph.
595 if (!saver.ParseFromString(raw)) {
596 continue;
597 }
598 if (saver.filename_tensor_name().empty()) {
599 continue;
600 }
601 new_item->save_op = saver.save_tensor_name();
602 new_item->restore_op = saver.restore_op_name();
603 new_item->save_restore_loc_tensor = saver.filename_tensor_name();
604 // Only use the first saver since it's not clear what to do if there's
605 // more than one.
606 break;
607 }
608 } else {
609 const SaverDef& saver = meta_graph.saver_def();
610 new_item->save_op = saver.save_tensor_name();
611 new_item->restore_op = saver.restore_op_name();
612 new_item->save_restore_loc_tensor = saver.filename_tensor_name();
613 }
614
615 // Instantiate all the missing attributes with their default values.
616 Status attr_status = AddDefaultAttrsToGraphDef(
617 &new_item->graph,
618 FunctionLibraryDefinition(OpRegistry::Global(),
619 new_item->graph.library()),
620 0, true);
621 if (!attr_status.ok()) {
622 LOG(ERROR) << "Failed to instantiate default attribute values: "
623 << attr_status.error_message();
624 return nullptr;
625 }
626
627 // Optimize the graph (function inlining, l1 optimizations, etc).
628 VLOG(1) << "Number of nodes in graph before RuntimeGraphOptimizer: "
629 << new_item->graph.node_size();
630 Status optimize_status =
631 RuntimeGraphOptimizer(new_item->graph, &new_item->graph, cfg);
632 if (!optimize_status.ok()) {
633 LOG(ERROR) << "Graph preprocessing failed: " << optimize_status;
634 return nullptr;
635 }
636 VLOG(1) << "Number of nodes in graph after RuntimeGraphOptimizer: "
637 << new_item->graph.node_size();
638
639 if (cfg.prune_graph) {
640 VLOG(1) << "Pruning graph...";
641 auto status = PruneGraph(new_item.get());
642 if (!status.ok()) {
643 LOG(ERROR) << "Pruning failed: " << status.error_message();
644 return nullptr;
645 }
646 VLOG(1) << "Number of nodes in graph after pruning: "
647 << new_item->graph.node_size();
648 }
649
650 // Validate feed, fetch and init nodes
651 std::unordered_set<string> nodes;
652 for (const auto& node : new_item->graph.node()) {
653 nodes.insert(node.name());
654 }
655 for (const auto& feed : new_item->feed) {
656 if (nodes.find(feed.first) == nodes.end()) {
657 LOG(ERROR) << "Feed node " << feed.first << " doesn't exist in graph";
658 return nullptr;
659 }
660 }
661 for (const auto& fetch : new_item->fetch) {
662 if (nodes.find(fetch) == nodes.end()) {
663 LOG(ERROR) << "Fetch node " << fetch << " doesn't exist in graph";
664 return nullptr;
665 }
666 }
667 for (const auto& init : new_item->init_ops) {
668 if (nodes.find(init) == nodes.end()) {
669 LOG(ERROR) << "Init node " << init << " doesn't exist in graph";
670 return nullptr;
671 }
672 }
673 return new_item;
674 }
675
GrapplerItemFromMetaGraphDefFile(const string & id,const string & meta_graph_file,const ItemConfig & cfg)676 std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDefFile(
677 const string& id, const string& meta_graph_file, const ItemConfig& cfg) {
678 MetaGraphDef meta_graph;
679 if (!ReadMetaGraphDefFromFile(meta_graph_file, &meta_graph).ok()) {
680 LOG(ERROR) << "Failed to read " << meta_graph_file;
681 return nullptr;
682 }
683 return GrapplerItemFromMetaGraphDef(id, meta_graph, cfg);
684 }
685
686 } // end namespace grappler
687 } // end namespace tensorflow
688