• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SavedModel builder implementation."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import os
23
24from google.protobuf.any_pb2 import Any
25
26from tensorflow.core.framework import types_pb2
27from tensorflow.core.protobuf import meta_graph_pb2
28from tensorflow.core.protobuf import saved_model_pb2
29from tensorflow.core.protobuf import saver_pb2
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.lib.io import file_io
33from tensorflow.python.ops import variables
34from tensorflow.python.platform import tf_logging
35from tensorflow.python.saved_model import constants
36from tensorflow.python.saved_model import signature_def_utils
37from tensorflow.python.saved_model import utils_impl as saved_model_utils
38from tensorflow.python.saved_model.pywrap_saved_model import metrics
39from tensorflow.python.training import saver as tf_saver
40from tensorflow.python.util import compat
41from tensorflow.python.util.deprecation import deprecated_args
42from tensorflow.python.util.tf_export import tf_export
43
44# API label for SavedModel metrics.
45_SAVE_BUILDER_LABEL = "save_v1_builder"
46
47
48# Base class for the SavedModelBuilder that is only used by Tensorflow
49# internally. Please use tf.compat.v1.saved_model.SavedModelBuilder instead.
50@tf_export("__internal__.saved_model.SavedModelBuilder", v1=[])
51class _SavedModelBuilder(object):
52  """Builds the `SavedModel` protocol buffer and saves variables and assets.
53
54  The `SavedModelBuilder` class provides the functionality to build a
55  `SavedModel` protocol buffer. Specifically, this allows multiple meta
56  graphs to be saved as part of a single language-neutral `SavedModel`,
57  while sharing variables and assets.
58
59  To build a SavedModel, the first meta graph must be saved with variables.
60  Subsequent meta graphs will simply be saved with their graph definitions. If
61  assets need to be saved and written or copied to disk, they can be provided
62  when the meta graph def is added. If multiple meta graph defs are associated
63  an asset of the same name, only the first version is retained.
64
65  Each meta graph added to the SavedModel must be annotated with tags. The tags
66  provide a means to identify the specific meta graph to load and restore, along
67  with the shared set of variables and assets.
68
69  Typical usage for the `SavedModelBuilder`:
70
71  ```python
72  ...
73  builder = tf.compat.v1.saved_model.Builder(export_dir)
74
75  with tf.compat.v1.Session(graph=tf.Graph()) as sess:
76    ...
77    builder.add_meta_graph_and_variables(sess,
78                                    ["foo-tag"],
79                                    signature_def_map=foo_signatures,
80                                    assets_list=foo_assets)
81  ...
82
83  with tf.compat.v1.Session(graph=tf.Graph()) as sess:
84    ...
85    builder.add_meta_graph(["bar-tag", "baz-tag"])
86  ...
87
88  builder.save()
89  ```
90
91  Note: This function will only be available through the v1 compatibility
92  library as tf.compat.v1.saved_model.builder.SavedModelBuilder or
93  tf.compat.v1.saved_model.Builder. Tensorflow 2.0 will introduce a new
94  object-based method of creating SavedModels.
95  """
96
97  def __init__(self, export_dir):
98    self._saved_model = saved_model_pb2.SavedModel()
99    self._saved_model.saved_model_schema_version = (
100        constants.SAVED_MODEL_SCHEMA_VERSION)
101
102    self._export_dir = export_dir
103    if file_io.file_exists(export_dir):
104      if file_io.list_directory(export_dir):
105        raise AssertionError(
106            "Export directory already exists, and isn't empty. Please choose "
107            "a different export directory, or delete all the contents of the "
108            "specified directory: %s" % export_dir)
109    else:
110      file_io.recursive_create_dir(self._export_dir)
111
112    # Boolean to track whether variables and assets corresponding to the
113    # SavedModel have been saved. Specifically, the first meta graph to be added
114    # MUST use the add_meta_graph_and_variables() API. Subsequent add operations
115    # on the SavedModel MUST use the add_meta_graph() API which does not save
116    # weights.
117    self._has_saved_variables = False
118
119  def _save_and_write_assets(self, meta_graph_def, assets_list=None):
120    """Saves asset to the meta graph and writes asset files to disk.
121
122    Args:
123      meta_graph_def: The meta graph def to which the assets will be added.
124      assets_list: The list where the asset paths are setup.
125    """
126    # Creates a function that adds assets into the meta graph def.
127    write_fn = functools.partial(_add_asset_to_metagraph, meta_graph_def)
128    asset_filename_map = _maybe_save_assets(write_fn, assets_list)
129
130    # Return if there are no assets to write.
131    if not asset_filename_map:
132      tf_logging.info("No assets to write.")
133      return
134
135    # Copy assets from source path to destination path.
136    copy_assets_to_destination_dir(asset_filename_map, self._export_dir)
137
138  def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map):
139    """Tags the meta graph def and adds it to the SavedModel.
140
141    Tags the meta graph def with the supplied tags, adds signature defs to it if
142    provided and appends the meta graph def to the SavedModel proto.
143
144    Args:
145      meta_graph_def: The meta graph def to add to the SavedModel.
146      tags: The set of tags to annotate the meta graph def with.
147      signature_def_map: The map of signature defs to be added to the meta graph
148          def.
149    """
150    for tag in tags:
151      meta_graph_def.meta_info_def.tags.append(tag)
152
153    if signature_def_map is not None:
154      for key in signature_def_map:
155        meta_graph_def.signature_def[key].CopyFrom(signature_def_map[key])
156
157    proto_meta_graph_def = self._saved_model.meta_graphs.add()
158    proto_meta_graph_def.CopyFrom(meta_graph_def)
159
160  def _validate_tensor_info(self, tensor_info):
161    """Validates the `TensorInfo` proto.
162
163    Checks if the `encoding` (`name` or `coo_sparse` or `type_spec`) and
164    `dtype` fields exist and are non-empty.
165
166    Args:
167      tensor_info: `TensorInfo` protocol buffer to validate.
168
169    Raises:
170      AssertionError: If the `encoding` or `dtype` fields of the supplied
171          `TensorInfo` proto are not populated.
172    """
173    if tensor_info is None:
174      raise AssertionError(
175          "All TensorInfo protos used in the SignatureDefs must have the name "
176          "and dtype fields set.")
177    if tensor_info.WhichOneof("encoding") is None:
178      # TODO(soergel) validate each of the fields of coo_sparse
179      raise AssertionError(
180          "All TensorInfo protos used in the SignatureDefs must have one of "
181          "the 'encoding' fields (e.g., name or coo_sparse) set: %s"
182          % tensor_info)
183    if tensor_info.WhichOneof("encoding") == "composite_tensor":
184      for component in tensor_info.composite_tensor.components:
185        self._validate_tensor_info(component)
186    elif tensor_info.dtype == types_pb2.DT_INVALID:
187      raise AssertionError(
188          "All TensorInfo protos used in the SignatureDefs must have the dtype "
189          "field set: %s" % tensor_info)
190
191  def _validate_signature_def_map(self, signature_def_map):
192    """Validates the `SignatureDef` entries in the signature def map.
193
194    Validation of entries in the signature def map includes ensuring that the
195    `name` and `dtype` fields of the TensorInfo protos of the `inputs` and
196    `outputs` of each `SignatureDef` are populated. Also ensures that reserved
197    SignatureDef keys for the initialization and train ops are not used.
198
199    Args:
200      signature_def_map: The map of signature defs to be validated.
201
202    Raises:
203      AssertionError: If a TensorInfo is not valid.
204      KeyError: If a reserved signature key is used in the map.
205    """
206    for signature_def_key in signature_def_map:
207      signature_def = signature_def_map[signature_def_key]
208      inputs = signature_def.inputs
209      outputs = signature_def.outputs
210      for inputs_key in inputs:
211        self._validate_tensor_info(inputs[inputs_key])
212      for outputs_key in outputs:
213        self._validate_tensor_info(outputs[outputs_key])
214    if constants.INIT_OP_SIGNATURE_KEY in signature_def_map:
215      raise KeyError(
216          "SignatureDef map key \"{}\" is reserved for initialization. Please "
217          "use a different key.".format(constants.INIT_OP_SIGNATURE_KEY))
218    if constants.TRAIN_OP_SIGNATURE_KEY in signature_def_map:
219      raise KeyError(
220          "SignatureDef map key \"{}\" is reserved for the train op. Please "
221          "use a different key.".format(constants.TRAIN_OP_SIGNATURE_KEY))
222
223  def _maybe_create_saver(self, saver=None):
224    """Creates a sharded saver if one does not already exist."""
225    if not saver:
226      # Initialize a saver to generate a sharded output for all saveables in the
227      # current scope.
228      saver = tf_saver.Saver(
229          variables._all_saveable_objects(),  # pylint: disable=protected-access
230          sharded=True,
231          write_version=saver_pb2.SaverDef.V2,
232          allow_empty=True)
233    return saver
234
235  def add_meta_graph(self,
236                     tags,
237                     signature_def_map=None,
238                     assets_list=None,
239                     clear_devices=False,
240                     init_op=None,
241                     train_op=None,
242                     saver=None):
243    """Adds the current meta graph to the SavedModel.
244
245    Creates a Saver in the current scope and uses the Saver to export the meta
246    graph def. Invoking this API requires the `add_meta_graph_and_variables()`
247    API to have been invoked before.
248
249    Args:
250      tags: The set of tags to annotate the meta graph def with.
251      signature_def_map: The map of signature defs to be added to the meta graph
252          def.
253      assets_list: Assets to be saved with SavedModel. Note
254          that this list should be a subset of the assets saved as part of
255          the first meta graph in the SavedModel.
256      clear_devices: Set to true if the device info on the default graph should
257          be cleared.
258      init_op: Op or group of ops to execute when the graph is loaded. Note
259          that when the init_op is specified it is run after the restore op at
260          load-time.
261      train_op: Op or group of opts that trains the model when run. This will
262        not be run automatically when the graph is loaded, instead saved in
263        a SignatureDef accessible through the exported MetaGraph.
264      saver: An instance of tf.compat.v1.train.Saver that will be used to export
265        the metagraph. If None, a sharded Saver that restores all variables will
266        be used.
267
268    Raises:
269      AssertionError: If the variables for the SavedModel have not been saved
270          yet, or if the graph already contains one or more legacy init ops.
271    """
272    if not self._has_saved_variables:
273      raise AssertionError(
274          "Graph state including variables and assets has not been saved yet. "
275          "Please invoke `add_meta_graph_and_variables()` first.")
276
277    # Validate the signature def map to ensure all included TensorInfos are
278    # properly populated.
279    signature_def_map = signature_def_map or {}
280    self._validate_signature_def_map(signature_def_map)
281
282    # Create a SignatureDef pointing to the graph initialization op, which will
283    # be added to the MetaGraphDef.
284    _add_op_to_signature_def_map(signature_def_map, init_op,
285                                 constants.INIT_OP_SIGNATURE_KEY)
286    _add_op_to_signature_def_map(signature_def_map, train_op,
287                                 constants.TRAIN_OP_SIGNATURE_KEY)
288
289    saver = self._maybe_create_saver(saver)
290
291    # The graph almost certainly previously contained at least one Saver, and
292    # possibly several (e.g. one for loading a pretrained embedding, and another
293    # for the model weights).  Removing the preexisting ones was the
294    # motivation for the clear_extraneous_savers option, but it turns out that
295    # there are edge cases where that option breaks the graph.  Until that is
296    # resolved, we just leave the option set to False for now.
297    # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
298    meta_graph_def = saver.export_meta_graph(
299        clear_devices=clear_devices, strip_default_attrs=True)
300
301    # Save asset files and write them to disk, if any.
302    self._save_and_write_assets(meta_graph_def, assets_list)
303
304    # Tag the meta graph def and add it to the SavedModel.
305    self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
306
307  def add_meta_graph_and_variables(self,
308                                   sess,
309                                   tags,
310                                   signature_def_map=None,
311                                   assets_list=None,
312                                   clear_devices=False,
313                                   init_op=None,
314                                   train_op=None,
315                                   strip_default_attrs=False,
316                                   saver=None):
317    # pylint: disable=line-too-long
318    """Adds the current meta graph to the SavedModel and saves variables.
319
320    Creates a Saver to save the variables from the provided session. Exports the
321    corresponding meta graph def. This function assumes that the variables to be
322    saved have been initialized. For a given `SavedModelBuilder`, this API must
323    be called exactly once and for the first meta graph to save. For subsequent
324    meta graph defs to be added, the `add_meta_graph()` API must be used.
325
326    Args:
327      sess: The TensorFlow session from which to save the meta graph and
328        variables.
329      tags: The set of tags with which to save the meta graph.
330      signature_def_map: The map of signature def map to add to the meta graph
331        def.
332      assets_list: Assets to be saved with SavedModel.
333      clear_devices: Set to true if the device info on the default graph should
334          be cleared.
335      init_op: Op or group of ops to execute when the graph is loaded. Note
336          that when the init_op is specified it is run after the restore op at
337          load-time.
338      train_op: Op or group of ops that trains the model when run. This will
339        not be run automatically when the graph is loaded, instead saved in
340        a SignatureDef accessible through the exported MetaGraph.
341      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
342        removed from the NodeDefs. For a detailed guide, see
343        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
344      saver: An instance of tf.compat.v1.train.Saver that will be used to export the
345        metagraph and save variables. If None, a sharded Saver that restores
346        all variables will be used.
347
348    """
349    # pylint: enable=line-too-long
350    if self._has_saved_variables:
351      raise AssertionError("Graph state including variables and assets has "
352                           "already been saved. Please invoke "
353                           "`add_meta_graph()` instead.")
354
355    # Validate the signature def map to ensure all included TensorInfos are
356    # properly populated.
357    signature_def_map = signature_def_map or {}
358    self._validate_signature_def_map(signature_def_map)
359
360    # Create a SignatureDef pointing to the graph initialization op, which will
361    # be added to the MetaGraphDef.
362    _add_op_to_signature_def_map(signature_def_map, init_op,
363                                 constants.INIT_OP_SIGNATURE_KEY)
364    _add_op_to_signature_def_map(signature_def_map, train_op,
365                                 constants.TRAIN_OP_SIGNATURE_KEY)
366
367    saved_model_utils.get_or_create_variables_dir(self._export_dir)
368    variables_path = saved_model_utils.get_variables_path(self._export_dir)
369
370    saver = self._maybe_create_saver(saver)
371
372    # Save the variables. Also, disable writing the checkpoint state proto. The
373    # file is not used during SavedModel loading. In addition, since a
374    # SavedModel can be copied or moved, this avoids the checkpoint state to
375    # become outdated.
376    saver.save(sess, variables_path, write_meta_graph=False, write_state=False)
377
378    # Export the meta graph def.
379
380    # The graph almost certainly previously contained at least one Saver, and
381    # possibly several (e.g. one for loading a pretrained embedding, and another
382    # for the model weights).  Removing the preexisting ones was the
383    # motivation for the clear_extraneous_savers option, but it turns out that
384    # there are edge cases where that option breaks the graph.  Until that is
385    # resolved, we just leave the option set to False for now.
386    # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
387    meta_graph_def = saver.export_meta_graph(
388        clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
389
390    # Save asset files and write them to disk, if any.
391    self._save_and_write_assets(meta_graph_def, assets_list)
392
393    # Tag the meta graph def and add it to the SavedModel.
394    self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
395
396    # Mark this instance of SavedModel as having saved variables, such that
397    # subsequent attempts to save variables will fail.
398    self._has_saved_variables = True
399
400  def save(self, as_text=False):
401    """Writes a `SavedModel` protocol buffer to disk.
402
403    The function writes the SavedModel protocol buffer to the export directory
404    in a serialized format.
405
406    Args:
407      as_text: Writes the SavedModel protocol buffer in text format to
408        disk. Protocol buffers in text format are useful for debugging, but
409        parsing fails when it encounters an unknown field and so is not forward
410        compatible. This means changes to TensorFlow may prevent deployment of
411        new text format SavedModels to existing serving binaries. Do not deploy
412        `as_text` SavedModels to production.
413
414    Returns:
415      The path to which the SavedModel protocol buffer was written.
416    """
417    metrics.IncrementWriteApi(_SAVE_BUILDER_LABEL)
418    if not file_io.file_exists(self._export_dir):
419      file_io.recursive_create_dir(self._export_dir)
420
421    if as_text:
422      path = os.path.join(
423          compat.as_bytes(self._export_dir),
424          compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
425      file_io.write_string_to_file(path, str(self._saved_model))
426    else:
427      path = os.path.join(
428          compat.as_bytes(self._export_dir),
429          compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
430      file_io.write_string_to_file(
431          path, self._saved_model.SerializeToString(deterministic=True))
432    tf_logging.info("SavedModel written to: %s", compat.as_text(path))
433    metrics.IncrementWrite(write_version="1")
434    return path
435
436
437@tf_export(v1=["saved_model.Builder", "saved_model.builder.SavedModelBuilder"])  # pylint: disable=missing-docstring
438class SavedModelBuilder(_SavedModelBuilder):
439  __doc__ = _SavedModelBuilder.__doc__.replace("assets_list",
440                                               "assets_collection")
441
442  def __init__(self, export_dir):
443    super(SavedModelBuilder, self).__init__(export_dir=export_dir)
444
445  def _add_collections(self, assets_collection, main_op, train_op):
446    """Add asset and op collections to be saved."""
447    # Save asset files and write them to disk, if any.
448    self._save_and_write_assets(assets_collection)
449
450    self._maybe_add_main_op(main_op)
451
452    self._add_train_op(train_op)
453
454  def _save_and_write_assets(self, assets_collection_to_add=None):
455    """Saves asset to the meta graph and writes asset files to disk.
456
457    Args:
458      assets_collection_to_add: The collection where the asset paths are setup.
459    """
460    # Add assets to the collection with key `saved_model.ASSETS_KEY`, in the
461    # graph.
462    asset_filename_map = _maybe_save_assets(_add_asset_to_collection,
463                                            assets_collection_to_add)
464
465    # Return if there are no assets to write.
466    if not asset_filename_map:
467      tf_logging.info("No assets to write.")
468      return
469
470    # Copy assets from source path to destination path.
471    copy_assets_to_destination_dir(asset_filename_map, self._export_dir)
472
473  def _maybe_add_main_op(self, main_op):
474    """Adds main op to the SavedModel.
475
476    Args:
477      main_op: Main op to run as part of graph initialization. If None, no main
478        op will be added to the graph.
479
480    Raises:
481      TypeError: If the main op is provided but is not of type `Operation`.
482      ValueError: if the Graph already contains an init op.
483    """
484    if main_op is None:
485      return
486
487    if not isinstance(main_op, ops.Operation):
488      raise TypeError("main_op needs to be an Operation: %r" % main_op)
489
490    # Validate that no other init ops have been added to this graph already.
491    # We check main_op and legacy_init_op for thoroughness and explicitness.
492    for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY):
493      if ops.get_collection(init_op_key):
494        raise ValueError(
495            "Graph already contains one or more main ops under the "
496            "collection {}.".format(init_op_key))
497
498    ops.add_to_collection(constants.MAIN_OP_KEY, main_op)
499
500  def _add_train_op(self, train_op):
501    """Add train op to the SavedModel.
502
503    Note that this functionality is in development, and liable to be
504    moved elsewhere.
505
506    Args:
507      train_op: Op or group of ops that are used for training. These are stored
508        as a collection with key TRAIN_OP_KEY, but not executed.
509
510    Raises:
511      TypeError if Train op is not of type `Operation`.
512    """
513    if train_op is not None:
514      if (not isinstance(train_op, ops.Tensor) and
515          not isinstance(train_op, ops.Operation)):
516        raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op)
517      ops.add_to_collection(constants.TRAIN_OP_KEY, train_op)
518
519  @deprecated_args(None,
520                   "Pass your op to the equivalent parameter main_op instead.",
521                   "legacy_init_op")
522  def add_meta_graph(self,
523                     tags,
524                     signature_def_map=None,
525                     assets_collection=None,
526                     legacy_init_op=None,
527                     clear_devices=False,
528                     main_op=None,
529                     strip_default_attrs=False,
530                     saver=None):
531    if not self._has_saved_variables:
532      raise AssertionError(
533          "Graph state including variables and assets has not been saved yet. "
534          "Please invoke `add_meta_graph_and_variables()` first.")
535
536    # Validate the signature def map to ensure all included TensorInfos are
537    # properly populated.
538    signature_def_map = signature_def_map or {}
539    self._validate_signature_def_map(signature_def_map)
540
541    # legacy_init_op is deprecated, and going away in TF 2.0.
542    # Re-mapping to main_op, as treatment is identical regardless.
543    main_op = main_op if main_op is not None else legacy_init_op
544
545    # Add assets and ops
546    self._add_collections(assets_collection, main_op, None)
547
548    saver = self._maybe_create_saver(saver)
549
550    # The graph almost certainly previously contained at least one Saver, and
551    # possibly several (e.g. one for loading a pretrained embedding, and another
552    # for the model weights).  Removing the preexisting ones was the
553    # motivation for the clear_extraneous_savers option, but it turns out that
554    # there are edge cases where that option breaks the graph.  Until that is
555    # resolved, we just leave the option set to False for now.
556    # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
557    meta_graph_def = saver.export_meta_graph(
558        clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
559
560    # Tag the meta graph def and add it to the SavedModel.
561    self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
562
563  @deprecated_args(None,
564                   "Pass your op to the equivalent parameter main_op instead.",
565                   "legacy_init_op")
566  def add_meta_graph_and_variables(self,
567                                   sess,
568                                   tags,
569                                   signature_def_map=None,
570                                   assets_collection=None,
571                                   legacy_init_op=None,
572                                   clear_devices=False,
573                                   main_op=None,
574                                   strip_default_attrs=False,
575                                   saver=None):
576    if self._has_saved_variables:
577      raise AssertionError("Graph state including variables and assets has "
578                           "already been saved. Please invoke "
579                           "`add_meta_graph()` instead.")
580
581    # Validate the signature def map to ensure all included TensorInfos are
582    # properly populated.
583    signature_def_map = signature_def_map or {}
584    self._validate_signature_def_map(signature_def_map)
585
586    # legacy_init_op is deprecated, and going away in TF 2.0.
587    # Re-mapping to main_op, as treatment is identical regardless.
588    main_op = main_op or legacy_init_op
589
590    # Add assets and ops
591    self._add_collections(assets_collection, main_op, None)
592
593    saved_model_utils.get_or_create_variables_dir(self._export_dir)
594    variables_path = saved_model_utils.get_variables_path(self._export_dir)
595
596    saver = self._maybe_create_saver(saver)
597
598    # Save the variables. Also, disable writing the checkpoint state proto. The
599    # file is not used during SavedModel loading. In addition, since a
600    # SavedModel can be copied or moved, this avoids the checkpoint state to
601    # become outdated.
602    saver.save(sess, variables_path, write_meta_graph=False, write_state=False)
603
604    # Export the meta graph def.
605
606    # The graph almost certainly previously contained at least one Saver, and
607    # possibly several (e.g. one for loading a pretrained embedding, and another
608    # for the model weights).  Removing the preexisting ones was the
609    # motivation for the clear_extraneous_savers option, but it turns out that
610    # there are edge cases where that option breaks the graph.  Until that is
611    # resolved, we just leave the option set to False for now.
612    # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
613    meta_graph_def = saver.export_meta_graph(
614        clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
615
616    # Tag the meta graph def and add it to the SavedModel.
617    self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
618
619    # Mark this instance of SavedModel as having saved variables, such that
620    # subsequent attempts to save variables will fail.
621    self._has_saved_variables = True
622
623  add_meta_graph.__doc__ = _SavedModelBuilder.add_meta_graph.__doc__.replace(
624      "assets_list", "assets_collection")
625  add_meta_graph_and_variables.__doc__ = \
626      _SavedModelBuilder.add_meta_graph_and_variables.__doc__.replace(
627          "assets_list", "assets_collection")
628
629
630def _maybe_save_assets(write_fn, assets_to_add=None):
631  """Saves assets to the meta graph.
632
633  Args:
634    write_fn: A function callback that writes assets into meta graph.
635    assets_to_add: The list where the asset paths are setup.
636
637  Returns:
638    A dict of asset basenames for saving to the original full path to the asset.
639
640  Raises:
641    ValueError: Indicating an invalid filepath tensor.
642  """
643  # Map of target file names to original filenames
644  asset_filename_map = {}
645
646  if assets_to_add is None:
647    tf_logging.info("No assets to save.")
648    return asset_filename_map
649
650  # Iterate over the supplied assets, build the `AssetFile` proto and add them
651  # to the meta graph.
652  for asset_tensor in assets_to_add:
653    asset_source_filepath = _asset_path_from_tensor(asset_tensor)
654    if not asset_source_filepath:
655      raise ValueError("Invalid asset filepath tensor %s" % asset_tensor)
656
657    asset_filename = get_asset_filename_to_add(
658        asset_source_filepath, asset_filename_map)
659
660    # Call the passed-in function that builds AssetFileDef proto and adds it
661    # to either the collection or asset_file_def field of the meta graph.
662    # Note that this should be done even when the file is a duplicate of an
663    # already-added file, as the tensor reference should still exist.
664    write_fn(asset_filename, asset_tensor)
665
666    # In the cases where we are adding a duplicate, this will result in the
667    # last of the filepaths being the one used for copying the file to the
668    # SavedModel. Since the files in question are the same, it doesn't matter
669    # either way.
670    asset_filename_map[asset_filename] = asset_source_filepath
671
672  tf_logging.info("Assets added to graph.")
673  return asset_filename_map
674
675
676def get_asset_filename_to_add(asset_filepath, asset_filename_map):
677  """Get a unique basename to add to the SavedModel if this file is unseen.
678
679  Assets come from users as full paths, and we save them out to the
680  SavedModel as basenames. In some cases, the basenames collide. Here,
681  we dedupe asset basenames by first checking if the file is the same,
682  and, if different, generate and return an index-suffixed basename
683  that can be used to add the asset to the SavedModel.
684
685  Args:
686    asset_filepath: the full path to the asset that is being saved
687    asset_filename_map: a dict of filenames used for saving the asset in
688      the SavedModel to full paths from which the filenames were derived.
689
690  Returns:
691    Uniquified filename string if the file is not a duplicate, or the original
692    filename if the file has already been seen and saved.
693  """
694  asset_filename = os.path.basename(asset_filepath)
695
696  if asset_filename not in asset_filename_map:
697    # This is an unseen asset. Safe to add.
698    return asset_filename
699
700  other_asset_filepath = asset_filename_map[asset_filename]
701  if other_asset_filepath == asset_filepath:
702    # This is the same file, stored twice in the list. No need
703    # to make unique.
704    return asset_filename
705
706  # Else, asset_filename is in the map, and the filepath is different. Dedupe.
707  if not file_io.filecmp(asset_filepath, other_asset_filepath):
708    # Files are different; dedupe filenames.
709    return _get_unique_asset_filename(asset_filename, asset_filename_map)
710
711  # Files are the same; don't make unique.
712  return asset_filename
713
714
715def _get_unique_asset_filename(asset_filename, asset_filename_map):
716  i = 1
717  unique_filename = asset_filename
718  while unique_filename in asset_filename_map:
719    unique_filename = compat.as_bytes("_").join(
720        [compat.as_bytes(asset_filename), compat.as_bytes(str(i))])
721    i += 1
722  return unique_filename
723
724
725def _asset_path_from_tensor(path_tensor):
726  """Returns the filepath value stored in constant `path_tensor`.
727
728  Args:
729    path_tensor: Tensor of a file-path.
730
731  Returns:
732    The string value i.e. path of the tensor, if valid.
733
734  Raises:
735    TypeError if tensor does not match expected op type, dtype or value.
736  """
737  if not isinstance(path_tensor, ops.Tensor):
738    raise TypeError("Asset path tensor must be a Tensor.")
739  if path_tensor.op.type != "Const":
740    raise TypeError("Asset path tensor must be of type constant.")
741  if path_tensor.dtype != dtypes.string:
742    raise TypeError("Asset path tensor must be of dtype string.")
743  str_values = path_tensor.op.get_attr("value").string_val
744  if len(str_values) != 1:
745    raise TypeError("Asset path tensor must be a scalar.")
746  return str_values[0]
747
748
749def _add_asset_to_metagraph(meta_graph_def, asset_filename, asset_tensor):
750  """Builds an asset proto and adds it to the meta graph def.
751
752  Args:
753    meta_graph_def: The meta graph def to which the asset will be added.
754    asset_filename: The filename of the asset to be added.
755    asset_tensor: The asset tensor used to populate the tensor info of the asset
756      proto.
757  """
758  asset_proto = meta_graph_def.asset_file_def.add()
759  asset_proto.filename = asset_filename
760  asset_proto.tensor_info.name = asset_tensor.name
761
762
763def copy_assets_to_destination_dir(asset_filename_map, destination_dir):
764  """Copy all assets from source path to destination path."""
765  assets_destination_dir = saved_model_utils.get_or_create_assets_dir(
766      destination_dir)
767
768  # Copy each asset from source path to destination path.
769  for asset_basename, asset_source_filepath in asset_filename_map.items():
770    asset_destination_filepath = os.path.join(
771        compat.as_bytes(assets_destination_dir),
772        compat.as_bytes(asset_basename))
773
774    # Only copy the asset file to the destination if it does not already
775    # exist. This is to ensure that an asset with the same name defined as
776    # part of multiple graphs is only copied the first time.
777    if not file_io.file_exists(asset_destination_filepath):
778      file_io.copy(asset_source_filepath, asset_destination_filepath)
779
780  tf_logging.info("Assets written to: %s",
781                  compat.as_text(assets_destination_dir))
782
783
784def _add_asset_to_collection(asset_filename, asset_tensor):
785  """Builds an asset proto and adds it to the asset collection of the graph.
786
787  Args:
788    asset_filename: The filename of the asset to be added.
789    asset_tensor: The asset tensor used to populate the tensor info of the
790        asset proto.
791  """
792  asset_proto = meta_graph_pb2.AssetFileDef()
793  asset_proto.filename = asset_filename
794  asset_proto.tensor_info.name = asset_tensor.name
795
796  asset_any_proto = Any()
797  asset_any_proto.Pack(asset_proto)
798  ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto)
799
800
801def _add_op_to_signature_def_map(signature_def_map, op, key):
802  if op is not None:
803    signature_def_map[key] = signature_def_utils.op_signature_def(op, key)
804