• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Loader implementation for SavedModel with hermetic, language-neutral exports.
16"""
17
18import os
19import sys
20
21from google.protobuf import message
22from google.protobuf import text_format
23
24from tensorflow.core.protobuf import graph_debug_info_pb2
25from tensorflow.core.protobuf import meta_graph_pb2
26from tensorflow.core.protobuf import saved_model_pb2
27from tensorflow.python.framework import ops
28from tensorflow.python.lib.io import file_io
29from tensorflow.python.ops import variables
30from tensorflow.python.platform import tf_logging
31from tensorflow.python.saved_model import constants
32from tensorflow.python.saved_model import signature_def_utils
33from tensorflow.python.saved_model import utils_impl as saved_model_utils
34from tensorflow.python.saved_model.pywrap_saved_model import metrics
35from tensorflow.python.training import saver as tf_saver
36from tensorflow.python.util import compat
37from tensorflow.python.util import deprecation
38from tensorflow.python.util.tf_export import tf_export
39
40# API label for SavedModel metrics.
41_LOADER_LABEL = "loader"
42
43
44def parse_saved_model_with_debug_info(export_dir):
45  """Reads the savedmodel as well as the graph debug info.
46
47  Args:
48    export_dir: Directory containing the SavedModel and GraphDebugInfo files.
49
50  Returns:
51    `SavedModel` and `GraphDebugInfo` protocol buffers.
52
53  Raises:
54    IOError: If the saved model file does not exist, or cannot be successfully
55    parsed. Missing graph debug info file is fine.
56  """
57  saved_model = parse_saved_model(export_dir)
58
59  debug_info_path = file_io.join(
60      saved_model_utils.get_debug_dir(export_dir),
61      constants.DEBUG_INFO_FILENAME_PB)
62  debug_info = graph_debug_info_pb2.GraphDebugInfo()
63  if file_io.file_exists(debug_info_path):
64    with file_io.FileIO(debug_info_path, "rb") as debug_file:
65      try:
66        debug_info.ParseFromString(debug_file.read())
67      except message.DecodeError as e:
68        raise IOError(f"Cannot parse file {debug_info_path}: {e}.")
69
70  return (saved_model, debug_info)
71
72
73@tf_export("__internal__.saved_model.parse_saved_model", v1=[])
74def parse_saved_model(export_dir):
75  """Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`.
76
77  Args:
78    export_dir: String or Pathlike, path to the directory containing the
79    SavedModel file.
80
81  Returns:
82    A `SavedModel` protocol buffer.
83
84  Raises:
85    IOError: If the file does not exist, or cannot be successfully parsed.
86  """
87  # Build the path to the SavedModel in pbtxt format.
88  path_to_pbtxt = file_io.join(
89      compat.as_bytes(compat.path_to_str(export_dir)),
90      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
91  # Build the path to the SavedModel in pb format.
92  path_to_pb = file_io.join(
93      compat.as_bytes(compat.path_to_str(export_dir)),
94      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
95
96  # Parse the SavedModel protocol buffer.
97  saved_model = saved_model_pb2.SavedModel()
98  if file_io.file_exists(path_to_pb):
99    with file_io.FileIO(path_to_pb, "rb") as f:
100      file_content = f.read()
101    try:
102      saved_model.ParseFromString(file_content)
103      return saved_model
104    except message.DecodeError as e:
105      raise IOError(f"Cannot parse file {path_to_pb}: {str(e)}.")
106  elif file_io.file_exists(path_to_pbtxt):
107    with file_io.FileIO(path_to_pbtxt, "rb") as f:
108      file_content = f.read()
109    try:
110      text_format.Merge(file_content.decode("utf-8"), saved_model)
111      return saved_model
112    except text_format.ParseError as e:
113      raise IOError(f"Cannot parse file {path_to_pbtxt}: {str(e)}.")
114  else:
115    raise IOError(
116        f"SavedModel file does not exist at: {export_dir}{os.path.sep}"
117        f"{{{constants.SAVED_MODEL_FILENAME_PBTXT}|"
118        f"{constants.SAVED_MODEL_FILENAME_PB}}}")
119
120
121def get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
122  """Gets the asset tensors, if defined in the meta graph def to load.
123
124  Args:
125    export_dir: Directory where the SavedModel is located.
126    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
127    import_scope: Optional `string` -- if specified, prepend this followed by
128        '/' to all returned asset tensor names.
129
130  Returns:
131    A dictionary of asset tensors, keyed by the name of the asset tensor. The
132    value in the map corresponds to the absolute path of the asset file.
133  """
134  # Collection-def that may contain the assets key.
135  collection_def = meta_graph_def_to_load.collection_def
136
137  asset_tensor_dict = {}
138  asset_protos = []
139
140  if meta_graph_def_to_load.asset_file_def:
141    asset_protos = meta_graph_def_to_load.asset_file_def
142  elif constants.ASSETS_KEY in collection_def:
143    assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value
144    for asset_any_proto in assets_any_proto:
145      asset_proto = meta_graph_pb2.AssetFileDef()
146      asset_any_proto.Unpack(asset_proto)
147      asset_protos.append(asset_proto)
148
149  # Location of the assets for SavedModel.
150  assets_directory = file_io.join(
151      compat.as_bytes(export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY))
152  # Process each asset and add it to the asset tensor dictionary.
153  for asset_proto in asset_protos:
154    tensor_name = asset_proto.tensor_info.name
155    if import_scope:
156      tensor_name = "%s/%s" % (import_scope, tensor_name)
157    asset_tensor_dict[tensor_name] = file_io.join(
158        compat.as_bytes(assets_directory),
159        compat.as_bytes(asset_proto.filename))
160
161  return asset_tensor_dict
162
163
164def _get_main_op_tensor(
165    meta_graph_def_to_load, init_op_key=constants.MAIN_OP_KEY):
166  """Gets the main op tensor, if one exists.
167
168  Args:
169    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
170    init_op_key: name of the collection to check; should be one of MAIN_OP_KEY
171      or the deprecated LEGACY_INIT_OP_KEY
172
173  Returns:
174    The main op tensor, if it exists and `None` otherwise.
175
176  Raises:
177    RuntimeError: If the collection def corresponding to the main op key has
178        other than exactly one tensor.
179  """
180  # TODO(kathywu): Rename this method to _get_op_from_collection when
181  # dependency from SavedModelEstimator is removed.
182  collection_def = meta_graph_def_to_load.collection_def
183  init_op = None
184  if init_op_key in collection_def:
185    init_op_list = collection_def[init_op_key].node_list.value
186    if len(init_op_list) != 1:
187      raise RuntimeError("Expected exactly one SavedModel init op. "
188                         f"Found {len(init_op_list)}: {init_op_list}.")
189    init_op = ops.get_collection(init_op_key)[0]
190  return init_op
191
192
193def _get_op_from_collection(meta_graph_def, op_key):
194  return _get_main_op_tensor(meta_graph_def, op_key)
195
196
197def _get_op_from_signature_def(meta_graph_def, op_signature_key, import_scope):
198  """Retrieve op stored in the imported meta graph's signature def."""
199  if op_signature_key in meta_graph_def.signature_def:
200    return signature_def_utils.load_op_from_signature_def(
201        meta_graph_def.signature_def[op_signature_key], op_signature_key,
202        import_scope)
203  else:
204    return None
205
206
207def get_init_op(meta_graph_def, import_scope=None):
208  return (_get_op_from_signature_def(
209      meta_graph_def, constants.INIT_OP_SIGNATURE_KEY, import_scope) or
210          _get_op_from_collection(meta_graph_def, constants.MAIN_OP_KEY) or
211          _get_op_from_collection(meta_graph_def, constants.LEGACY_INIT_OP_KEY))
212
213
214def get_train_op(meta_graph_def, import_scope=None):
215  train_op = _get_op_from_signature_def(
216      meta_graph_def, constants.TRAIN_OP_SIGNATURE_KEY, import_scope)
217  if train_op is None:
218    train_op = _get_op_from_collection(meta_graph_def, constants.TRAIN_OP_KEY)
219  return train_op
220
221
222@tf_export(v1=[
223    "saved_model.contains_saved_model",
224    "saved_model.maybe_saved_model_directory",
225    "saved_model.loader.maybe_saved_model_directory"
226])
227@deprecation.deprecated_endpoints(
228    "saved_model.loader.maybe_saved_model_directory")
229def maybe_saved_model_directory(export_dir):
230  """Checks whether the provided export directory could contain a SavedModel.
231
232  Note that the method does not load any data by itself. If the method returns
233  `false`, the export directory definitely does not contain a SavedModel. If the
234  method returns `true`, the export directory may contain a SavedModel but
235  provides no guarantee that it can be loaded.
236
237  Args:
238    export_dir: Absolute string path to possible export location. For example,
239                '/my/foo/model'.
240
241  Returns:
242    True if the export directory contains SavedModel files, False otherwise.
243  """
244  txt_path = file_io.join(export_dir, constants.SAVED_MODEL_FILENAME_PBTXT)
245  pb_path = file_io.join(export_dir, constants.SAVED_MODEL_FILENAME_PB)
246  return file_io.file_exists(txt_path) or file_io.file_exists(pb_path)
247
248
249@tf_export("saved_model.contains_saved_model", v1=[])
250def contains_saved_model(export_dir):
251  """Checks whether the provided export directory could contain a SavedModel.
252
253  Note that the method does not load any data by itself. If the method returns
254  `false`, the export directory definitely does not contain a SavedModel. If the
255  method returns `true`, the export directory may contain a SavedModel but
256  provides no guarantee that it can be loaded.
257
258  Args:
259    export_dir: Absolute path to possible export location. For example,
260                '/my/foo/model'.
261
262  Returns:
263    True if the export directory contains SavedModel files, False otherwise.
264  """
265  if isinstance(export_dir, os.PathLike):
266    export_dir = os.fspath(export_dir)
267  return maybe_saved_model_directory(export_dir)
268
269
270@tf_export(v1=["saved_model.load", "saved_model.loader.load"])
271@deprecation.deprecated(
272    None,
273    "Use `tf.saved_model.load` instead.")
274def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
275  """Loads the model from a SavedModel as specified by tags.
276
277  Args:
278    sess: The TensorFlow session to restore the variables.
279    tags: Set of string tags to identify the required MetaGraphDef. These should
280        correspond to the tags used when saving the variables using the
281        SavedModel `save()` API.
282    export_dir: Directory in which the SavedModel protocol buffer and variables
283        to be loaded are located.
284    import_scope: Optional `string` -- if specified, prepend this string
285        followed by '/' to all loaded tensor names. This scope is applied to
286        tensor instances loaded into the passed session, but it is *not* written
287        through to the static `MetaGraphDef` protocol buffer that is returned.
288    **saver_kwargs: Optional keyword arguments passed through to Saver.
289
290  Returns:
291    The `MetaGraphDef` protocol buffer loaded in the provided session. This
292    can be used to further extract signature-defs, collection-defs, etc.
293
294  Raises:
295    RuntimeError: MetaGraphDef associated with the tags cannot be found.
296
297  @compatibility(TF2)
298
299  `tf.compat.v1.saved_model.load` or `tf.compat.v1.saved_model.loader.load` is
300  not compatible with eager execution. Please use `tf.saved_model.load` instead
301  to load your model. You can refer to the [SavedModel guide]
302  (https://www.tensorflow.org/guide/saved_model) for more information as well as
303  "Importing SavedModels from TensorFlow 1.x" in the [`tf.saved_model.load`]
304  (https://www.tensorflow.org/api_docs/python/tf/saved_model/load) docstring.
305
306  #### How to Map Arguments
307
308  | TF1 Arg Name          | TF2 Arg Name    | Note                       |
309  | :-------------------- | :-------------- | :------------------------- |
310  | `sess`                | Not supported   | -                          |
311  | `tags`                | `tags`          | -                          |
312  | `export_dir`          | `export_dir`    | -                          |
313  | `import_scope`        | Not supported   | Name scopes are not needed.
314  :                       :                 : By default, variables are  :
315  :                       :                 : associated with the loaded :
316  :                       :                 : object and function names  :
317  :                       :                 : are deduped.               :
318  | `saver_kwargs`        | Not supported   | -                          |
319
320  #### Before & After Usage Example
321
322  Before:
323
324  ```
325  with tf.compat.v1.Session(graph=tf.Graph()) as sess:
326    tf.compat.v1.saved_model.loader.load(sess, ["foo-tag"], export_dir)
327  ```
328
329  After:
330
331  ```
332  model = tf.saved_model.load(export_dir, tags=["foo-tag"])
333  ```
334  @end_compatibility
335  """
336  loader = SavedModelLoader(export_dir)
337  return loader.load(sess, tags, import_scope, **saver_kwargs)
338
339
340class SavedModelLoader(object):
341  """Load graphs and restore variable values from a `SavedModel`."""
342
343  def __init__(self, export_dir):
344    """Creates a `SavedModelLoader`.
345
346    Args:
347      export_dir: Directory in which the SavedModel protocol buffer and
348        variables to be loaded are located.
349    """
350    self._export_dir = export_dir
351    self._variables_path = saved_model_utils.get_variables_path(export_dir)
352    self._saved_model = parse_saved_model(export_dir)
353
354  @property
355  def export_dir(self):
356    """Directory containing the SavedModel."""
357    return self._export_dir
358
359  @property
360  def variables_path(self):
361    """Path to variable checkpoint files."""
362    return self._variables_path
363
364  @property
365  def saved_model(self):
366    """SavedModel object parsed from the export directory."""
367    return self._saved_model
368
369  def get_meta_graph_def_from_tags(self, tags):
370    """Return MetaGraphDef with the exact specified tags.
371
372    Args:
373      tags: A list or set of string tags that identify the MetaGraphDef.
374
375    Returns:
376      MetaGraphDef with the same tags.
377
378    Raises:
379      RuntimeError: if no metagraphs were found with the associated tags.
380    """
381    found_match = False
382    available_tags = []
383    for meta_graph_def in self._saved_model.meta_graphs:
384      available_tags.append(set(meta_graph_def.meta_info_def.tags))
385      if set(meta_graph_def.meta_info_def.tags) == set(tags):
386        meta_graph_def_to_load = meta_graph_def
387        found_match = True
388        break
389
390    if not found_match:
391      raise RuntimeError(
392          f"MetaGraphDef associated with tags {str(tags).strip('[]')} "
393          "could not be found in SavedModel, with available tags "
394          f"'{available_tags}'. To inspect available tag-sets in"
395          " the SavedModel, please use the SavedModel CLI: `saved_model_cli`.")
396    return meta_graph_def_to_load
397
398  def load_graph(self, graph, tags, import_scope=None, **saver_kwargs):
399    """Load ops and nodes from SavedModel MetaGraph into graph.
400
401    Args:
402      graph: tf.Graph object.
403      tags: a set of string tags identifying a MetaGraphDef.
404      import_scope: Optional `string` -- if specified, prepend this string
405        followed by '/' to all loaded tensor names. This scope is applied to
406        tensor instances loaded into the passed session, but it is *not* written
407        through to the static `MetaGraphDef` protocol buffer that is returned.
408      **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
409
410    Returns:
411      A tuple of
412        * Saver defined by the MetaGraph, which can be used to restore the
413          variable values.
414        * List of `Operation`/`Tensor` objects returned from
415          `tf.import_graph_def` (may be `None`).
416    """
417    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
418    if sys.byteorder == "big":
419      saved_model_utils.swap_function_tensor_content(meta_graph_def, "little",
420                                                     "big")
421    with graph.as_default():
422      return tf_saver._import_meta_graph_with_return_elements(  # pylint: disable=protected-access
423          meta_graph_def, import_scope=import_scope, **saver_kwargs)
424
425  def restore_variables(self, sess, saver, import_scope=None):
426    """Restore SavedModel variable values into the session.
427
428    Args:
429      sess: tf.compat.v1.Session to restore variable values.
430      saver: a tf.compat.v1.train.Saver object. Can be None if there are no
431        variables in graph. This may be the saver returned by the load_graph()
432        function, or a default `tf.compat.v1.train.Saver()`.
433      import_scope: Optional `string` -- if specified, prepend this string
434        followed by '/' to all loaded tensor names. This scope is applied to
435        tensor instances loaded into the passed session, but it is *not* written
436        through to the static `MetaGraphDef` protocol buffer that is returned.
437
438    Raises:
439      ValueError: if no saver was passed to the saver argument, and there are
440        variables in the graph.
441    """
442    with sess.graph.as_default():
443      if (saver is None and
444          not variables._all_saveable_objects(scope=import_scope)):  # pylint: disable=protected-access
445        tf_logging.info("The specified SavedModel has no variables; no "
446                        "checkpoints were restored.")
447      elif isinstance(saver, tf_saver.Saver):
448        saver.restore(sess, self._variables_path)
449      else:
450        raise ValueError(
451            "No tf.train.Saver object was passed to the function "
452            "`SavedModelLoader.restore_variables`. Since there are variables in"
453            " the graph, a saver is required.")
454
455  def run_init_ops(self, sess, tags, import_scope=None):
456    """Run initialization ops defined in the `MetaGraphDef`.
457
458    Args:
459      sess: tf.compat.v1.Session to restore variable values.
460      tags: a set of string tags identifying a MetaGraphDef.
461      import_scope: Optional `string` -- if specified, prepend this string
462        followed by '/' to all loaded tensor names. This scope is applied to
463        tensor instances loaded into the passed session, but it is *not* written
464        through to the static `MetaGraphDef` protocol buffer that is returned.
465    """
466    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
467    with sess.graph.as_default():
468      # Get asset tensors, if any.
469      asset_tensors_dictionary = get_asset_tensors(
470          self._export_dir, meta_graph_def, import_scope=import_scope)
471
472      init_op = get_init_op(meta_graph_def, import_scope)
473      if init_op is not None:
474        sess.run(fetches=[init_op], feed_dict=asset_tensors_dictionary)
475
476  def load(self, sess, tags, import_scope=None, **saver_kwargs):
477    """Load the MetaGraphDef graph and restore variable values into the session.
478
479    Args:
480      sess: tf.compat.v1.Session to restore variable values.
481      tags: a set of string tags identifying a MetaGraphDef.
482      import_scope: Optional `string` -- if specified, prepend this string
483        followed by '/' to all loaded tensor names. This scope is applied to
484        tensor instances loaded into the passed session, but it is *not* written
485        through to the static `MetaGraphDef` protocol buffer that is returned.
486      **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
487
488    Returns:
489      `MetagraphDef` proto of the graph that was loaded.
490    """
491    saved_model_proto = parse_saved_model(self._export_dir)
492    metrics.IncrementReadApi(_LOADER_LABEL)
493
494    with sess.graph.as_default():
495      saver, _ = self.load_graph(sess.graph, tags, import_scope,
496                                 **saver_kwargs)
497      self.restore_variables(sess, saver, import_scope)
498      self.run_init_ops(sess, tags, import_scope)
499    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
500
501    if (len(saved_model_proto.meta_graphs) == 1 and
502        saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
503      metrics.IncrementRead(write_version="2")
504    else:
505      metrics.IncrementRead(write_version="1")
506
507    return meta_graph_def
508