• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/grappler_item_builder.h"
16 
17 #include <type_traits>
18 #include <unordered_map>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_factory.h"
24 #include "tensorflow/core/common_runtime/device_mgr.h"
25 #include "tensorflow/core/common_runtime/function.h"
26 #include "tensorflow/core/common_runtime/graph_constructor.h"
27 #include "tensorflow/core/common_runtime/graph_optimizer.h"
28 #include "tensorflow/core/framework/attr_value.pb.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/function.pb.h"
31 #include "tensorflow/core/framework/graph_def_util.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/tensor.pb.h"
35 #include "tensorflow/core/framework/tensor_shape.pb.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/framework/variable.pb.h"
38 #include "tensorflow/core/framework/versions.pb.h"
39 #include "tensorflow/core/grappler/inputs/utils.h"
40 #include "tensorflow/core/grappler/op_types.h"
41 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
42 #include "tensorflow/core/grappler/utils.h"
43 #include "tensorflow/core/lib/gtl/map_util.h"
44 #include "tensorflow/core/lib/io/path.h"
45 #include "tensorflow/core/platform/protobuf_internal.h"
46 #include "tensorflow/core/protobuf/meta_graph.pb.h"
47 #include "tensorflow/core/protobuf/saver.pb.h"
48 #include "tensorflow/core/public/session_options.h"
49 
50 namespace tensorflow {
51 namespace grappler {
52 
53 namespace {
54 
InitializeTensor(DataType type,Tensor * tensor)55 void InitializeTensor(DataType type, Tensor* tensor) {
56   const int period = 7;
57   if (type == DT_FLOAT) {
58     auto flat = tensor->flat<float>();
59     // Populate numbers 0, 0.1, 0.2, ..., 0.5, 0.6, 0, 0.1, 0.2, ...
60     for (int i = 0; i < flat.size(); i++) {
61       flat(i) = static_cast<float>(i % period) / 10.0f;
62     }
63   } else if (type == DT_INT64) {
64     auto flat = tensor->flat<int64>();
65     // Populate numbers 0, 1, 2, ..., 5, 6, 0, 1, 2, ...
66     for (int i = 0; i < flat.size(); i++) {
67       flat(i) = i % period;
68     }
69   } else if (type != DT_STRING && type != DT_RESOURCE && type != DT_VARIANT) {
70     // DT_STRING, DT_RESOURCE and DT_VARIANT are not simple types according to
71     // is_simple_type<> in tensorflow/core/framework/type_traits.h, and
72     // Allocator will run non-trivial constructor/destructor for a Tensor with
73     // one of these types, so we should not memset its buffer.
74     memset(const_cast<char*>(tensor->tensor_data().data()), 0,
75            tensor->tensor_data().size());
76   }
77 }
78 
79 // Applies the same graph pruning logic to the graph as Session.Run in TF.
80 // If the returned status is not OK, item state may be inconsistent.
PruneGraph(GrapplerItem * item)81 Status PruneGraph(GrapplerItem* item) {
82   ModelPruner pruner;
83   GraphDef pruned_graph;
84   Cluster* cluster = nullptr;  // ModelPruner doesn't check cluster.
85   TF_RETURN_IF_ERROR(pruner.Optimize(cluster, *item, &pruned_graph));
86   item->graph = std::move(pruned_graph);
87   return Status::OK();
88 }
89 
90 // Replace any unknown dimensions in a shape with
91 // cfg.placeholder_unknown_output_shape_dim if it is no less than 0.
ReplaceUnknownShapeDim(const ItemConfig & cfg,const TensorShapeProto & shape_pb_in,TensorShapeProto * shape_pb_out,TensorShape * shape_out)92 Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
93                               const TensorShapeProto& shape_pb_in,
94                               TensorShapeProto* shape_pb_out,
95                               TensorShape* shape_out) {
96   std::vector<int32> dims;
97   for (const auto& dim_proto : shape_pb_in.dim()) {
98     if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
99         dim_proto.size() == -1) {
100       dims.push_back(cfg.placeholder_unknown_output_shape_dim);
101       shape_pb_out->add_dim()->set_size(
102           cfg.placeholder_unknown_output_shape_dim);
103     } else {
104       dims.push_back(std::max<int32>(1, dim_proto.size()));
105       shape_pb_out->add_dim()->set_size(dim_proto.size());
106     }
107   }
108   return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out);
109 }
110 
111 // Replace unknown dimensions in Placeholder shape if
112 // cfg.placeholder_unknown_output_shape_dim is set or
113 // the Placeholder node has _output_shapes.
114 // Otherwise keep it intact to keep compatible with shape annotation
115 // (b/134092018).
UpdatePlaceholderShape(const ItemConfig & cfg,const std::unordered_set<string> & signature_feed_nodes,GrapplerItem * new_item,NodeDef * node)116 Status UpdatePlaceholderShape(
117     const ItemConfig& cfg,
118     const std::unordered_set<string>& signature_feed_nodes,
119     GrapplerItem* new_item, NodeDef* node) {
120   if (node->attr().count("dtype") == 0) {
121     return errors::Internal("Unknown type for placeholder ", node->name(),
122                             ", skipping this input");
123   }
124   DataType type = node->attr().at("dtype").type();
125 
126   // TODO(andiryxu): Consider cfg.placeholder_unknown_output_shape_dim >= 0 and
127   // _output_shapes is present case.
128   if (node->attr().count("shape") == 0) {
129     return errors::Internal("Unknown shape for placeholder ", node->name(),
130                             ", skipping this input");
131   }
132 
133   // Replace all unknown dimensions in the placeholder's tensorshape proto
134   // with cfg.placeholder_unknown_output_shape_dim and create a tensorshape
135   // from it. We do this because in newer protos, the input placeholder
136   // shape is not empty if the shape is partially defined.
137   TensorShape shape;
138   TensorShapeProto shape_proto;
139   Status make_shape_status = ReplaceUnknownShapeDim(
140       cfg, node->attr().at("shape").shape(), &shape_proto, &shape);
141   if (!make_shape_status.ok()) {
142     return errors::Internal("Invalid shape for placeholder ", node->name(),
143                             ": ", make_shape_status, ", skipping this input");
144   }
145 
146   // Some placeholder nodes have a mis-match between the node
147   // attribute "shape" and a different node attribute "_output_shapes".
148   // Specifically, a shape with shape.dims() == 0 could indicate either
149   // a scalar or an unknown shape. In those cases, we check _output_shapes
150   // for additional information.
151   // This case is observed in the bnmt graphs. Have not observed any
152   // cases where there was more than 1 _output_shapes, so limit it
153   // to cases where there is only 1 _output_shapes.
154   // We only do this if cfg.placeholder_unknown_output_shape_dim has
155   // been set to avoid crashing non-BNMT graphs.
156   // TODO(andiryxu): Investigate if this is a bug in BNMT graph.
157   if ((cfg.placeholder_unknown_output_shape_dim >= 0) && (shape.dims() == 0) &&
158       (node->attr().count("_output_shapes") == 1)) {
159     const auto& output_shapes =
160         node->attr().at("_output_shapes").list().shape(0);
161 
162     if (output_shapes.dim_size() != 0) {
163       shape.Clear();
164       shape_proto.clear_dim();
165 
166       for (const auto& dim : output_shapes.dim()) {
167         auto size = dim.size();
168         if (size == -1) size = cfg.placeholder_unknown_output_shape_dim;
169         shape.AddDim(size);
170         shape_proto.add_dim()->set_size(size);
171       }
172     }
173   }
174 
175   Tensor fake_input(type, shape);
176   InitializeTensor(type, &fake_input);
177 
178   if (cfg.feed_nodes.empty()) {
179     // No specific feed nodes were given. Assume all placeholders are fed.
180     if (signature_feed_nodes.count(node->name()) == 0) {
181       new_item->feed.emplace_back(node->name(), fake_input);
182     }
183   } else if (cfg.feed_nodes.count(node->name()) > 0) {
184     // If specific feed nodes were given, only update their tensors.
185     auto it = find_if(new_item->feed.begin(), new_item->feed.end(),
186                       [&node](std::pair<string, Tensor>& f) {
187                         return f.first == node->name();
188                       });
189     DCHECK(it != new_item->feed.end());
190     it->second = fake_input;
191   }
192 
193   // Set the shape of the node in the graph. This is needed for statically
194   // inferring shapes and is a no-op when dynamically inferring shapes as
195   // the Placeholder shape will match the shape passed from new_item->feed.
196   // Only replace node shape with known shape. For unknown shape keep it intact
197   // (b/134092018).
198   if (!shape_proto.dim().empty())
199     *(node->mutable_attr()->at("shape").mutable_shape()) = shape_proto;
200 
201   return Status::OK();
202 }
203 
204 }  // namespace
205 
RuntimeGraphOptimizer(const GraphDef & graph_def_arg,GraphDef * output_graph_def,const ItemConfig & cfg)206 Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
207                              GraphDef* output_graph_def,
208                              const ItemConfig& cfg) {
209   // This is a temporary change that optimizes the graph in context of a single
210   // gpu machine. Down the line, we may want to make grappler_item_builder aware
211   // of the cluster type (E.g: single cpu, multiple gpu, etc)  being simulated
212   // in order to get the correct session options and environment, and performing
213   // the correct optimizations.
214 
215   // Return input as is if no graph-modifying config is set.
216   if (!cfg.apply_optimizations && !cfg.inline_functions &&
217       !cfg.erase_noinline_attributes) {
218     if (output_graph_def != &graph_def_arg) {
219       *output_graph_def = graph_def_arg;
220     }
221     return Status::OK();
222   }
223 
224   // Create a session option for a single GPU device.
225   SessionOptions options;
226 
227   // Make a local copy of graph def, because we need to change some things.
228   GraphDef graph_def(graph_def_arg);
229 
230   if (cfg.erase_noinline_attributes) {
231     // TF optimizer doesn't inline functions with "_noinline" attribute,
232     // so let's go over the function library and erase it.
233     for (auto& func : *graph_def.mutable_library()->mutable_function()) {
234       func.mutable_attr()->erase("_noinline");
235     }
236   }
237 
238   // Instantiate all variables for function library runtime creation.
239   std::vector<std::unique_ptr<Device>> devices;
240   // Only CPU device is used so instead of calling DeviceFactory::AddDevices()
241   // with dummy session config, which will conflict with user defined options
242   // and create unwanted devices, call cpu_factory->CreateDevices() to get CPU
243   // only devices.
244   DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
245   TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
246       options, "/job:localhost/replica:0/task:0", &devices));
247   Device* cpu_device = devices[0].get();
248   auto dvc_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
249   FunctionLibraryDefinition function_library(OpRegistry::Global(),
250                                              graph_def.library());
251   Env* env = Env::Default();
252 
253   // Optimizer options: L1 and inlining. L1 is default.
254   OptimizerOptions* optimizer_opts =
255       options.config.mutable_graph_options()->mutable_optimizer_options();
256   if (cfg.apply_optimizations) {
257     optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions::L1);
258   } else {
259     optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions::L0);
260   }
261   optimizer_opts->set_do_function_inlining(cfg.inline_functions);
262 
263   // Create the function library runtime.
264   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
265       new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, &options.config,
266                                         graph_def.versions().producer(),
267                                         &function_library, *optimizer_opts));
268   FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name());
269 
270   // Create the GraphOptimizer to optimize the graph def.
271   GraphConstructorOptions graph_ctor_opts;
272   graph_ctor_opts.allow_internal_ops = true;
273   graph_ctor_opts.expect_device_spec = false;
274   std::unique_ptr<Graph> graphptr(new Graph(function_library));
275 
276   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
277       graph_ctor_opts, std::move(graph_def), graphptr.get()));
278 
279   // Optimize the graph.
280   ::tensorflow::GraphOptimizer optimizer(*optimizer_opts);
281   optimizer.Optimize(flr, env, cpu_device, &graphptr, /*shape_map=*/nullptr);
282   graphptr->ToGraphDef(output_graph_def);
283 
284   // The default values of attributes might have been stripped by the optimizer.
285   // Add them back.
286   return AddDefaultAttrsToGraphDef(output_graph_def, *graphptr->op_registry(),
287                                    0, true);
288 }
289 
GrapplerItemFromMetaGraphDef(const string & id,const MetaGraphDef & meta_graph,const ItemConfig & cfg)290 std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
291     const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) {
292   if (id.empty()) {
293     LOG(ERROR) << "id must be non-empty.";
294     return nullptr;
295   }
296   std::unique_ptr<GrapplerItem> new_item(new GrapplerItem());
297   new_item->id = id;
298   new_item->graph = meta_graph.graph_def();
299 
300   // Fill in feed nodes from config, if any provided.
301   for (const auto& feed_node : cfg.feed_nodes) {
302     const string feed_name = NodeName(feed_node);
303     new_item->feed.emplace_back(feed_name, Tensor());
304   }
305   for (const auto& fetch_node : cfg.fetch_nodes) {
306     new_item->fetch.emplace_back(NodeName(fetch_node));
307   }
308 
309   // Attempt to detect the fetch node(s) if they were not set explicitly.
310   if (new_item->fetch.empty() &&
311       meta_graph.collection_def().count("train_op") > 0) {
312     const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
313     if (nodes.has_node_list()) {
314       for (const auto& node : nodes.node_list().value()) {
315         new_item->fetch.push_back(NodeName(node));
316       }
317     }
318   }
319 
320   // Detect feed and fetch nodes from signature defs. Signatures may share same
321   // inputs or outputs.
322   std::unordered_set<string> signature_feed_nodes;
323   std::unordered_set<string> signature_fetch_nodes;
324   for (const auto& name_and_signature : meta_graph.signature_def()) {
325     for (const auto& name_and_input : name_and_signature.second.inputs()) {
326       const TensorInfo& input = name_and_input.second;
327       if (input.has_coo_sparse()) {
328         // Define the shapes following the comment of CooSparse.
329         // TODO(yuefengz): we probably want to use different dim values for the
330         // three tensors of a SparseTensor.
331         int64 dim = std::max(1, cfg.placeholder_unknown_output_shape_dim);
332         TensorShape shape_1d({dim});
333         TensorShape shape_2d({dim, dim});
334 
335         if (gtl::InsertIfNotPresent(
336                 &signature_feed_nodes,
337                 NodeName(input.coo_sparse().values_tensor_name()))) {
338           Tensor value_tensor(input.dtype(), shape_1d);
339           InitializeTensor(input.dtype(), &value_tensor);
340           new_item->feed.emplace_back(
341               NodeName(input.coo_sparse().values_tensor_name()), value_tensor);
342         }
343         if (gtl::InsertIfNotPresent(
344                 &signature_feed_nodes,
345                 NodeName(input.coo_sparse().indices_tensor_name()))) {
346           Tensor indices_tensor(DT_INT64, shape_2d);
347           InitializeTensor(input.dtype(), &indices_tensor);
348           new_item->feed.emplace_back(
349               NodeName(input.coo_sparse().indices_tensor_name()),
350               indices_tensor);
351         }
352         if (gtl::InsertIfNotPresent(
353                 &signature_feed_nodes,
354                 NodeName(input.coo_sparse().dense_shape_tensor_name()))) {
355           Tensor dense_shape_tensor(DT_INT64, shape_1d);
356           InitializeTensor(input.dtype(), &dense_shape_tensor);
357           new_item->feed.emplace_back(
358               NodeName(input.coo_sparse().dense_shape_tensor_name()),
359               dense_shape_tensor);
360         }
361       } else {
362         if (gtl::InsertIfNotPresent(&signature_feed_nodes,
363                                     NodeName(input.name()))) {
364           TensorShape shape;
365           TensorShapeProto shape_proto;
366           Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(),
367                                             &shape_proto, &shape);
368           if (!s.ok()) {
369             LOG(ERROR) << "Invalid shape for signature input " << input.name()
370                        << ": " << s << ", skipping this input";
371             return nullptr;
372           }
373 
374           Tensor fake_input(input.dtype(), shape);
375           InitializeTensor(input.dtype(), &fake_input);
376           new_item->feed.emplace_back(NodeName(input.name()), fake_input);
377         }
378       }
379     }
380     for (const auto& name_and_output : name_and_signature.second.outputs()) {
381       const TensorInfo& output = name_and_output.second;
382       if (output.has_coo_sparse()) {
383         if (gtl::InsertIfNotPresent(
384                 &signature_fetch_nodes,
385                 NodeName(output.coo_sparse().values_tensor_name()))) {
386           new_item->fetch.push_back(
387               NodeName(output.coo_sparse().values_tensor_name()));
388         }
389         if (gtl::InsertIfNotPresent(
390                 &signature_fetch_nodes,
391                 NodeName(output.coo_sparse().indices_tensor_name()))) {
392           new_item->fetch.push_back(
393               NodeName(output.coo_sparse().indices_tensor_name()));
394         }
395         if (gtl::InsertIfNotPresent(
396                 &signature_fetch_nodes,
397                 NodeName(output.coo_sparse().dense_shape_tensor_name()))) {
398           new_item->fetch.push_back(
399               NodeName(output.coo_sparse().dense_shape_tensor_name()));
400         }
401       } else {
402         if (gtl::InsertIfNotPresent(&signature_fetch_nodes,
403                                     NodeName(output.name()))) {
404           new_item->fetch.push_back(NodeName(output.name()));
405         }
406       }
407     }
408   }
409 
410   for (const auto& feed : new_item->feed) {
411     if (feed.first.empty()) {
412       LOG(ERROR) << "Invalid feed node name skipping this input";
413       return nullptr;
414     } else {
415       VLOG(1) << "Will use feed node " << feed.first;
416     }
417   }
418 
419   for (const auto& fetch : new_item->fetch) {
420     if (fetch.empty()) {
421       LOG(ERROR) << "Invalid fetch node name skipping this input";
422       return nullptr;
423     } else {
424       VLOG(1) << "Will use fetch node " << fetch;
425     }
426   }
427 
428   if (new_item->fetch.empty()) {
429     LOG(ERROR) << "Failed to detect the fetch node(s), skipping this input";
430     return nullptr;
431   }
432 
433   // TODO(yuefengz): consider handling saved_model_main_op and legacy_init_op.
434   // The reason why they are difficult to handle is because they may not intend
435   // to initialize all variables that are required to run fetch nodes. We may
436   // have to run restore op first.
437 
438   // Try to find initializers from variables and tables as init ops.
439   for (const string& var_collection :
440        {"variables", "local_variables", "model_variables",
441         "trainable_variables"}) {
442     if (meta_graph.collection_def().count(var_collection) == 0) {
443       continue;
444     }
445     const CollectionDef& vars = meta_graph.collection_def().at(var_collection);
446     for (const auto& raw_var : vars.bytes_list().value()) {
447       VariableDef var;
448       var.ParseFromString(raw_var);
449       if (!var.initializer_name().empty()) {
450         new_item->init_ops.push_back(NodeName(var.initializer_name()));
451       }
452     }
453   }
454 
455   if (meta_graph.collection_def().count("table_initializer") > 0) {
456     const CollectionDef& inits =
457         meta_graph.collection_def().at("table_initializer");
458     if (inits.has_node_list()) {
459       for (const auto& node : inits.node_list().value()) {
460         new_item->init_ops.push_back(NodeName(node));
461         // Tables are initialized from files, which can take a long time. Add
462         // 30 minutes to the initialization time for each table to avoid
463         // timing out.
464         // TODO(bsteiner): adjust the timeout based on the file size.
465         new_item->expected_init_time += 30 * 60;
466       }
467     }
468   }
469 
470   // We keep the mapping from asset node to asset files. This should have been
471   // used as feed but since asset node is usually a constant node, we will fill
472   // the values of these constant nodes with their actual asset file paths.
473   std::unordered_map<string, string> asset_node_to_value;
474 
475   // Assets file may have changed their directory, we assemble their new paths
476   // if assets_directory_override is set. We also make sure we still can
477   // access these asset files.
478   if (!cfg.assets_directory_override.empty()) {
479     if (meta_graph.collection_def().count("saved_model_assets") > 0) {
480       const CollectionDef& collection =
481           meta_graph.collection_def().at("saved_model_assets");
482       const auto& any_assets = collection.any_list().value();
483       if (!any_assets.empty()) {
484         if (std::is_base_of<protobuf::Message, AssetFileDef>()) {
485           for (const auto& any_asset : any_assets) {
486             AssetFileDef asset_file_def;
487             if (!ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef")
488                      .ok()) {
489               LOG(ERROR) << "Failed to parse AssetFile.";
490               continue;
491             }
492             string asset_filepath = io::JoinPath(cfg.assets_directory_override,
493                                                  asset_file_def.filename());
494             if (!FilesExist({asset_filepath}, nullptr)) {
495               LOG(ERROR) << "Can't access one or more of the asset files "
496                          << asset_filepath << ", skipping this input";
497               return nullptr;
498             }
499             asset_node_to_value[NodeName(asset_file_def.tensor_info().name())] =
500                 asset_filepath;
501           }
502         } else {
503           LOG(ERROR) << "Can't parse AssetFileDef when using lite protos.";
504           return nullptr;
505         }
506       }
507     }
508   } else if (meta_graph.collection_def().count("asset_filepaths") > 0) {
509     const CollectionDef& file_paths =
510         meta_graph.collection_def().at("asset_filepaths");
511     std::vector<string> paths;
512     for (const auto& raw_path : file_paths.bytes_list().value()) {
513       paths.push_back(raw_path);
514     }
515     if (!FilesExist(paths, nullptr)) {
516       LOG(ERROR) << "Can't access one or more of the asset files, skipping "
517                     "this input";
518       return nullptr;
519     }
520   }
521 
522   if (meta_graph.collection_def().count("queue_runners") > 0) {
523     const CollectionDef& vars = meta_graph.collection_def().at("queue_runners");
524     for (const auto& raw : vars.bytes_list().value()) {
525       QueueRunnerDef queue_runner;
526       if (!queue_runner.ParseFromString(raw)) {
527         LOG(ERROR) << "Could not parse queue_runners, skipping this input";
528         return nullptr;
529       }
530       if (queue_runner.cancel_op_name().empty()) {
531         LOG(ERROR) << "Queue without a cancel op, skipping this input";
532         return nullptr;
533       }
534       new_item->queue_runners.push_back(queue_runner);
535     }
536   }
537 
538   // Add each node referenced in a collection to the list of nodes to keep.
539   for (const auto& col : meta_graph.collection_def()) {
540     const CollectionDef& collection = col.second;
541     for (const string& node : collection.node_list().value()) {
542       new_item->keep_ops.push_back(NodeName(node));
543     }
544   }
545 
546   for (auto& node : *new_item->graph.mutable_node()) {
547     if (IsPlaceholder(node) && node.op() != "PlaceholderWithDefault") {
548       Status s = UpdatePlaceholderShape(cfg, signature_feed_nodes,
549                                         new_item.get(), &node);
550       if (!s.ok()) return nullptr;
551     } else if (IsConstant(node)) {
552       auto it = asset_node_to_value.find(node.name());
553       if (it != asset_node_to_value.end()) {
554         auto iter = node.mutable_attr()->find("value");
555         if (iter == node.attr().end()) {
556           LOG(ERROR) << "Value attribute expected in const op for asset files";
557           return nullptr;
558         }
559         if (!iter->second.has_tensor() ||
560             iter->second.tensor().string_val_size() != 1) {
561           LOG(INFO) << "Unexpected AttrValue proto: "
562                     << iter->second.DebugString();
563           return nullptr;
564         }
565         LOG(INFO) << "Using asset file " << it->second << " for node "
566                   << node.name();
567         *(iter->second.mutable_tensor()->mutable_string_val(0)) = it->second;
568       }
569     }
570 
571     // Erase the recorded result of any previous shape inference to start again
572     // from scratch.
573     node.mutable_attr()->erase("_output_shapes");
574 
575     // Delete user specified placement if requested.
576     if (cfg.ignore_user_placement) {
577       node.clear_device();
578     }
579     // Delete colocation constraints if requested.
580     if (cfg.ignore_colocation) {
581       auto attr = node.mutable_attr();
582       auto it = attr->find("_class");
583       if (it != attr->end()) {
584         attr->erase(it);
585       }
586     }
587   }
588 
589   if (meta_graph.collection_def().count("savers") > 0) {
590     const CollectionDef& savers = meta_graph.collection_def().at("savers");
591     for (const auto& raw : savers.bytes_list().value()) {
592       SaverDef saver;
593       // Skip bad savers since we don't need saves/restores to be able to run a
594       // graph.
595       if (!saver.ParseFromString(raw)) {
596         continue;
597       }
598       if (saver.filename_tensor_name().empty()) {
599         continue;
600       }
601       new_item->save_op = saver.save_tensor_name();
602       new_item->restore_op = saver.restore_op_name();
603       new_item->save_restore_loc_tensor = saver.filename_tensor_name();
604       // Only use the first saver since it's not clear what to do if there's
605       // more than one.
606       break;
607     }
608   } else {
609     const SaverDef& saver = meta_graph.saver_def();
610     new_item->save_op = saver.save_tensor_name();
611     new_item->restore_op = saver.restore_op_name();
612     new_item->save_restore_loc_tensor = saver.filename_tensor_name();
613   }
614 
615   // Instantiate all the missing attributes with their default values.
616   Status attr_status = AddDefaultAttrsToGraphDef(
617       &new_item->graph,
618       FunctionLibraryDefinition(OpRegistry::Global(),
619                                 new_item->graph.library()),
620       0, true);
621   if (!attr_status.ok()) {
622     LOG(ERROR) << "Failed to instantiate default attribute values: "
623                << attr_status.error_message();
624     return nullptr;
625   }
626 
627   // Optimize the graph (function inlining, l1 optimizations, etc).
628   VLOG(1) << "Number of nodes in graph before RuntimeGraphOptimizer: "
629           << new_item->graph.node_size();
630   Status optimize_status =
631       RuntimeGraphOptimizer(new_item->graph, &new_item->graph, cfg);
632   if (!optimize_status.ok()) {
633     LOG(ERROR) << "Graph preprocessing failed: " << optimize_status;
634     return nullptr;
635   }
636   VLOG(1) << "Number of nodes in graph after RuntimeGraphOptimizer: "
637           << new_item->graph.node_size();
638 
639   if (cfg.prune_graph) {
640     VLOG(1) << "Pruning graph...";
641     auto status = PruneGraph(new_item.get());
642     if (!status.ok()) {
643       LOG(ERROR) << "Pruning failed: " << status.error_message();
644       return nullptr;
645     }
646     VLOG(1) << "Number of nodes in graph after pruning: "
647             << new_item->graph.node_size();
648   }
649 
650   // Validate feed, fetch and init nodes
651   std::unordered_set<string> nodes;
652   for (const auto& node : new_item->graph.node()) {
653     nodes.insert(node.name());
654   }
655   for (const auto& feed : new_item->feed) {
656     if (nodes.find(feed.first) == nodes.end()) {
657       LOG(ERROR) << "Feed node " << feed.first << " doesn't exist in graph";
658       return nullptr;
659     }
660   }
661   for (const auto& fetch : new_item->fetch) {
662     if (nodes.find(fetch) == nodes.end()) {
663       LOG(ERROR) << "Fetch node " << fetch << " doesn't exist in graph";
664       return nullptr;
665     }
666   }
667   for (const auto& init : new_item->init_ops) {
668     if (nodes.find(init) == nodes.end()) {
669       LOG(ERROR) << "Init node " << init << " doesn't exist in graph";
670       return nullptr;
671     }
672   }
673   return new_item;
674 }
675 
GrapplerItemFromMetaGraphDefFile(const string & id,const string & meta_graph_file,const ItemConfig & cfg)676 std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDefFile(
677     const string& id, const string& meta_graph_file, const ItemConfig& cfg) {
678   MetaGraphDef meta_graph;
679   if (!ReadMetaGraphDefFromFile(meta_graph_file, &meta_graph).ok()) {
680     LOG(ERROR) << "Failed to read " << meta_graph_file;
681     return nullptr;
682   }
683   return GrapplerItemFromMetaGraphDef(id, meta_graph, cfg);
684 }
685 
686 }  // end namespace grappler
687 }  // end namespace tensorflow
688