• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Loader implementation for SavedModel with hermetic, language-neutral exports.
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import os
23
24from google.protobuf import message
25from google.protobuf import text_format
26
27from tensorflow.core.protobuf import graph_debug_info_pb2
28from tensorflow.core.protobuf import meta_graph_pb2
29from tensorflow.core.protobuf import saved_model_pb2
30from tensorflow.python.framework import ops
31from tensorflow.python.lib.io import file_io
32from tensorflow.python.ops import variables
33from tensorflow.python.platform import tf_logging
34from tensorflow.python.saved_model import constants
35from tensorflow.python.saved_model import signature_def_utils
36from tensorflow.python.saved_model import utils_impl as saved_model_utils
37from tensorflow.python.saved_model.pywrap_saved_model import metrics
38from tensorflow.python.training import saver as tf_saver
39from tensorflow.python.util import compat
40from tensorflow.python.util import deprecation
41from tensorflow.python.util.tf_export import tf_export
42
43# API label for SavedModel metrics.
44_LOADER_LABEL = "loader"
45
46
47def parse_saved_model_with_debug_info(export_dir):
48  """Reads the savedmodel as well as the graph debug info.
49
50  Args:
51    export_dir: Directory containing the SavedModel and GraphDebugInfo files.
52
53  Returns:
54    `SavedModel` and `GraphDebugInfo` protocol buffers.
55
56  Raises:
57    IOError: If the saved model file does not exist, or cannot be successfully
58    parsed. Missing graph debug info file is fine.
59  """
60  saved_model = _parse_saved_model(export_dir)
61
62  debug_info_path = os.path.join(
63      saved_model_utils.get_debug_dir(export_dir),
64      constants.DEBUG_INFO_FILENAME_PB)
65  debug_info = graph_debug_info_pb2.GraphDebugInfo()
66  if file_io.file_exists(debug_info_path):
67    with file_io.FileIO(debug_info_path, "rb") as debug_file:
68      try:
69        debug_info.ParseFromString(debug_file.read())
70      except message.DecodeError as e:
71        raise IOError("Cannot parse file %s: %s." % (debug_info_path, str(e)))
72
73  return (saved_model, debug_info)
74
75
76@tf_export("__internal__.saved_model.parse_saved_model", v1=[])
77def parse_saved_model(export_dir):
78  """Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`.
79
80  Args:
81    export_dir: String or Pathlike, path to the directory containing the
82    SavedModel file.
83
84  Returns:
85    A `SavedModel` protocol buffer.
86
87  Raises:
88    IOError: If the file does not exist, or cannot be successfully parsed.
89  """
90  # Build the path to the SavedModel in pbtxt format.
91  path_to_pbtxt = os.path.join(
92      compat.as_bytes(compat.path_to_str(export_dir)),
93      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
94  # Build the path to the SavedModel in pb format.
95  path_to_pb = os.path.join(
96      compat.as_bytes(compat.path_to_str(export_dir)),
97      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
98
99  # Parse the SavedModel protocol buffer.
100  saved_model = saved_model_pb2.SavedModel()
101  if file_io.file_exists(path_to_pb):
102    with file_io.FileIO(path_to_pb, "rb") as f:
103      file_content = f.read()
104    try:
105      saved_model.ParseFromString(file_content)
106      return saved_model
107    except message.DecodeError as e:
108      raise IOError("Cannot parse file %s: %s." % (path_to_pb, str(e)))
109  elif file_io.file_exists(path_to_pbtxt):
110    with file_io.FileIO(path_to_pbtxt, "rb") as f:
111      file_content = f.read()
112    try:
113      text_format.Merge(file_content.decode("utf-8"), saved_model)
114      return saved_model
115    except text_format.ParseError as e:
116      raise IOError("Cannot parse file %s: %s." % (path_to_pbtxt, str(e)))
117  else:
118    raise IOError(
119        "SavedModel file does not exist at: %s%s{%s|%s}" %
120        (export_dir, os.path.sep, constants.SAVED_MODEL_FILENAME_PBTXT,
121         constants.SAVED_MODEL_FILENAME_PB))
122
123
124# TODO(b/120594573): Make this symbol also available as private, so that
125# tensorflow_transform and tensorflow_estimator do not break.
126_parse_saved_model = parse_saved_model
127
128
129def get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
130  """Gets the asset tensors, if defined in the meta graph def to load.
131
132  Args:
133    export_dir: Directory where the SavedModel is located.
134    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
135    import_scope: Optional `string` -- if specified, prepend this followed by
136        '/' to all returned asset tensor names.
137
138  Returns:
139    A dictionary of asset tensors, keyed by the name of the asset tensor. The
140    value in the map corresponds to the absolute path of the asset file.
141  """
142  # Collection-def that may contain the assets key.
143  collection_def = meta_graph_def_to_load.collection_def
144
145  asset_tensor_dict = {}
146  asset_protos = []
147
148  if meta_graph_def_to_load.asset_file_def:
149    asset_protos = meta_graph_def_to_load.asset_file_def
150  elif constants.ASSETS_KEY in collection_def:
151    assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value
152    for asset_any_proto in assets_any_proto:
153      asset_proto = meta_graph_pb2.AssetFileDef()
154      asset_any_proto.Unpack(asset_proto)
155      asset_protos.append(asset_proto)
156
157  # Location of the assets for SavedModel.
158  assets_directory = os.path.join(
159      compat.as_bytes(export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY))
160  # Process each asset and add it to the asset tensor dictionary.
161  for asset_proto in asset_protos:
162    tensor_name = asset_proto.tensor_info.name
163    if import_scope:
164      tensor_name = "%s/%s" % (import_scope, tensor_name)
165    asset_tensor_dict[tensor_name] = os.path.join(
166        compat.as_bytes(assets_directory),
167        compat.as_bytes(asset_proto.filename))
168
169  return asset_tensor_dict
170
171
172def _get_main_op_tensor(
173    meta_graph_def_to_load, init_op_key=constants.MAIN_OP_KEY):
174  """Gets the main op tensor, if one exists.
175
176  Args:
177    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
178    init_op_key: name of the collection to check; should be one of MAIN_OP_KEY
179      or the deprecated LEGACY_INIT_OP_KEY
180
181  Returns:
182    The main op tensor, if it exists and `None` otherwise.
183
184  Raises:
185    RuntimeError: If the collection def corresponding to the main op key has
186        other than exactly one tensor.
187  """
188  # TODO(kathywu): Rename this method to _get_op_from_collection when
189  # dependency from SavedModelEstimator is removed.
190  collection_def = meta_graph_def_to_load.collection_def
191  init_op = None
192  if init_op_key in collection_def:
193    init_op_list = collection_def[init_op_key].node_list.value
194    if len(init_op_list) != 1:
195      raise RuntimeError("Expected exactly one SavedModel init op. "
196                         "Found: {}".format(init_op_list))
197    init_op = ops.get_collection(init_op_key)[0]
198  return init_op
199
200
201def _get_op_from_collection(meta_graph_def, op_key):
202  return _get_main_op_tensor(meta_graph_def, op_key)
203
204
205def _get_op_from_signature_def(meta_graph_def, op_signature_key, import_scope):
206  """Retrieve op stored in the imported meta graph's signature def."""
207  if op_signature_key in meta_graph_def.signature_def:
208    return signature_def_utils.load_op_from_signature_def(
209        meta_graph_def.signature_def[op_signature_key], op_signature_key,
210        import_scope)
211  else:
212    return None
213
214
215def get_init_op(meta_graph_def, import_scope=None):
216  return (_get_op_from_signature_def(
217      meta_graph_def, constants.INIT_OP_SIGNATURE_KEY, import_scope) or
218          _get_op_from_collection(meta_graph_def, constants.MAIN_OP_KEY) or
219          _get_op_from_collection(meta_graph_def, constants.LEGACY_INIT_OP_KEY))
220
221
222def get_train_op(meta_graph_def, import_scope=None):
223  train_op = _get_op_from_signature_def(
224      meta_graph_def, constants.TRAIN_OP_SIGNATURE_KEY, import_scope)
225  if train_op is None:
226    train_op = _get_op_from_collection(meta_graph_def, constants.TRAIN_OP_KEY)
227  return train_op
228
229
230@tf_export(v1=[
231    "saved_model.contains_saved_model",
232    "saved_model.maybe_saved_model_directory",
233    "saved_model.loader.maybe_saved_model_directory"
234])
235@deprecation.deprecated_endpoints(
236    "saved_model.loader.maybe_saved_model_directory")
237def maybe_saved_model_directory(export_dir):
238  """Checks whether the provided export directory could contain a SavedModel.
239
240  Note that the method does not load any data by itself. If the method returns
241  `false`, the export directory definitely does not contain a SavedModel. If the
242  method returns `true`, the export directory may contain a SavedModel but
243  provides no guarantee that it can be loaded.
244
245  Args:
246    export_dir: Absolute string path to possible export location. For example,
247                '/my/foo/model'.
248
249  Returns:
250    True if the export directory contains SavedModel files, False otherwise.
251  """
252  txt_path = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PBTXT)
253  pb_path = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PB)
254  return file_io.file_exists(txt_path) or file_io.file_exists(pb_path)
255
256
257@tf_export("saved_model.contains_saved_model", v1=[])
258def contains_saved_model(export_dir):
259  """Checks whether the provided export directory could contain a SavedModel.
260
261  Note that the method does not load any data by itself. If the method returns
262  `false`, the export directory definitely does not contain a SavedModel. If the
263  method returns `true`, the export directory may contain a SavedModel but
264  provides no guarantee that it can be loaded.
265
266  Args:
267    export_dir: Absolute string path to possible export location. For example,
268                '/my/foo/model'.
269
270  Returns:
271    True if the export directory contains SavedModel files, False otherwise.
272  """
273  return maybe_saved_model_directory(export_dir)
274
275
276@tf_export(v1=["saved_model.load", "saved_model.loader.load"])
277@deprecation.deprecated(
278    None,
279    "This function will only be available through the v1 compatibility "
280    "library as tf.compat.v1.saved_model.loader.load or "
281    "tf.compat.v1.saved_model.load. There will be a new function for importing "
282    "SavedModels in Tensorflow 2.0.")
283def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
284  """Loads the model from a SavedModel as specified by tags.
285
286  Args:
287    sess: The TensorFlow session to restore the variables.
288    tags: Set of string tags to identify the required MetaGraphDef. These should
289        correspond to the tags used when saving the variables using the
290        SavedModel `save()` API.
291    export_dir: Directory in which the SavedModel protocol buffer and variables
292        to be loaded are located.
293    import_scope: Optional `string` -- if specified, prepend this string
294        followed by '/' to all loaded tensor names. This scope is applied to
295        tensor instances loaded into the passed session, but it is *not* written
296        through to the static `MetaGraphDef` protocol buffer that is returned.
297    **saver_kwargs: Optional keyword arguments passed through to Saver.
298
299  Returns:
300    The `MetaGraphDef` protocol buffer loaded in the provided session. This
301    can be used to further extract signature-defs, collection-defs, etc.
302
303  Raises:
304    RuntimeError: MetaGraphDef associated with the tags cannot be found.
305
306  @compatibility(TF2)
307
308  `tf.compat.v1.saved_model.load` or `tf.compat.v1.saved_model.loader.load` is
309  not compatible with eager execution. Please use `tf.saved_model.load` instead
310  to load your model. You can refer to the [SavedModel guide]
311  (https://www.tensorflow.org/guide/saved_model) for more information as well as
312  "Importing SavedModels from TensorFlow 1.x" in the [`tf.saved_model.load`]
313  (https://www.tensorflow.org/api_docs/python/tf/saved_model/load) docstring.
314
315  #### How to Map Arguments
316
317  | TF1 Arg Name          | TF2 Arg Name    | Note                       |
318  | :-------------------- | :-------------- | :------------------------- |
319  | `sess`                | Not supported   | -                          |
320  | `tags`                | `tags`          | -                          |
321  | `export_dir`          | `export_dir`    | -                          |
322  | `import_scope`        | Not supported   | Name scopes are not needed.
323  :                       :                 : By default, variables are  :
324  :                       :                 : associated with the loaded :
325  :                       :                 : object and function names  :
326  :                       :                 : are deduped.               :
327  | `saver_kwargs`        | Not supported   | -                          |
328
329  #### Before & After Usage Example
330
331  Before:
332
333  ```
334  with tf.compat.v1.Session(graph=tf.Graph()) as sess:
335    tf.compat.v1.saved_model.loader.load(sess, ["foo-tag"], export_dir)
336  ```
337
338  After:
339
340  ```
341  model = tf.saved_model.load(export_dir, tags=["foo-tag"])
342  ```
343  @end_compatibility
344  """
345  loader = SavedModelLoader(export_dir)
346  return loader.load(sess, tags, import_scope, **saver_kwargs)
347
348
349class SavedModelLoader(object):
350  """Load graphs and restore variable values from a `SavedModel`."""
351
352  def __init__(self, export_dir):
353    """Creates a `SavedModelLoader`.
354
355    Args:
356      export_dir: Directory in which the SavedModel protocol buffer and
357        variables to be loaded are located.
358    """
359    self._export_dir = export_dir
360    self._variables_path = saved_model_utils.get_variables_path(export_dir)
361    self._saved_model = parse_saved_model(export_dir)
362
363  @property
364  def export_dir(self):
365    """Directory containing the SavedModel."""
366    return self._export_dir
367
368  @property
369  def variables_path(self):
370    """Path to variable checkpoint files."""
371    return self._variables_path
372
373  @property
374  def saved_model(self):
375    """SavedModel object parsed from the export directory."""
376    return self._saved_model
377
378  def get_meta_graph_def_from_tags(self, tags):
379    """Return MetaGraphDef with the exact specified tags.
380
381    Args:
382      tags: A list or set of string tags that identify the MetaGraphDef.
383
384    Returns:
385      MetaGraphDef with the same tags.
386
387    Raises:
388      RuntimeError: if no metagraphs were found with the associated tags.
389    """
390    found_match = False
391    available_tags = []
392    for meta_graph_def in self._saved_model.meta_graphs:
393      available_tags.append(set(meta_graph_def.meta_info_def.tags))
394      if set(meta_graph_def.meta_info_def.tags) == set(tags):
395        meta_graph_def_to_load = meta_graph_def
396        found_match = True
397        break
398
399    if not found_match:
400      raise RuntimeError(
401          "MetaGraphDef associated with tags " + str(tags).strip("[]") +
402          " could not be found in SavedModel. To inspect available tag-sets in"
403          " the SavedModel, please use the SavedModel CLI: `saved_model_cli`"
404          "\navailable_tags: " + str(available_tags))
405    return meta_graph_def_to_load
406
407  def load_graph(self, graph, tags, import_scope=None, **saver_kwargs):
408    """Load ops and nodes from SavedModel MetaGraph into graph.
409
410    Args:
411      graph: tf.Graph object.
412      tags: a set of string tags identifying a MetaGraphDef.
413      import_scope: Optional `string` -- if specified, prepend this string
414        followed by '/' to all loaded tensor names. This scope is applied to
415        tensor instances loaded into the passed session, but it is *not* written
416        through to the static `MetaGraphDef` protocol buffer that is returned.
417      **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
418
419    Returns:
420      A tuple of
421        * Saver defined by the MetaGraph, which can be used to restore the
422          variable values.
423        * List of `Operation`/`Tensor` objects returned from
424          `tf.import_graph_def` (may be `None`).
425    """
426    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
427    with graph.as_default():
428      return tf_saver._import_meta_graph_with_return_elements(  # pylint: disable=protected-access
429          meta_graph_def, import_scope=import_scope, **saver_kwargs)
430
431  def restore_variables(self, sess, saver, import_scope=None):
432    """Restore SavedModel variable values into the session.
433
434    Args:
435      sess: tf.compat.v1.Session to restore variable values.
436      saver: a tf.compat.v1.train.Saver object. Can be None if there are no
437        variables in graph. This may be the saver returned by the load_graph()
438        function, or a default `tf.compat.v1.train.Saver()`.
439      import_scope: Optional `string` -- if specified, prepend this string
440        followed by '/' to all loaded tensor names. This scope is applied to
441        tensor instances loaded into the passed session, but it is *not* written
442        through to the static `MetaGraphDef` protocol buffer that is returned.
443
444    Raises:
445      ValueError: if no saver was passed to the saver argument, and there are
446        variables in the graph.
447    """
448    with sess.graph.as_default():
449      if (saver is None and
450          not variables._all_saveable_objects(scope=import_scope)):  # pylint: disable=protected-access
451        tf_logging.info("The specified SavedModel has no variables; no "
452                        "checkpoints were restored.")
453      elif isinstance(saver, tf_saver.Saver):
454        saver.restore(sess, self._variables_path)
455      else:
456        raise ValueError(
457            "No tf.train.Saver object was passed to the function "
458            "SavedModelLoader.restore_variables. Since there are variables in "
459            "the graph, a saver is required.")
460
461  def run_init_ops(self, sess, tags, import_scope=None):
462    """Run initialization ops defined in the `MetaGraphDef`.
463
464    Args:
465      sess: tf.compat.v1.Session to restore variable values.
466      tags: a set of string tags identifying a MetaGraphDef.
467      import_scope: Optional `string` -- if specified, prepend this string
468        followed by '/' to all loaded tensor names. This scope is applied to
469        tensor instances loaded into the passed session, but it is *not* written
470        through to the static `MetaGraphDef` protocol buffer that is returned.
471    """
472    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
473    with sess.graph.as_default():
474      # Get asset tensors, if any.
475      asset_tensors_dictionary = get_asset_tensors(
476          self._export_dir, meta_graph_def, import_scope=import_scope)
477
478      init_op = get_init_op(meta_graph_def, import_scope)
479      if init_op is not None:
480        sess.run(fetches=[init_op], feed_dict=asset_tensors_dictionary)
481
482  def load(self, sess, tags, import_scope=None, **saver_kwargs):
483    """Load the MetaGraphDef graph and restore variable values into the session.
484
485    Args:
486      sess: tf.compat.v1.Session to restore variable values.
487      tags: a set of string tags identifying a MetaGraphDef.
488      import_scope: Optional `string` -- if specified, prepend this string
489        followed by '/' to all loaded tensor names. This scope is applied to
490        tensor instances loaded into the passed session, but it is *not* written
491        through to the static `MetaGraphDef` protocol buffer that is returned.
492      **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
493
494    Returns:
495      `MetagraphDef` proto of the graph that was loaded.
496    """
497    saved_model_proto = parse_saved_model(self._export_dir)
498    metrics.IncrementReadApi(_LOADER_LABEL)
499
500    with sess.graph.as_default():
501      saver, _ = self.load_graph(sess.graph, tags, import_scope,
502                                 **saver_kwargs)
503      self.restore_variables(sess, saver, import_scope)
504      self.run_init_ops(sess, tags, import_scope)
505    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
506
507    if (len(saved_model_proto.meta_graphs) == 1 and
508        saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
509      metrics.IncrementRead(write_version="2")
510    else:
511      metrics.IncrementRead(write_version="1")
512
513    return meta_graph_def
514