• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/grappler_item_builder.h"
16 
17 #include <type_traits>
18 #include <unordered_map>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_factory.h"
24 #include "tensorflow/core/common_runtime/device_mgr.h"
25 #include "tensorflow/core/common_runtime/function.h"
26 #include "tensorflow/core/common_runtime/graph_constructor.h"
27 #include "tensorflow/core/common_runtime/graph_optimizer.h"
28 #include "tensorflow/core/framework/attr_value.pb.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/function.pb.h"
31 #include "tensorflow/core/framework/graph_def_util.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/tensor.pb.h"
35 #include "tensorflow/core/framework/tensor_shape.pb.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/framework/variable.pb.h"
38 #include "tensorflow/core/framework/versions.pb.h"
39 #include "tensorflow/core/grappler/inputs/utils.h"
40 #include "tensorflow/core/grappler/op_types.h"
41 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
42 #include "tensorflow/core/grappler/utils.h"
43 #include "tensorflow/core/lib/gtl/map_util.h"
44 #include "tensorflow/core/lib/io/path.h"
45 #include "tensorflow/core/platform/protobuf_internal.h"
46 #include "tensorflow/core/protobuf/meta_graph.pb.h"
47 #include "tensorflow/core/protobuf/saver.pb.h"
48 #include "tensorflow/core/public/session_options.h"
49 
50 namespace tensorflow {
51 namespace grappler {
52 
53 namespace {
54 
InitializeTensor(DataType type,Tensor * tensor)55 void InitializeTensor(DataType type, Tensor* tensor) {
56   const int period = 7;
57   if (type == DT_FLOAT) {
58     auto flat = tensor->flat<float>();
59     // Populate numbers 0, 0.1, 0.2, ..., 0.5, 0.6, 0, 0.1, 0.2, ...
60     for (int i = 0; i < flat.size(); i++) {
61       flat(i) = static_cast<float>(i % period) / 10.0f;
62     }
63   } else if (type == DT_INT64) {
64     auto flat = tensor->flat<int64_t>();
65     // Populate numbers 0, 1, 2, ..., 5, 6, 0, 1, 2, ...
66     for (int i = 0; i < flat.size(); i++) {
67       flat(i) = i % period;
68     }
69   } else if (type != DT_STRING && type != DT_RESOURCE && type != DT_VARIANT) {
70     // DT_STRING, DT_RESOURCE and DT_VARIANT are not simple types according to
71     // is_simple_type<> in tensorflow/core/framework/type_traits.h, and
72     // Allocator will run non-trivial constructor/destructor for a Tensor with
73     // one of these types, so we should not memset its buffer.
74     memset(const_cast<char*>(tensor->tensor_data().data()), 0,
75            tensor->tensor_data().size());
76   }
77 }
78 
79 // Applies the same graph pruning logic to the graph as Session.Run in TF.
80 // If the returned status is not OK, item state may be inconsistent.
PruneGraph(GrapplerItem * item)81 Status PruneGraph(GrapplerItem* item) {
82   ModelPruner pruner;
83   GraphDef pruned_graph;
84   Cluster* cluster = nullptr;  // ModelPruner doesn't check cluster.
85   TF_RETURN_IF_ERROR(pruner.Optimize(cluster, *item, &pruned_graph));
86   item->graph = std::move(pruned_graph);
87   return OkStatus();
88 }
89 
90 // Replace any unknown dimensions in a shape with
91 // cfg.placeholder_unknown_output_shape_dim if it is no less than 0.
ReplaceUnknownShapeDim(const ItemConfig & cfg,const TensorShapeProto & shape_pb_in,TensorShapeProto * shape_pb_out,TensorShape * shape_out)92 Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
93                               const TensorShapeProto& shape_pb_in,
94                               TensorShapeProto* shape_pb_out,
95                               TensorShape* shape_out) {
96   std::vector<int32> dims;
97   for (const auto& dim_proto : shape_pb_in.dim()) {
98     if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
99         dim_proto.size() == -1) {
100       dims.push_back(cfg.placeholder_unknown_output_shape_dim);
101       shape_pb_out->add_dim()->set_size(
102           cfg.placeholder_unknown_output_shape_dim);
103     } else {
104       dims.push_back(std::max<int32>(1, dim_proto.size()));
105       shape_pb_out->add_dim()->set_size(dim_proto.size());
106     }
107   }
108   return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out);
109 }
110 
111 // Replace unknown dimensions in Placeholder shape if
112 // cfg.placeholder_unknown_output_shape_dim is set or
113 // the Placeholder node has _output_shapes.
114 // Otherwise keep it intact to keep compatible with shape annotation
115 // (b/134092018).
UpdatePlaceholderShape(const ItemConfig & cfg,const std::unordered_set<string> & signature_feed_nodes,GrapplerItem * new_item,NodeDef * node)116 Status UpdatePlaceholderShape(
117     const ItemConfig& cfg,
118     const std::unordered_set<string>& signature_feed_nodes,
119     GrapplerItem* new_item, NodeDef* node) {
120   if (node->attr().count("dtype") == 0) {
121     return errors::Internal("Unknown type for placeholder ", node->name(),
122                             ", skipping this input");
123   }
124   DataType type = node->attr().at("dtype").type();
125 
126   // TODO(andiryxu): Consider cfg.placeholder_unknown_output_shape_dim >= 0 and
127   // _output_shapes is present case.
128   if (node->attr().count("shape") == 0) {
129     return errors::Internal("Unknown shape for placeholder ", node->name(),
130                             ", skipping this input");
131   }
132 
133   // Replace all unknown dimensions in the placeholder's tensorshape proto
134   // with cfg.placeholder_unknown_output_shape_dim and create a tensorshape
135   // from it. We do this because in newer protos, the input placeholder
136   // shape is not empty if the shape is partially defined.
137   TensorShape shape;
138   TensorShapeProto shape_proto;
139   Status make_shape_status = ReplaceUnknownShapeDim(
140       cfg, node->attr().at("shape").shape(), &shape_proto, &shape);
141   if (!make_shape_status.ok()) {
142     return errors::Internal("Invalid shape for placeholder ", node->name(),
143                             ": ", make_shape_status, ", skipping this input");
144   }
145 
146   // Some placeholder nodes have a mismatch between the node
147   // attribute "shape" and a different node attribute "_output_shapes".
148   // Specifically, a shape with shape.dims() == 0 could indicate either
149   // a scalar or an unknown shape. In those cases, we check _output_shapes
150   // for additional information.
151   // This case is observed in the bnmt graphs. Have not observed any
152   // cases where there was more than 1 _output_shapes, so limit it
153   // to cases where there is only 1 _output_shapes.
154   // We only do this if cfg.placeholder_unknown_output_shape_dim has
155   // been set to avoid crashing non-BNMT graphs.
156   // TODO(andiryxu): Investigate if this is a bug in BNMT graph.
157   if ((cfg.placeholder_unknown_output_shape_dim >= 0) && (shape.dims() == 0) &&
158       (node->attr().count("_output_shapes") == 1)) {
159     const auto& output_shapes =
160         node->attr().at("_output_shapes").list().shape(0);
161 
162     if (output_shapes.dim_size() != 0) {
163       shape.Clear();
164       shape_proto.clear_dim();
165 
166       for (const auto& dim : output_shapes.dim()) {
167         auto size = dim.size();
168         if (size == -1) size = cfg.placeholder_unknown_output_shape_dim;
169         shape.AddDim(size);
170         shape_proto.add_dim()->set_size(size);
171       }
172     }
173   }
174 
175   Tensor fake_input(type, shape);
176   InitializeTensor(type, &fake_input);
177 
178   if (cfg.feed_nodes.empty()) {
179     // No specific feed nodes were given. Assume all placeholders are fed.
180     if (signature_feed_nodes.count(node->name()) == 0) {
181       new_item->feed.emplace_back(node->name(), fake_input);
182     }
183   } else if (cfg.feed_nodes.count(node->name()) > 0) {
184     // If specific feed nodes were given, only update their tensors.
185     auto it = find_if(new_item->feed.begin(), new_item->feed.end(),
186                       [&node](std::pair<string, Tensor>& f) {
187                         return f.first == node->name();
188                       });
189     DCHECK(it != new_item->feed.end());
190     it->second = fake_input;
191   }
192 
193   // Set the shape of the node in the graph. This is needed for statically
194   // inferring shapes and is a no-op when dynamically inferring shapes as
195   // the Placeholder shape will match the shape passed from new_item->feed.
196   // Only replace node shape with known shape. For unknown shape keep it intact
197   // (b/134092018).
198   if (!shape_proto.dim().empty())
199     *(node->mutable_attr()->at("shape").mutable_shape()) = shape_proto;
200 
201   return OkStatus();
202 }
203 
204 }  // namespace
205 
RuntimeGraphOptimizer(const GraphDef & graph_def_arg,GraphDef * output_graph_def,const ItemConfig & cfg)206 Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
207                              GraphDef* output_graph_def,
208                              const ItemConfig& cfg) {
209   // This is a temporary change that optimizes the graph in context of a single
210   // gpu machine. Down the line, we may want to make grappler_item_builder aware
211   // of the cluster type (E.g: single cpu, multiple gpu, etc)  being simulated
212   // in order to get the correct session options and environment, and performing
213   // the correct optimizations.
214 
215   // Return input as is if no graph-modifying config is set.
216   if (!cfg.apply_optimizations && !cfg.inline_functions &&
217       !cfg.erase_noinline_attributes) {
218     if (output_graph_def != &graph_def_arg) {
219       *output_graph_def = graph_def_arg;
220     }
221     return OkStatus();
222   }
223 
224   // Create a session option for a single GPU device.
225   SessionOptions options;
226 
227   // Make a local copy of graph def, because we need to change some things.
228   GraphDef graph_def(graph_def_arg);
229 
230   if (cfg.erase_noinline_attributes) {
231     // TF optimizer doesn't inline functions with "_noinline" attribute,
232     // so let's go over the function library and erase it.
233     for (auto& func : *graph_def.mutable_library()->mutable_function()) {
234       func.mutable_attr()->erase("_noinline");
235     }
236   }
237 
238   // Instantiate all variables for function library runtime creation.
239   std::vector<std::unique_ptr<Device>> devices;
240   // Only CPU device is used so instead of calling DeviceFactory::AddDevices()
241   // with dummy session config, which will conflict with user defined options
242   // and create unwanted devices, call cpu_factory->CreateDevices() to get CPU
243   // only devices.
244   DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
245   TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
246       options, "/job:localhost/replica:0/task:0", &devices));
247   Device* cpu_device = devices[0].get();
248   auto dvc_mgr = std::make_unique<StaticDeviceMgr>(std::move(devices));
249   FunctionLibraryDefinition function_library(OpRegistry::Global(),
250                                              graph_def.library());
251   Env* env = Env::Default();
252 
253   // Optimizer options: L1 and inlining. L1 is default.
254   OptimizerOptions* optimizer_opts =
255       options.config.mutable_graph_options()->mutable_optimizer_options();
256   if (cfg.apply_optimizations) {
257     optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions::L1);
258   } else {
259     optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions::L0);
260   }
261   optimizer_opts->set_do_function_inlining(cfg.inline_functions);
262 
263   // Create the function library runtime.
264   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
265       new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, &options.config,
266                                         graph_def.versions().producer(),
267                                         &function_library, *optimizer_opts));
268   FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name());
269 
270   // Create the GraphOptimizer to optimize the graph def.
271   GraphConstructorOptions graph_ctor_opts;
272   graph_ctor_opts.allow_internal_ops = true;
273   graph_ctor_opts.expect_device_spec = false;
274   std::unique_ptr<Graph> graphptr(new Graph(function_library));
275 
276   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
277       graph_ctor_opts, std::move(graph_def), graphptr.get()));
278 
279   // Optimize the graph.
280   ::tensorflow::GraphOptimizer optimizer(*optimizer_opts);
281   optimizer.Optimize(flr, env, cpu_device, &graphptr,
282                      tensorflow::GraphOptimizer::Options());
283   graphptr->ToGraphDef(output_graph_def);
284 
285   // The default values of attributes might have been stripped by the optimizer.
286   // Add them back.
287   return AddDefaultAttrsToGraphDef(output_graph_def, *graphptr->op_registry(),
288                                    0, true);
289 }
290 
GrapplerItemFromMetaGraphDef(const string & id,const MetaGraphDef & meta_graph,const ItemConfig & cfg)291 std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
292     const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) {
293   if (id.empty()) {
294     LOG(ERROR) << "id must be non-empty.";
295     return nullptr;
296   }
297   std::unique_ptr<GrapplerItem> new_item(new GrapplerItem());
298   new_item->id = id;
299   new_item->graph = meta_graph.graph_def();
300 
301   // Fill in feed nodes from config, if any provided.
302   for (const auto& feed_node : cfg.feed_nodes) {
303     const string feed_name = NodeName(feed_node);
304     new_item->feed.emplace_back(feed_name, Tensor());
305   }
306   for (const auto& fetch_node : cfg.fetch_nodes) {
307     new_item->fetch.emplace_back(NodeName(fetch_node));
308   }
309 
310   // Attempt to detect the fetch node(s) if they were not set explicitly.
311   if (new_item->fetch.empty() &&
312       meta_graph.collection_def().count("train_op") > 0) {
313     const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
314     if (nodes.has_node_list()) {
315       for (const auto& node : nodes.node_list().value()) {
316         new_item->fetch.push_back(NodeName(node));
317       }
318     }
319   }
320 
321   // Detect feed and fetch nodes from signature defs. Signatures may share same
322   // inputs or outputs.
323   std::unordered_set<string> signature_feed_nodes;
324   std::unordered_set<string> signature_fetch_nodes;
325   for (const auto& name_and_signature : meta_graph.signature_def()) {
326     for (const auto& name_and_input : name_and_signature.second.inputs()) {
327       const TensorInfo& input = name_and_input.second;
328       if (input.has_coo_sparse()) {
329         // Define the shapes following the comment of CooSparse.
330         // TODO(yuefengz): we probably want to use different dim values for the
331         // three tensors of a SparseTensor.
332         int64_t dim = std::max(1, cfg.placeholder_unknown_output_shape_dim);
333         TensorShape shape_1d({dim});
334         TensorShape shape_2d({dim, dim});
335 
336         if (gtl::InsertIfNotPresent(
337                 &signature_feed_nodes,
338                 NodeName(input.coo_sparse().values_tensor_name()))) {
339           Tensor value_tensor(input.dtype(), shape_1d);
340           InitializeTensor(input.dtype(), &value_tensor);
341           new_item->feed.emplace_back(
342               NodeName(input.coo_sparse().values_tensor_name()), value_tensor);
343         }
344         if (gtl::InsertIfNotPresent(
345                 &signature_feed_nodes,
346                 NodeName(input.coo_sparse().indices_tensor_name()))) {
347           Tensor indices_tensor(DT_INT64, shape_2d);
348           InitializeTensor(input.dtype(), &indices_tensor);
349           new_item->feed.emplace_back(
350               NodeName(input.coo_sparse().indices_tensor_name()),
351               indices_tensor);
352         }
353         if (gtl::InsertIfNotPresent(
354                 &signature_feed_nodes,
355                 NodeName(input.coo_sparse().dense_shape_tensor_name()))) {
356           Tensor dense_shape_tensor(DT_INT64, shape_1d);
357           InitializeTensor(input.dtype(), &dense_shape_tensor);
358           new_item->feed.emplace_back(
359               NodeName(input.coo_sparse().dense_shape_tensor_name()),
360               dense_shape_tensor);
361         }
362       } else {
363         if (gtl::InsertIfNotPresent(&signature_feed_nodes,
364                                     NodeName(input.name()))) {
365           TensorShape shape;
366           TensorShapeProto shape_proto;
367           Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(),
368                                             &shape_proto, &shape);
369           if (!s.ok()) {
370             LOG(ERROR) << "Invalid shape for signature input " << input.name()
371                        << ": " << s << ", skipping this input";
372             return nullptr;
373           }
374 
375           Tensor fake_input(input.dtype(), shape);
376           InitializeTensor(input.dtype(), &fake_input);
377           new_item->feed.emplace_back(NodeName(input.name()), fake_input);
378         }
379       }
380     }
381     for (const auto& name_and_output : name_and_signature.second.outputs()) {
382       const TensorInfo& output = name_and_output.second;
383       if (output.has_coo_sparse()) {
384         if (gtl::InsertIfNotPresent(
385                 &signature_fetch_nodes,
386                 NodeName(output.coo_sparse().values_tensor_name()))) {
387           new_item->fetch.push_back(
388               NodeName(output.coo_sparse().values_tensor_name()));
389         }
390         if (gtl::InsertIfNotPresent(
391                 &signature_fetch_nodes,
392                 NodeName(output.coo_sparse().indices_tensor_name()))) {
393           new_item->fetch.push_back(
394               NodeName(output.coo_sparse().indices_tensor_name()));
395         }
396         if (gtl::InsertIfNotPresent(
397                 &signature_fetch_nodes,
398                 NodeName(output.coo_sparse().dense_shape_tensor_name()))) {
399           new_item->fetch.push_back(
400               NodeName(output.coo_sparse().dense_shape_tensor_name()));
401         }
402       } else {
403         if (gtl::InsertIfNotPresent(&signature_fetch_nodes,
404                                     NodeName(output.name()))) {
405           new_item->fetch.push_back(NodeName(output.name()));
406         }
407       }
408     }
409   }
410 
411   for (const auto& feed : new_item->feed) {
412     if (feed.first.empty()) {
413       LOG(ERROR) << "Invalid feed node name skipping this input";
414       return nullptr;
415     } else {
416       VLOG(1) << "Will use feed node " << feed.first;
417     }
418   }
419 
420   for (const auto& fetch : new_item->fetch) {
421     if (fetch.empty()) {
422       LOG(ERROR) << "Invalid fetch node name skipping this input";
423       return nullptr;
424     } else {
425       VLOG(1) << "Will use fetch node " << fetch;
426     }
427   }
428 
429   if (new_item->fetch.empty()) {
430     LOG(ERROR) << "Failed to detect the fetch node(s), skipping this input";
431     return nullptr;
432   }
433 
434   // TODO(yuefengz): consider handling saved_model_main_op and legacy_init_op.
435   // The reason why they are difficult to handle is because they may not intend
436   // to initialize all variables that are required to run fetch nodes. We may
437   // have to run restore op first.
438 
439   // Try to find initializers from variables and tables as init ops.
440   for (const string& var_collection :
441        {"variables", "local_variables", "model_variables",
442         "trainable_variables"}) {
443     if (meta_graph.collection_def().count(var_collection) == 0) {
444       continue;
445     }
446     const CollectionDef& vars = meta_graph.collection_def().at(var_collection);
447     for (const auto& raw_var : vars.bytes_list().value()) {
448       VariableDef var;
449       var.ParseFromString(raw_var);
450       if (!var.initializer_name().empty()) {
451         new_item->init_ops.push_back(NodeName(var.initializer_name()));
452       }
453     }
454   }
455 
456   if (meta_graph.collection_def().count("table_initializer") > 0) {
457     const CollectionDef& inits =
458         meta_graph.collection_def().at("table_initializer");
459     if (inits.has_node_list()) {
460       for (const auto& node : inits.node_list().value()) {
461         new_item->init_ops.push_back(NodeName(node));
462         // Tables are initialized from files, which can take a long time. Add
463         // 30 minutes to the initialization time for each table to avoid
464         // timing out.
465         // TODO(bsteiner): adjust the timeout based on the file size.
466         new_item->expected_init_time += 30 * 60;
467       }
468     }
469   }
470 
471   // We keep the mapping from asset node to asset files. This should have been
472   // used as feed but since asset node is usually a constant node, we will fill
473   // the values of these constant nodes with their actual asset file paths.
474   std::unordered_map<string, string> asset_node_to_value;
475 
476   // Assets file may have changed their directory, we assemble their new paths
477   // if assets_directory_override is set. We also make sure we still can
478   // access these asset files.
479   if (!cfg.assets_directory_override.empty()) {
480     if (meta_graph.collection_def().count("saved_model_assets") > 0) {
481       const CollectionDef& collection =
482           meta_graph.collection_def().at("saved_model_assets");
483       const auto& any_assets = collection.any_list().value();
484       if (!any_assets.empty()) {
485         if (std::is_base_of<protobuf::Message, AssetFileDef>()) {
486           for (const auto& any_asset : any_assets) {
487             AssetFileDef asset_file_def;
488             if (!ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef")
489                      .ok()) {
490               LOG(ERROR) << "Failed to parse AssetFile.";
491               continue;
492             }
493             string asset_filepath = io::JoinPath(cfg.assets_directory_override,
494                                                  asset_file_def.filename());
495             if (!FilesExist({asset_filepath}, nullptr)) {
496               LOG(ERROR) << "Can't access one or more of the asset files "
497                          << asset_filepath << ", skipping this input";
498               return nullptr;
499             }
500             asset_node_to_value[NodeName(asset_file_def.tensor_info().name())] =
501                 asset_filepath;
502           }
503         } else {
504           LOG(ERROR) << "Can't parse AssetFileDef when using lite protos.";
505           return nullptr;
506         }
507       }
508     }
509   } else if (meta_graph.collection_def().count("asset_filepaths") > 0) {
510     const CollectionDef& file_paths =
511         meta_graph.collection_def().at("asset_filepaths");
512     std::vector<string> paths;
513     for (const auto& raw_path : file_paths.bytes_list().value()) {
514       paths.push_back(raw_path);
515     }
516     if (!FilesExist(paths, nullptr)) {
517       LOG(ERROR) << "Can't access one or more of the asset files, skipping "
518                     "this input";
519       return nullptr;
520     }
521   }
522 
523   if (meta_graph.collection_def().count("queue_runners") > 0) {
524     const CollectionDef& vars = meta_graph.collection_def().at("queue_runners");
525     for (const auto& raw : vars.bytes_list().value()) {
526       QueueRunnerDef queue_runner;
527       if (!queue_runner.ParseFromString(raw)) {
528         LOG(ERROR) << "Could not parse queue_runners, skipping this input";
529         return nullptr;
530       }
531       if (queue_runner.cancel_op_name().empty()) {
532         LOG(ERROR) << "Queue without a cancel op, skipping this input";
533         return nullptr;
534       }
535       new_item->queue_runners.push_back(queue_runner);
536     }
537   }
538 
539   // Add each node referenced in a collection to the list of nodes to keep.
540   for (const auto& col : meta_graph.collection_def()) {
541     const CollectionDef& collection = col.second;
542     for (const string& node : collection.node_list().value()) {
543       new_item->keep_ops.push_back(NodeName(node));
544     }
545   }
546 
547   for (auto& node : *new_item->graph.mutable_node()) {
548     if (IsPlaceholder(node) && node.op() != "PlaceholderWithDefault") {
549       Status s = UpdatePlaceholderShape(cfg, signature_feed_nodes,
550                                         new_item.get(), &node);
551       if (!s.ok()) return nullptr;
552     } else if (IsConstant(node)) {
553       auto it = asset_node_to_value.find(node.name());
554       if (it != asset_node_to_value.end()) {
555         auto iter = node.mutable_attr()->find("value");
556         if (iter == node.attr().end()) {
557           LOG(ERROR) << "Value attribute expected in const op for asset files";
558           return nullptr;
559         }
560         if (!iter->second.has_tensor() ||
561             iter->second.tensor().string_val_size() != 1) {
562           LOG(INFO) << "Unexpected AttrValue proto: "
563                     << iter->second.DebugString();
564           return nullptr;
565         }
566         LOG(INFO) << "Using asset file " << it->second << " for node "
567                   << node.name();
568         *(iter->second.mutable_tensor()->mutable_string_val(0)) = it->second;
569       }
570     }
571 
572     // Erase the recorded result of any previous shape inference to start again
573     // from scratch.
574     node.mutable_attr()->erase("_output_shapes");
575 
576     // Delete user specified placement if requested.
577     if (cfg.ignore_user_placement) {
578       node.clear_device();
579     }
580     // Delete colocation constraints if requested.
581     if (cfg.ignore_colocation) {
582       auto attr = node.mutable_attr();
583       auto it = attr->find("_class");
584       if (it != attr->end()) {
585         attr->erase(it);
586       }
587     }
588   }
589 
590   if (meta_graph.collection_def().count("savers") > 0) {
591     const CollectionDef& savers = meta_graph.collection_def().at("savers");
592     for (const auto& raw : savers.bytes_list().value()) {
593       SaverDef saver;
594       // Skip bad savers since we don't need saves/restores to be able to run a
595       // graph.
596       if (!saver.ParseFromString(raw)) {
597         continue;
598       }
599       if (saver.filename_tensor_name().empty()) {
600         continue;
601       }
602       new_item->save_op = saver.save_tensor_name();
603       new_item->restore_op = saver.restore_op_name();
604       new_item->save_restore_loc_tensor = saver.filename_tensor_name();
605       // Only use the first saver since it's not clear what to do if there's
606       // more than one.
607       break;
608     }
609   } else {
610     const SaverDef& saver = meta_graph.saver_def();
611     new_item->save_op = saver.save_tensor_name();
612     new_item->restore_op = saver.restore_op_name();
613     new_item->save_restore_loc_tensor = saver.filename_tensor_name();
614   }
615 
616   // Instantiate all the missing attributes with their default values.
617   Status attr_status = AddDefaultAttrsToGraphDef(
618       &new_item->graph,
619       FunctionLibraryDefinition(OpRegistry::Global(),
620                                 new_item->graph.library()),
621       0, true);
622   if (!attr_status.ok()) {
623     LOG(ERROR) << "Failed to instantiate default attribute values: "
624                << attr_status.error_message();
625     return nullptr;
626   }
627 
628   // Optimize the graph (function inlining, l1 optimizations, etc).
629   VLOG(1) << "Number of nodes in graph before RuntimeGraphOptimizer: "
630           << new_item->graph.node_size();
631   Status optimize_status =
632       RuntimeGraphOptimizer(new_item->graph, &new_item->graph, cfg);
633   if (!optimize_status.ok()) {
634     LOG(ERROR) << "Graph preprocessing failed: " << optimize_status;
635     return nullptr;
636   }
637   VLOG(1) << "Number of nodes in graph after RuntimeGraphOptimizer: "
638           << new_item->graph.node_size();
639 
640   if (cfg.prune_graph) {
641     VLOG(1) << "Pruning graph...";
642     auto status = PruneGraph(new_item.get());
643     if (!status.ok()) {
644       LOG(ERROR) << "Pruning failed: " << status.error_message();
645       return nullptr;
646     }
647     VLOG(1) << "Number of nodes in graph after pruning: "
648             << new_item->graph.node_size();
649   }
650 
651   // Validate feed, fetch and init nodes
652   std::unordered_set<string> nodes;
653   for (const auto& node : new_item->graph.node()) {
654     nodes.insert(node.name());
655   }
656   for (const auto& feed : new_item->feed) {
657     if (nodes.find(feed.first) == nodes.end()) {
658       LOG(ERROR) << "Feed node " << feed.first << " doesn't exist in graph";
659       return nullptr;
660     }
661   }
662   for (const auto& fetch : new_item->fetch) {
663     if (nodes.find(fetch) == nodes.end()) {
664       LOG(ERROR) << "Fetch node " << fetch << " doesn't exist in graph";
665       return nullptr;
666     }
667   }
668   for (const auto& init : new_item->init_ops) {
669     if (nodes.find(init) == nodes.end()) {
670       LOG(ERROR) << "Init node " << init << " doesn't exist in graph";
671       return nullptr;
672     }
673   }
674   return new_item;
675 }
676 
GrapplerItemFromMetaGraphDefFile(const string & id,const string & meta_graph_file,const ItemConfig & cfg)677 std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDefFile(
678     const string& id, const string& meta_graph_file, const ItemConfig& cfg) {
679   MetaGraphDef meta_graph;
680   if (!ReadMetaGraphDefFromFile(meta_graph_file, &meta_graph).ok()) {
681     LOG(ERROR) << "Failed to read " << meta_graph_file;
682     return nullptr;
683   }
684   return GrapplerItemFromMetaGraphDef(id, meta_graph, cfg);
685 }
686 
687 }  // end namespace grappler
688 }  // end namespace tensorflow
689