1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Options for saving SavedModels.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import enum 22import six 23 24from tensorflow.python.util import compat 25from tensorflow.python.util.tf_export import tf_export 26 27 28@tf_export("saved_model.experimental.VariablePolicy") 29class VariablePolicy(enum.Enum): 30 """Enum defining options for variable handling when saving. 31 32 NONE 33 No policy applied: Distributed variables are saved as one variable, with no 34 device attached. 35 36 SAVE_VARIABLE_DEVICES 37 When saving variables, also save their device assignment. 38 This is useful if one wants to hardcode devices in saved models, but it also 39 makes them non-portable if soft device placement is disabled (more details 40 in `tf.config.set_soft_device_placement`). This is currently not 41 fully supported by `saved_model.load`, and is mainly intended to be used 42 when one will be reading the saved model at a lower API level. In the 43 example below, the graph saved by the call to `saved_model.save` will have 44 the variable devices correctly specified: 45 ```python 46 exported = tf.train.Checkpoint() 47 with tf.device('/GPU:0'): 48 exported.x_gpu = tf.Variable(1.0) 49 with tf.device('/CPU:0'): 50 exported.x_cpu = tf.Variable(1.0) 51 tf.saved_model.save(exported, export_dir, 52 options = tf.saved_model.SaveOptions( 53 experimental_variable_policy= 54 tf.saved_model.experimental.VariablePolicy.SAVE_VARIABLE_DEVICES)) 55 ``` 56 Distributed variables are still saved as one variable under this policy. 57 58 EXPAND_DISTRIBUTED_VARIABLES 59 Distributed variables will be saved with information about their components, 60 allowing for their restoration on load. Also, the saved graph will contain 61 references to those variables. This is useful when one wants to use the 62 model for training in environments where the original distribution strategy 63 is not available. 64 """ 65 66 NONE = None 67 68 SAVE_VARIABLE_DEVICES = "save_variable_devices" 69 70 EXPAND_DISTRIBUTED_VARIABLES = "expand_distributed_variables" 71 72 def _save_variable_devices(self): 73 """Checks whether variable devices should be saved.""" 74 return self != VariablePolicy.NONE 75 76 def _expand_distributed_variables(self): 77 """Checks whether distributed variables should be expanded.""" 78 return self == VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES 79 80 @staticmethod 81 def from_obj(obj): 82 """Tries to convert `obj` to a VariablePolicy instance.""" 83 if obj is None: 84 return VariablePolicy.NONE 85 if isinstance(obj, VariablePolicy): 86 return obj 87 key = str(obj).lower() 88 for policy in VariablePolicy: 89 if key == policy.value: 90 return policy 91 raise ValueError('Invalid VariablePolicy value "%s".' % obj) 92 93 94@tf_export("saved_model.SaveOptions") 95class SaveOptions(object): 96 """Options for saving to SavedModel. 97 98 This function may be used in the `options` argument in functions that 99 save a SavedModel (`tf.saved_model.save`, `tf.keras.models.save_model`). 100 """ 101 102 # Define object attributes in __slots__ for improved memory and performance. 103 __slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases", 104 "experimental_io_device", "experimental_variable_policy") 105 106 def __init__(self, 107 namespace_whitelist=None, 108 save_debug_info=False, 109 function_aliases=None, 110 experimental_io_device=None, 111 experimental_variable_policy=None): 112 """Creates an object that stores options for SavedModel saving. 113 114 Args: 115 namespace_whitelist: List of strings containing op namespaces to whitelist 116 when saving a model. Saving an object that uses namespaced ops must 117 explicitly add all namespaces to the whitelist. The namespaced ops must 118 be registered into the framework when loading the SavedModel. 119 save_debug_info: Boolean indicating whether debug information is saved. If 120 True, then a debug/saved_model_debug_info.pb file will be written with 121 the contents of a GraphDebugInfo binary protocol buffer containing stack 122 trace information for all ops and functions that are saved. 123 function_aliases: Python dict. Mapping from string to object returned by 124 @tf.function. A single tf.function can generate many ConcreteFunctions. 125 If a downstream tool wants to refer to all concrete functions generated 126 by a single tf.function you can use the `function_aliases` argument to 127 store a map from the alias name to all concrete function names. 128 E.g. 129 130 >>> class Adder(tf.Module): 131 ... @tf.function 132 ... def double(self, x): 133 ... return x + x 134 135 >>> model = Adder() 136 >>> model.double.get_concrete_function( 137 ... tf.TensorSpec(shape=[], dtype=tf.float32, name="float_input")) 138 >>> model.double.get_concrete_function( 139 ... tf.TensorSpec(shape=[], dtype=tf.string, name="string_input")) 140 141 >>> options = tf.saved_model.SaveOptions( 142 ... function_aliases={'double': model.double}) 143 >>> tf.saved_model.save(model, '/tmp/adder', options=options) 144 145 experimental_io_device: string. Applies in a distributed setting. 146 Tensorflow device to use to access the filesystem. If `None` (default) 147 then for each variable the filesystem is accessed from the CPU:0 device 148 of the host where that variable is assigned. If specified, the 149 filesystem is instead accessed from that device for all variables. 150 151 This is for example useful if you want to save to a local directory, 152 such as "/tmp" when running in a distributed setting. In that case pass 153 a device for the host where the "/tmp" directory is accessible. 154 experimental_variable_policy: The policy to apply to variables when 155 saving. This is either a `saved_model.experimental.VariablePolicy` enum 156 instance or one of its value strings (case is not important). See that 157 enum documentation for details. A value of `None` corresponds to the 158 default policy. 159 """ 160 self.namespace_whitelist = _validate_namespace_whitelist( 161 namespace_whitelist) 162 self.save_debug_info = save_debug_info 163 self.function_aliases = function_aliases if function_aliases else dict() 164 self.experimental_io_device = experimental_io_device 165 self.experimental_variable_policy = ( 166 VariablePolicy.from_obj(experimental_variable_policy)) 167 168 169def _validate_namespace_whitelist(namespace_whitelist): 170 """Validates namespace whitelist argument.""" 171 if namespace_whitelist is None: 172 return [] 173 if not isinstance(namespace_whitelist, list): 174 raise TypeError("Namespace whitelist must be a list of strings.") 175 176 processed = [] 177 for namespace in namespace_whitelist: 178 if not isinstance(namespace, six.string_types): 179 raise ValueError("Whitelisted namespace must be a string. Got: {} of type" 180 " {}.".format(namespace, type(namespace))) 181 processed.append(compat.as_str(namespace)) 182 return processed 183