• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Options for saving SavedModels."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import enum
22import six
23
24from tensorflow.python.util import compat
25from tensorflow.python.util.tf_export import tf_export
26
27
28@tf_export("saved_model.experimental.VariablePolicy")
29class VariablePolicy(enum.Enum):
30  """Enum defining options for variable handling when saving.
31
32  NONE
33    No policy applied: Distributed variables are saved as one variable, with no
34    device attached.
35
36  SAVE_VARIABLE_DEVICES
37    When saving variables, also save their device assignment.
38    This is useful if one wants to hardcode devices in saved models, but it also
39    makes them non-portable if soft device placement is disabled (more details
40    in `tf.config.set_soft_device_placement`). This is currently not
41    fully supported by `saved_model.load`, and is mainly intended to be used
42    when one will be reading the saved model at a lower API level. In the
43    example below, the graph saved by the call to `saved_model.save` will have
44    the variable devices correctly specified:
45    ```python
46    exported = tf.train.Checkpoint()
47    with tf.device('/GPU:0'):
48      exported.x_gpu = tf.Variable(1.0)
49    with tf.device('/CPU:0'):
50      exported.x_cpu = tf.Variable(1.0)
51    tf.saved_model.save(exported, export_dir,
52        options = tf.saved_model.SaveOptions(
53            experimental_variable_policy=
54              tf.saved_model.experimental.VariablePolicy.SAVE_VARIABLE_DEVICES))
55    ```
56    Distributed variables are still saved as one variable under this policy.
57
58  EXPAND_DISTRIBUTED_VARIABLES
59    Distributed variables will be saved with information about their components,
60    allowing for their restoration on load. Also, the saved graph will contain
61    references to those variables. This is useful when one wants to use the
62    model for training in environments where the original distribution strategy
63    is not available.
64  """
65
66  NONE = None
67
68  SAVE_VARIABLE_DEVICES = "save_variable_devices"
69
70  EXPAND_DISTRIBUTED_VARIABLES = "expand_distributed_variables"
71
72  def _save_variable_devices(self):
73    """Checks whether variable devices should be saved."""
74    return self != VariablePolicy.NONE
75
76  def _expand_distributed_variables(self):
77    """Checks whether distributed variables should be expanded."""
78    return self == VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES
79
80  @staticmethod
81  def from_obj(obj):
82    """Tries to convert `obj` to a VariablePolicy instance."""
83    if obj is None:
84      return VariablePolicy.NONE
85    if isinstance(obj, VariablePolicy):
86      return obj
87    key = str(obj).lower()
88    for policy in VariablePolicy:
89      if key == policy.value:
90        return policy
91    raise ValueError(f"Received invalid VariablePolicy value: {obj}.")
92
93
94@tf_export("saved_model.SaveOptions")
95class SaveOptions(object):
96  """Options for saving to SavedModel.
97
98  This function may be used in the `options` argument in functions that
99  save a SavedModel (`tf.saved_model.save`, `tf.keras.models.save_model`).
100  """
101
102  # Define object attributes in __slots__ for improved memory and performance.
103  __slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases",
104               "experimental_io_device", "experimental_variable_policy",
105               "experimental_custom_gradients")
106
107  def __init__(self,
108               namespace_whitelist=None,
109               save_debug_info=False,
110               function_aliases=None,
111               experimental_io_device=None,
112               experimental_variable_policy=None,
113               experimental_custom_gradients=True):
114    """Creates an object that stores options for SavedModel saving.
115
116    Args:
117      namespace_whitelist: List of strings containing op namespaces to whitelist
118        when saving a model. Saving an object that uses namespaced ops must
119        explicitly add all namespaces to the whitelist. The namespaced ops must
120        be registered into the framework when loading the SavedModel. If no
121        whitelist is provided, all namespaced ops will be allowed.
122      save_debug_info: Boolean indicating whether debug information is saved. If
123        True, then a debug/saved_model_debug_info.pb file will be written with
124        the contents of a GraphDebugInfo binary protocol buffer containing stack
125        trace information for all ops and functions that are saved.
126      function_aliases: Python dict. Mapping from string to object returned by
127        @tf.function. A single tf.function can generate many ConcreteFunctions.
128        If a downstream tool wants to refer to all concrete functions generated
129        by a single tf.function you can use the `function_aliases` argument to
130        store a map from the alias name to all concrete function names.
131        E.g.
132
133        >>> class Adder(tf.Module):
134        ...   @tf.function
135        ...   def double(self, x):
136        ...     return x + x
137
138        >>> model = Adder()
139        >>> model.double.get_concrete_function(
140        ...   tf.TensorSpec(shape=[], dtype=tf.float32, name="float_input"))
141        >>> model.double.get_concrete_function(
142        ...   tf.TensorSpec(shape=[], dtype=tf.string, name="string_input"))
143
144        >>> options = tf.saved_model.SaveOptions(
145        ...   function_aliases={'double': model.double})
146        >>> tf.saved_model.save(model, '/tmp/adder', options=options)
147
148      experimental_io_device: string. Applies in a distributed setting.
149        Tensorflow device to use to access the filesystem. If `None` (default)
150        then for each variable the filesystem is accessed from the CPU:0 device
151        of the host where that variable is assigned. If specified, the
152        filesystem is instead accessed from that device for all variables.
153
154        This is for example useful if you want to save to a local directory,
155        such as "/tmp" when running in a distributed setting. In that case pass
156        a device for the host where the "/tmp" directory is accessible.
157      experimental_variable_policy: The policy to apply to variables when
158        saving. This is either a `saved_model.experimental.VariablePolicy` enum
159        instance or one of its value strings (case is not important). See that
160        enum documentation for details. A value of `None` corresponds to the
161        default policy.
162      experimental_custom_gradients: Boolean. When True, will save traced
163        gradient functions for the functions decorated by `tf.custom_gradient`.
164        Defaults to `True`.
165    """
166    self.namespace_whitelist = _validate_namespace_whitelist(
167        namespace_whitelist)
168    self.save_debug_info = save_debug_info
169    self.function_aliases = function_aliases if function_aliases else dict()
170    self.experimental_custom_gradients = experimental_custom_gradients
171    self.experimental_io_device = experimental_io_device
172    self.experimental_variable_policy = (
173        VariablePolicy.from_obj(experimental_variable_policy))
174
175
176def _validate_namespace_whitelist(namespace_whitelist):
177  """Validates namespace whitelist argument."""
178  if namespace_whitelist is None:
179    return None
180  if not isinstance(namespace_whitelist, list):
181    raise TypeError("`namespace_whitelist` must be a list of strings. Got: "
182                    f"{namespace_whitelist} with type "
183                    f"{type(namespace_whitelist)}.")
184
185  processed = []
186  for namespace in namespace_whitelist:
187    if not isinstance(namespace, six.string_types):
188      raise ValueError("Whitelisted namespace must be a string. Got: "
189                       f"{namespace} of type {type(namespace)}.")
190    processed.append(compat.as_str(namespace))
191  return processed
192