• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/grappler_item_builder.h"
16 
17 #include <unordered_map>
18 #include <unordered_set>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_factory.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/graph_optimizer.h"
26 #include "tensorflow/core/framework/attr_value.pb.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/function.pb.h"
29 #include "tensorflow/core/framework/graph_def_util.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/tensor.pb.h"
33 #include "tensorflow/core/framework/tensor_shape.pb.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/framework/variable.pb.h"
36 #include "tensorflow/core/framework/versions.pb.h"
37 #include "tensorflow/core/graph/graph_constructor.h"
38 #include "tensorflow/core/grappler/inputs/utils.h"
39 #include "tensorflow/core/grappler/op_types.h"
40 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
41 #include "tensorflow/core/grappler/utils.h"
42 #include "tensorflow/core/lib/gtl/map_util.h"
43 #include "tensorflow/core/lib/io/path.h"
44 #include "tensorflow/core/platform/protobuf_internal.h"
45 #include "tensorflow/core/protobuf/meta_graph.pb.h"
46 #include "tensorflow/core/protobuf/saver.pb.h"
47 #include "tensorflow/core/public/session_options.h"
48 
49 namespace tensorflow {
50 namespace grappler {
51 
52 namespace {
53 
InitializeTensor(DataType type,Tensor * tensor)54 void InitializeTensor(DataType type, Tensor* tensor) {
55   const int period = 7;
56   if (type == DT_FLOAT) {
57     auto flat = tensor->flat<float>();
58     // Populate numbers 0, 0.1, 0.2, ..., 0.5, 0.6, 0, 0.1, 0.2, ...
59     for (int i = 0; i < flat.size(); i++) {
60       flat(i) = static_cast<float>(i % period) / 10.0f;
61     }
62   } else if (type == DT_INT64) {
63     auto flat = tensor->flat<int64>();
64     // Populate numbers 0, 1, 2, ..., 5, 6, 0, 1, 2, ...
65     for (int i = 0; i < flat.size(); i++) {
66       flat(i) = i % period;
67     }
68   } else if (type != DT_STRING && type != DT_RESOURCE && type != DT_VARIANT) {
69     // DT_STRING, DT_RESOURCE and DT_VARIANT are not simple types according to
70     // is_simple_type<> in tensorflow/core/framework/type_traits.h, and
71     // Allocator will run non-trivial constructor/destructor for a Tensor with
72     // one of these types, so we should not memset its buffer.
73     memset(const_cast<char*>(tensor->tensor_data().data()), 0,
74            tensor->tensor_data().size());
75   }
76 }
77 
78 // Applies the same graph pruning logic to the graph as Session.Run in TF.
79 // If the returned status is not OK, item state may be inconsistent.
PruneGraph(GrapplerItem * item)80 Status PruneGraph(GrapplerItem* item) {
81   ModelPruner pruner;
82   GraphDef pruned_graph;
83   Cluster* cluster = nullptr;  // ModelPruner doesn't check cluster.
84   TF_RETURN_IF_ERROR(pruner.Optimize(cluster, *item, &pruned_graph));
85   item->graph = std::move(pruned_graph);
86   return Status::OK();
87 }
88 
89 // Replace any unknown dimensions in a shape with
90 // cfg.placeholder_unknown_output_shape_dim if it is no less than 0.
ReplaceUnknownShapeDim(const ItemConfig & cfg,const TensorShapeProto & shape_pb_in,TensorShapeProto * shape_pb_out,TensorShape * shape_out)91 Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
92                               const TensorShapeProto& shape_pb_in,
93                               TensorShapeProto* shape_pb_out,
94                               TensorShape* shape_out) {
95   std::vector<int32> dims;
96   for (const auto& dim_proto : shape_pb_in.dim()) {
97     if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
98         dim_proto.size() == -1) {
99       dims.push_back(cfg.placeholder_unknown_output_shape_dim);
100       shape_pb_out->add_dim()->set_size(
101           cfg.placeholder_unknown_output_shape_dim);
102     } else {
103       dims.push_back(std::max<int32>(1, dim_proto.size()));
104       shape_pb_out->add_dim()->set_size(dim_proto.size());
105     }
106   }
107   return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out);
108 }
109 
110 // Replace unknown dimensions in Placeholder shape if
111 // cfg.placeholder_unknown_output_shape_dim is set or
112 // the Placeholder node has _output_shapes.
113 // Otherwise keep it intact to keep compatible with shape annotation
114 // (b/134092018).
UpdatePlaceholderShape(const ItemConfig & cfg,const std::unordered_set<string> & signature_feed_nodes,GrapplerItem * new_item,NodeDef * node)115 Status UpdatePlaceholderShape(
116     const ItemConfig& cfg,
117     const std::unordered_set<string>& signature_feed_nodes,
118     GrapplerItem* new_item, NodeDef* node) {
119   if (node->attr().count("dtype") == 0) {
120     return errors::Internal("Unknown type for placeholder ", node->name(),
121                             ", skipping this input");
122   }
123   DataType type = node->attr().at("dtype").type();
124 
125   // TODO(andiryxu): Consider cfg.placeholder_unknown_output_shape_dim >= 0 and
126   // _output_shapes is present case.
127   if (node->attr().count("shape") == 0) {
128     return errors::Internal("Unknown shape for placeholder ", node->name(),
129                             ", skipping this input");
130   }
131 
132   // Replace all unknown dimensions in the placeholder's tensorshape proto
133   // with cfg.placeholder_unknown_output_shape_dim and create a tensorshape
134   // from it. We do this because in newer protos, the input placeholder
135   // shape is not empty if the shape is partially defined.
136   TensorShape shape;
137   TensorShapeProto shape_proto;
138   Status make_shape_status = ReplaceUnknownShapeDim(
139       cfg, node->attr().at("shape").shape(), &shape_proto, &shape);
140   if (!make_shape_status.ok()) {
141     return errors::Internal("Invalid shape for placeholder ", node->name(),
142                             ": ", make_shape_status, ", skipping this input");
143   }
144 
145   // Some placeholder nodes have a mis-match between the node
146   // attribute "shape" and a different node attribute "_output_shapes".
147   // Specifically, a shape with shape.dims() == 0 could indicate either
148   // a scalar or an unknown shape. In those cases, we check _output_shapes
149   // for additional information.
150   // This case is observed in the bnmt graphs. Have not observed any
151   // cases where there was more than 1 _output_shapes, so limit it
152   // to cases where there is only 1 _output_shapes.
153   // We only do this if cfg.placeholder_unknown_output_shape_dim has
154   // been set to avoid crashing non-BNMT graphs.
155   // TODO(andiryxu): Investigate if this is a bug in BNMT graph.
156   if ((cfg.placeholder_unknown_output_shape_dim >= 0) && (shape.dims() == 0) &&
157       (node->attr().count("_output_shapes") == 1)) {
158     const auto& output_shapes =
159         node->attr().at("_output_shapes").list().shape(0);
160 
161     if (output_shapes.dim_size() != 0) {
162       shape.Clear();
163       shape_proto.clear_dim();
164 
165       for (const auto& dim : output_shapes.dim()) {
166         auto size = dim.size();
167         if (size == -1) size = cfg.placeholder_unknown_output_shape_dim;
168         shape.AddDim(size);
169         shape_proto.add_dim()->set_size(size);
170       }
171     }
172   }
173 
174   Tensor fake_input(type, shape);
175   InitializeTensor(type, &fake_input);
176 
177   if (cfg.feed_nodes.empty()) {
178     // No specific feed nodes were given. Assume all placeholders are fed.
179     if (signature_feed_nodes.count(node->name()) == 0) {
180       new_item->feed.emplace_back(node->name(), fake_input);
181     }
182   } else if (cfg.feed_nodes.count(node->name()) > 0) {
183     // If specific feed nodes were given, only update their tensors.
184     auto it = find_if(new_item->feed.begin(), new_item->feed.end(),
185                       [&node](std::pair<string, Tensor>& f) {
186                         return f.first == node->name();
187                       });
188     DCHECK(it != new_item->feed.end());
189     it->second = fake_input;
190   }
191 
192   // Set the shape of the node in the graph. This is needed for statically
193   // inferring shapes and is a no-op when dynamically inferring shapes as
194   // the Placeholder shape will match the shape passed from new_item->feed.
195   // Only replace node shape with known shape. For unknown shape keep it intact
196   // (b/134092018).
197   if (!shape_proto.dim().empty())
198     *(node->mutable_attr()->at("shape").mutable_shape()) = shape_proto;
199 
200   return Status::OK();
201 }
202 
203 }  // namespace
204 
RuntimeGraphOptimizer(const GraphDef & graph_def_arg,GraphDef * output_graph_def,const ItemConfig & cfg)205 Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
206                              GraphDef* output_graph_def,
207                              const ItemConfig& cfg) {
208   // This is a temporary change that optimizes the graph in context of a single
209   // gpu machine. Down the line, we may want to make grappler_item_builder aware
210   // of the cluster type (E.g: single cpu, multiple gpu, etc)  being simulated
211   // in order to get the correct session options and environment, and performing
212   // the correct optimizations.
213 
214   // Return input as is if no graph-modifying config is set.
215   if (!cfg.apply_optimizations && !cfg.inline_functions &&
216       !cfg.erase_noinline_attributes) {
217     if (output_graph_def != &graph_def_arg) {
218       *output_graph_def = graph_def_arg;
219     }
220     return Status::OK();
221   }
222 
223   // Create a session option for a single GPU device.
224   SessionOptions options;
225 
226   // Make a local copy of graph def, because we need to change some things.
227   GraphDef graph_def(graph_def_arg);
228 
229   if (cfg.erase_noinline_attributes) {
230     // TF optimizer doesn't inline functions with "_noinline" attribute,
231     // so let's go over the function library and erase it.
232     for (auto& func : *graph_def.mutable_library()->mutable_function()) {
233       func.mutable_attr()->erase("_noinline");
234     }
235   }
236 
237   // Instantiate all variables for function library runtime creation.
238   std::vector<std::unique_ptr<Device>> devices;
239   // Only CPU device is used so instead of calling DeviceFactory::AddDevices()
240   // with dummy session config, which will conflict with user defined options
241   // and create unwanted devices, call cpu_factory->CreateDevices() to get CPU
242   // only devices.
243   DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
244   TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
245       options, "/job:localhost/replica:0/task:0", &devices));
246   Device* cpu_device = devices[0].get();
247   auto dvc_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
248   FunctionLibraryDefinition function_library(OpRegistry::Global(),
249                                              graph_def.library());
250   Env* env = Env::Default();
251 
252   // Optimizer options: L1 and inlining. L1 is default.
253   OptimizerOptions* optimizer_opts =
254       options.config.mutable_graph_options()->mutable_optimizer_options();
255   if (cfg.apply_optimizations) {
256     optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions::L1);
257   } else {
258     optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions::L0);
259   }
260   optimizer_opts->set_do_function_inlining(cfg.inline_functions);
261 
262   // Create the function library runtime.
263   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
264       new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, &options.config,
265                                         graph_def.versions().producer(),
266                                         &function_library, *optimizer_opts));
267   FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name());
268 
269   // Create the GraphOptimizer to optimize the graph def.
270   GraphConstructorOptions graph_ctor_opts;
271   graph_ctor_opts.allow_internal_ops = true;
272   graph_ctor_opts.expect_device_spec = false;
273   std::unique_ptr<Graph> graphptr(new Graph(function_library));
274 
275   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
276       graph_ctor_opts, std::move(graph_def), graphptr.get()));
277 
278   // Optimize the graph.
279   ::tensorflow::GraphOptimizer optimizer(*optimizer_opts);
280   optimizer.Optimize(flr, env, cpu_device, &graphptr, /*shape_map=*/nullptr);
281   graphptr->ToGraphDef(output_graph_def);
282 
283   // The default values of attributes might have been stripped by the optimizer.
284   // Add them back.
285   return AddDefaultAttrsToGraphDef(output_graph_def, *graphptr->op_registry(),
286                                    0, true);
287 }
288 
GrapplerItemFromMetaGraphDef(const string & id,const MetaGraphDef & meta_graph,const ItemConfig & cfg)289 std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
290     const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) {
291   if (id.empty()) {
292     LOG(ERROR) << "id must be non-empty.";
293     return nullptr;
294   }
295   std::unique_ptr<GrapplerItem> new_item(new GrapplerItem());
296   new_item->id = id;
297   new_item->graph = meta_graph.graph_def();
298 
299   // Fill in feed nodes from config, if any provided.
300   for (const auto& feed_node : cfg.feed_nodes) {
301     const string feed_name = NodeName(feed_node);
302     new_item->feed.emplace_back(feed_name, Tensor());
303   }
304   for (const auto& fetch_node : cfg.fetch_nodes) {
305     new_item->fetch.emplace_back(NodeName(fetch_node));
306   }
307 
308   // Attempt to detect the fetch node(s) if they were not set explicitly.
309   if (new_item->fetch.empty() &&
310       meta_graph.collection_def().count("train_op") > 0) {
311     const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
312     if (nodes.has_node_list()) {
313       for (const auto& node : nodes.node_list().value()) {
314         new_item->fetch.push_back(NodeName(node));
315       }
316     }
317   }
318 
319   // Detect feed and fetch nodes from signature defs. Signatures may share same
320   // inputs or outputs.
321   std::unordered_set<string> signature_feed_nodes;
322   std::unordered_set<string> signature_fetch_nodes;
323   for (const auto& name_and_signature : meta_graph.signature_def()) {
324     for (const auto& name_and_input : name_and_signature.second.inputs()) {
325       const TensorInfo& input = name_and_input.second;
326       if (input.has_coo_sparse()) {
327         // Define the shapes following the comment of CooSparse.
328         // TODO(yuefengz): we probably want to use different dim values for the
329         // three tensors of a SparseTensor.
330         int64 dim = std::max(1, cfg.placeholder_unknown_output_shape_dim);
331         TensorShape shape_1d({dim});
332         TensorShape shape_2d({dim, dim});
333 
334         if (gtl::InsertIfNotPresent(
335                 &signature_feed_nodes,
336                 NodeName(input.coo_sparse().values_tensor_name()))) {
337           Tensor value_tensor(input.dtype(), shape_1d);
338           InitializeTensor(input.dtype(), &value_tensor);
339           new_item->feed.emplace_back(
340               NodeName(input.coo_sparse().values_tensor_name()), value_tensor);
341         }
342         if (gtl::InsertIfNotPresent(
343                 &signature_feed_nodes,
344                 NodeName(input.coo_sparse().indices_tensor_name()))) {
345           Tensor indices_tensor(DT_INT64, shape_2d);
346           InitializeTensor(input.dtype(), &indices_tensor);
347           new_item->feed.emplace_back(
348               NodeName(input.coo_sparse().indices_tensor_name()),
349               indices_tensor);
350         }
351         if (gtl::InsertIfNotPresent(
352                 &signature_feed_nodes,
353                 NodeName(input.coo_sparse().dense_shape_tensor_name()))) {
354           Tensor dense_shape_tensor(DT_INT64, shape_1d);
355           InitializeTensor(input.dtype(), &dense_shape_tensor);
356           new_item->feed.emplace_back(
357               NodeName(input.coo_sparse().dense_shape_tensor_name()),
358               dense_shape_tensor);
359         }
360       } else {
361         if (gtl::InsertIfNotPresent(&signature_feed_nodes,
362                                     NodeName(input.name()))) {
363           TensorShape shape;
364           TensorShapeProto shape_proto;
365           Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(),
366                                             &shape_proto, &shape);
367           if (!s.ok()) {
368             LOG(ERROR) << "Invalid shape for signature input " << input.name()
369                        << ": " << s << ", skipping this input";
370             return nullptr;
371           }
372 
373           Tensor fake_input(input.dtype(), shape);
374           InitializeTensor(input.dtype(), &fake_input);
375           new_item->feed.emplace_back(NodeName(input.name()), fake_input);
376         }
377       }
378     }
379     for (const auto& name_and_output : name_and_signature.second.outputs()) {
380       const TensorInfo& output = name_and_output.second;
381       if (output.has_coo_sparse()) {
382         if (gtl::InsertIfNotPresent(
383                 &signature_fetch_nodes,
384                 NodeName(output.coo_sparse().values_tensor_name()))) {
385           new_item->fetch.push_back(
386               NodeName(output.coo_sparse().values_tensor_name()));
387         }
388         if (gtl::InsertIfNotPresent(
389                 &signature_fetch_nodes,
390                 NodeName(output.coo_sparse().indices_tensor_name()))) {
391           new_item->fetch.push_back(
392               NodeName(output.coo_sparse().indices_tensor_name()));
393         }
394         if (gtl::InsertIfNotPresent(
395                 &signature_fetch_nodes,
396                 NodeName(output.coo_sparse().dense_shape_tensor_name()))) {
397           new_item->fetch.push_back(
398               NodeName(output.coo_sparse().dense_shape_tensor_name()));
399         }
400       } else {
401         if (gtl::InsertIfNotPresent(&signature_fetch_nodes,
402                                     NodeName(output.name()))) {
403           new_item->fetch.push_back(NodeName(output.name()));
404         }
405       }
406     }
407   }
408 
409   for (const auto& feed : new_item->feed) {
410     if (feed.first.empty()) {
411       LOG(ERROR) << "Invalid feed node name skipping this input";
412       return nullptr;
413     } else {
414       VLOG(1) << "Will use feed node " << feed.first;
415     }
416   }
417 
418   for (const auto& fetch : new_item->fetch) {
419     if (fetch.empty()) {
420       LOG(ERROR) << "Invalid fetch node name skipping this input";
421       return nullptr;
422     } else {
423       VLOG(1) << "Will use fetch node " << fetch;
424     }
425   }
426 
427   if (new_item->fetch.empty()) {
428     LOG(ERROR) << "Failed to detect the fetch node(s), skipping this input";
429     return nullptr;
430   }
431 
432   // TODO(yuefengz): consider handling saved_model_main_op and legacy_init_op.
433   // The reason why they are difficult to handle is because they may not intend
434   // to initialize all variables that are required to run fetch nodes. We may
435   // have to run restore op first.
436 
437   // Try to find initializers from variables and tables as init ops.
438   for (const string& var_collection :
439        {"variables", "local_variables", "model_variables",
440         "trainable_variables"}) {
441     if (meta_graph.collection_def().count(var_collection) == 0) {
442       continue;
443     }
444     const CollectionDef& vars = meta_graph.collection_def().at(var_collection);
445     for (const auto& raw_var : vars.bytes_list().value()) {
446       VariableDef var;
447       var.ParseFromString(raw_var);
448       if (!var.initializer_name().empty()) {
449         new_item->init_ops.push_back(NodeName(var.initializer_name()));
450       }
451     }
452   }
453 
454   if (meta_graph.collection_def().count("table_initializer") > 0) {
455     const CollectionDef& inits =
456         meta_graph.collection_def().at("table_initializer");
457     if (inits.has_node_list()) {
458       for (const auto& node : inits.node_list().value()) {
459         new_item->init_ops.push_back(NodeName(node));
460         // Tables are initialized from files, which can take a long time. Add
461         // 30 minutes to the initialization time for each table to avoid
462         // timing out.
463         // TODO(bsteiner): adjust the timeout based on the file size.
464         new_item->expected_init_time += 30 * 60;
465       }
466     }
467   }
468 
469   // We keep the mapping from asset node to asset files. This should have been
470   // used as feed but since asset node is usually a constant node, we will fill
471   // the values of these constant nodes with their actual asset file paths.
472   std::unordered_map<string, string> asset_node_to_value;
473 
474   // Assets file may have changed their directory, we assemble their new paths
475   // if assets_directory_override is set. We also make sure we still can
476   // access these asset files.
477   if (!cfg.assets_directory_override.empty()) {
478     if (meta_graph.collection_def().count("saved_model_assets") > 0) {
479       const CollectionDef& collection =
480           meta_graph.collection_def().at("saved_model_assets");
481       const auto& any_assets = collection.any_list().value();
482       if (!any_assets.empty()) {
483 #ifndef TENSORFLOW_LITE_PROTOS
484         for (const auto& any_asset : any_assets) {
485           AssetFileDef asset_file_def;
486           if (!ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef")
487                    .ok()) {
488             LOG(ERROR) << "Failed to parse AssetFile.";
489             continue;
490           }
491           string asset_filepath = io::JoinPath(cfg.assets_directory_override,
492                                                asset_file_def.filename());
493           if (!FilesExist({asset_filepath}, nullptr)) {
494             LOG(ERROR) << "Can't access one or more of the asset files "
495                        << asset_filepath << ", skipping this input";
496             return nullptr;
497           }
498           asset_node_to_value[NodeName(asset_file_def.tensor_info().name())] =
499               asset_filepath;
500         }
501 #else
502         LOG(ERROR) << "Can't parse AssetFileDef on mobile.";
503         return nullptr;
504 #endif  // TENSORFLOW_LITE_PROTOS
505       }
506     }
507   } else if (meta_graph.collection_def().count("asset_filepaths") > 0) {
508     const CollectionDef& file_paths =
509         meta_graph.collection_def().at("asset_filepaths");
510     std::vector<string> paths;
511     for (const auto& raw_path : file_paths.bytes_list().value()) {
512       paths.push_back(raw_path);
513     }
514     if (!FilesExist(paths, nullptr)) {
515       LOG(ERROR) << "Can't access one or more of the asset files, skipping "
516                     "this input";
517       return nullptr;
518     }
519   }
520 
521   if (meta_graph.collection_def().count("queue_runners") > 0) {
522     const CollectionDef& vars = meta_graph.collection_def().at("queue_runners");
523     for (const auto& raw : vars.bytes_list().value()) {
524       QueueRunnerDef queue_runner;
525       if (!queue_runner.ParseFromString(raw)) {
526         LOG(ERROR) << "Could not parse queue_runners, skipping this input";
527         return nullptr;
528       }
529       if (queue_runner.cancel_op_name().empty()) {
530         LOG(ERROR) << "Queue without a cancel op, skipping this input";
531         return nullptr;
532       }
533       new_item->queue_runners.push_back(queue_runner);
534     }
535   }
536 
537   // Add each node referenced in a collection to the list of nodes to keep.
538   for (const auto& col : meta_graph.collection_def()) {
539     const CollectionDef& collection = col.second;
540     for (const string& node : collection.node_list().value()) {
541       new_item->keep_ops.push_back(NodeName(node));
542     }
543   }
544 
545   for (auto& node : *new_item->graph.mutable_node()) {
546     if (IsPlaceholder(node) && node.op() != "PlaceholderWithDefault") {
547       Status s = UpdatePlaceholderShape(cfg, signature_feed_nodes,
548                                         new_item.get(), &node);
549       if (!s.ok()) return nullptr;
550     } else if (IsConstant(node)) {
551       auto it = asset_node_to_value.find(node.name());
552       if (it != asset_node_to_value.end()) {
553         auto iter = node.mutable_attr()->find("value");
554         if (iter == node.attr().end()) {
555           LOG(ERROR) << "Value attribute expected in const op for asset files";
556           return nullptr;
557         }
558         if (!iter->second.has_tensor() ||
559             iter->second.tensor().string_val_size() != 1) {
560           LOG(INFO) << "Unexpected AttrValue proto: "
561                     << iter->second.DebugString();
562           return nullptr;
563         }
564         LOG(INFO) << "Using asset file " << it->second << " for node "
565                   << node.name();
566         *(iter->second.mutable_tensor()->mutable_string_val(0)) = it->second;
567       }
568     }
569 
570     // Erase the recorded result of any previous shape inference to start again
571     // from scratch.
572     node.mutable_attr()->erase("_output_shapes");
573 
574     // Delete user specified placement if requested.
575     if (cfg.ignore_user_placement) {
576       node.clear_device();
577     }
578     // Delete colocation constraints if requested.
579     if (cfg.ignore_colocation) {
580       auto attr = node.mutable_attr();
581       auto it = attr->find("_class");
582       if (it != attr->end()) {
583         attr->erase(it);
584       }
585     }
586   }
587 
588   if (meta_graph.collection_def().count("savers") > 0) {
589     const CollectionDef& savers = meta_graph.collection_def().at("savers");
590     for (const auto& raw : savers.bytes_list().value()) {
591       SaverDef saver;
592       // Skip bad savers since we don't need saves/restores to be able to run a
593       // graph.
594       if (!saver.ParseFromString(raw)) {
595         continue;
596       }
597       if (saver.filename_tensor_name().empty()) {
598         continue;
599       }
600       new_item->save_op = saver.save_tensor_name();
601       new_item->restore_op = saver.restore_op_name();
602       new_item->save_restore_loc_tensor = saver.filename_tensor_name();
603       // Only use the first saver since it's not clear what to do if there's
604       // more than one.
605       break;
606     }
607   } else {
608     const SaverDef& saver = meta_graph.saver_def();
609     new_item->save_op = saver.save_tensor_name();
610     new_item->restore_op = saver.restore_op_name();
611     new_item->save_restore_loc_tensor = saver.filename_tensor_name();
612   }
613 
614   // Instantiate all the missing attributes with their default values.
615   Status attr_status = AddDefaultAttrsToGraphDef(
616       &new_item->graph,
617       FunctionLibraryDefinition(OpRegistry::Global(),
618                                 new_item->graph.library()),
619       0, true);
620   if (!attr_status.ok()) {
621     LOG(ERROR) << "Failed to instantiate default attribute values: "
622                << attr_status.error_message();
623     return nullptr;
624   }
625 
626   // Optimize the graph (function inlining, l1 optimizations, etc).
627   VLOG(1) << "Number of nodes in graph before RuntimeGraphOptimizer: "
628           << new_item->graph.node_size();
629   Status optimize_status =
630       RuntimeGraphOptimizer(new_item->graph, &new_item->graph, cfg);
631   if (!optimize_status.ok()) {
632     LOG(ERROR) << "Graph preprocessing failed: " << optimize_status;
633     return nullptr;
634   }
635   VLOG(1) << "Number of nodes in graph after RuntimeGraphOptimizer: "
636           << new_item->graph.node_size();
637 
638   if (cfg.prune_graph) {
639     VLOG(1) << "Pruning graph...";
640     auto status = PruneGraph(new_item.get());
641     if (!status.ok()) {
642       LOG(ERROR) << "Pruning failed: " << status.error_message();
643       return nullptr;
644     }
645     VLOG(1) << "Number of nodes in graph after pruning: "
646             << new_item->graph.node_size();
647   }
648 
649   // Validate feed, fetch and init nodes
650   std::unordered_set<string> nodes;
651   for (const auto& node : new_item->graph.node()) {
652     nodes.insert(node.name());
653   }
654   for (const auto& feed : new_item->feed) {
655     if (nodes.find(feed.first) == nodes.end()) {
656       LOG(ERROR) << "Feed node " << feed.first << " doesn't exist in graph";
657       return nullptr;
658     }
659   }
660   for (const auto& fetch : new_item->fetch) {
661     if (nodes.find(fetch) == nodes.end()) {
662       LOG(ERROR) << "Fetch node " << fetch << " doesn't exist in graph";
663       return nullptr;
664     }
665   }
666   for (const auto& init : new_item->init_ops) {
667     if (nodes.find(init) == nodes.end()) {
668       LOG(ERROR) << "Init node " << init << " doesn't exist in graph";
669       return nullptr;
670     }
671   }
672   return new_item;
673 }
674 
GrapplerItemFromMetaGraphDefFile(const string & id,const string & meta_graph_file,const ItemConfig & cfg)675 std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDefFile(
676     const string& id, const string& meta_graph_file, const ItemConfig& cfg) {
677   MetaGraphDef meta_graph;
678   if (!ReadMetaGraphDefFromFile(meta_graph_file, &meta_graph).ok()) {
679     LOG(ERROR) << "Failed to read " << meta_graph_file;
680     return nullptr;
681   }
682   return GrapplerItemFromMetaGraphDef(id, meta_graph, cfg);
683 }
684 
685 }  // end namespace grappler
686 }  // end namespace tensorflow
687