1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for creating SavedModels.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import os 23import time 24 25import six 26 27from tensorflow.python.platform import gfile 28from tensorflow.python.platform import tf_logging as logging 29from tensorflow.python.saved_model import signature_constants 30from tensorflow.python.saved_model import signature_def_utils 31from tensorflow.python.saved_model import tag_constants 32from tensorflow.python.saved_model.model_utils import export_output as export_output_lib 33from tensorflow.python.saved_model.model_utils import mode_keys 34from tensorflow.python.saved_model.model_utils.mode_keys import KerasModeKeys as ModeKeys 35from tensorflow.python.util import compat 36 37 38# Mapping of the modes to appropriate MetaGraph tags in the SavedModel. 39EXPORT_TAG_MAP = mode_keys.ModeKeyMap(**{ 40 ModeKeys.PREDICT: [tag_constants.SERVING], 41 ModeKeys.TRAIN: [tag_constants.TRAINING], 42 ModeKeys.TEST: [tag_constants.EVAL]}) 43 44# For every exported mode, a SignatureDef map should be created using the 45# functions `export_outputs_for_mode` and `build_all_signature_defs`. By 46# default, this map will contain a single Signature that defines the input 47# tensors and output predictions, losses, and/or metrics (depending on the mode) 48# The default keys used in the SignatureDef map are defined below. 49SIGNATURE_KEY_MAP = mode_keys.ModeKeyMap(**{ 50 ModeKeys.PREDICT: signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, 51 ModeKeys.TRAIN: signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY, 52 ModeKeys.TEST: signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY}) 53 54# Default names used in the SignatureDef input map, which maps strings to 55# TensorInfo protos. 56SINGLE_FEATURE_DEFAULT_NAME = 'feature' 57SINGLE_RECEIVER_DEFAULT_NAME = 'input' 58SINGLE_LABEL_DEFAULT_NAME = 'label' 59 60### Below utilities are specific to SavedModel exports. 61 62 63def build_all_signature_defs(receiver_tensors, 64 export_outputs, 65 receiver_tensors_alternatives=None, 66 serving_only=True): 67 """Build `SignatureDef`s for all export outputs. 68 69 Args: 70 receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying 71 input nodes where this receiver expects to be fed by default. Typically, 72 this is a single placeholder expecting serialized `tf.Example` protos. 73 export_outputs: a dict of ExportOutput instances, each of which has 74 an as_signature_def instance method that will be called to retrieve 75 the signature_def for all export output tensors. 76 receiver_tensors_alternatives: a dict of string to additional 77 groups of receiver tensors, each of which may be a `Tensor` or a dict of 78 string to `Tensor`. These named receiver tensor alternatives generate 79 additional serving signatures, which may be used to feed inputs at 80 different points within the input receiver subgraph. A typical usage is 81 to allow feeding raw feature `Tensor`s *downstream* of the 82 tf.parse_example() op. Defaults to None. 83 serving_only: boolean; if true, resulting signature defs will only include 84 valid serving signatures. If false, all requested signatures will be 85 returned. 86 87 Returns: 88 signature_def representing all passed args. 89 90 Raises: 91 ValueError: if export_outputs is not a dict 92 """ 93 if not isinstance(receiver_tensors, dict): 94 receiver_tensors = {SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} 95 if export_outputs is None or not isinstance(export_outputs, dict): 96 raise ValueError('export_outputs must be a dict and not' 97 '{}'.format(type(export_outputs))) 98 99 signature_def_map = {} 100 excluded_signatures = {} 101 for output_key, export_output in export_outputs.items(): 102 signature_name = '{}'.format(output_key or 'None') 103 try: 104 signature = export_output.as_signature_def(receiver_tensors) 105 signature_def_map[signature_name] = signature 106 except ValueError as e: 107 excluded_signatures[signature_name] = str(e) 108 109 if receiver_tensors_alternatives: 110 for receiver_name, receiver_tensors_alt in ( 111 six.iteritems(receiver_tensors_alternatives)): 112 if not isinstance(receiver_tensors_alt, dict): 113 receiver_tensors_alt = { 114 SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors_alt 115 } 116 for output_key, export_output in export_outputs.items(): 117 signature_name = '{}:{}'.format(receiver_name or 'None', output_key or 118 'None') 119 try: 120 signature = export_output.as_signature_def(receiver_tensors_alt) 121 signature_def_map[signature_name] = signature 122 except ValueError as e: 123 excluded_signatures[signature_name] = str(e) 124 125 _log_signature_report(signature_def_map, excluded_signatures) 126 127 # The above calls to export_output_lib.as_signature_def should return only 128 # valid signatures; if there is a validity problem, they raise a ValueError, 129 # in which case we exclude that signature from signature_def_map above. 130 # The is_valid_signature check ensures that the signatures produced are 131 # valid for serving, and acts as an additional sanity check for export 132 # signatures produced for serving. We skip this check for training and eval 133 # signatures, which are not intended for serving. 134 if serving_only: 135 signature_def_map = { 136 k: v 137 for k, v in signature_def_map.items() 138 if signature_def_utils.is_valid_signature(v) 139 } 140 return signature_def_map 141 142 143_FRIENDLY_METHOD_NAMES = { 144 signature_constants.CLASSIFY_METHOD_NAME: 'Classify', 145 signature_constants.REGRESS_METHOD_NAME: 'Regress', 146 signature_constants.PREDICT_METHOD_NAME: 'Predict', 147 signature_constants.SUPERVISED_TRAIN_METHOD_NAME: 'Train', 148 signature_constants.SUPERVISED_EVAL_METHOD_NAME: 'Eval', 149} 150 151 152def _log_signature_report(signature_def_map, excluded_signatures): 153 """Log a report of which signatures were produced.""" 154 sig_names_by_method_name = collections.defaultdict(list) 155 156 # We'll collect whatever method_names are present, but also we want to make 157 # sure to output a line for each of the three standard methods even if they 158 # have no signatures. 159 for method_name in _FRIENDLY_METHOD_NAMES: 160 sig_names_by_method_name[method_name] = [] 161 162 for signature_name, sig in signature_def_map.items(): 163 sig_names_by_method_name[sig.method_name].append(signature_name) 164 165 # TODO(b/67733540): consider printing the full signatures, not just names 166 for method_name, sig_names in sig_names_by_method_name.items(): 167 if method_name in _FRIENDLY_METHOD_NAMES: 168 method_name = _FRIENDLY_METHOD_NAMES[method_name] 169 logging.info('Signatures INCLUDED in export for {}: {}'.format( 170 method_name, sig_names if sig_names else 'None')) 171 172 if excluded_signatures: 173 logging.info('Signatures EXCLUDED from export because they cannot be ' 174 'be served via TensorFlow Serving APIs:') 175 for signature_name, message in excluded_signatures.items(): 176 logging.info('\'{}\' : {}'.format(signature_name, message)) 177 178 if not signature_def_map: 179 logging.warn('Export includes no signatures!') 180 elif (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY not in 181 signature_def_map): 182 logging.warn('Export includes no default signature!') 183 184 185# When we create a timestamped directory, there is a small chance that the 186# directory already exists because another process is also creating these 187# directories. In this case we just wait one second to get a new timestamp and 188# try again. If this fails several times in a row, then something is seriously 189# wrong. 190MAX_DIRECTORY_CREATION_ATTEMPTS = 10 191 192 193def get_timestamped_export_dir(export_dir_base): 194 """Builds a path to a new subdirectory within the base directory. 195 196 Each export is written into a new subdirectory named using the 197 current time. This guarantees monotonically increasing version 198 numbers even across multiple runs of the pipeline. 199 The timestamp used is the number of seconds since epoch UTC. 200 201 Args: 202 export_dir_base: A string containing a directory to write the exported 203 graph and checkpoints. 204 Returns: 205 The full path of the new subdirectory (which is not actually created yet). 206 207 Raises: 208 RuntimeError: if repeated attempts fail to obtain a unique timestamped 209 directory name. 210 """ 211 attempts = 0 212 while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS: 213 timestamp = int(time.time()) 214 215 result_dir = os.path.join( 216 compat.as_bytes(export_dir_base), compat.as_bytes(str(timestamp))) 217 if not gfile.Exists(result_dir): 218 # Collisions are still possible (though extremely unlikely): this 219 # directory is not actually created yet, but it will be almost 220 # instantly on return from this function. 221 return result_dir 222 time.sleep(1) 223 attempts += 1 224 logging.warn('Directory {} already exists; retrying (attempt {}/{})'.format( 225 result_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)) 226 raise RuntimeError('Failed to obtain a unique export directory name after ' 227 '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS)) 228 229 230def get_temp_export_dir(timestamped_export_dir): 231 """Builds a directory name based on the argument but starting with 'temp-'. 232 233 This relies on the fact that TensorFlow Serving ignores subdirectories of 234 the base directory that can't be parsed as integers. 235 236 Args: 237 timestamped_export_dir: the name of the eventual export directory, e.g. 238 /foo/bar/<timestamp> 239 240 Returns: 241 A sister directory prefixed with 'temp-', e.g. /foo/bar/temp-<timestamp>. 242 """ 243 (dirname, basename) = os.path.split(timestamped_export_dir) 244 temp_export_dir = os.path.join( 245 compat.as_bytes(dirname), compat.as_bytes('temp-{}'.format(basename))) 246 return temp_export_dir 247 248 249def export_outputs_for_mode( 250 mode, serving_export_outputs=None, predictions=None, loss=None, 251 metrics=None): 252 """Util function for constructing a `ExportOutput` dict given a mode. 253 254 The returned dict can be directly passed to `build_all_signature_defs` helper 255 function as the `export_outputs` argument, used for generating a SignatureDef 256 map. 257 258 Args: 259 mode: A `ModeKeys` specifying the mode. 260 serving_export_outputs: Describes the output signatures to be exported to 261 `SavedModel` and used during serving. Should be a dict or None. 262 predictions: A dict of Tensors or single Tensor representing model 263 predictions. This argument is only used if serving_export_outputs is not 264 set. 265 loss: A dict of Tensors or single Tensor representing calculated loss. 266 metrics: A dict of (metric_value, update_op) tuples, or a single tuple. 267 metric_value must be a Tensor, and update_op must be a Tensor or Op 268 269 Returns: 270 Dictionary mapping the a key to an `tf.estimator.export.ExportOutput` object 271 The key is the expected SignatureDef key for the mode. 272 273 Raises: 274 ValueError: if an appropriate ExportOutput cannot be found for the mode. 275 """ 276 if mode not in SIGNATURE_KEY_MAP: 277 raise ValueError( 278 'Export output type not found for mode: {}. Expected one of: {}.\n' 279 'One likely error is that V1 Estimator Modekeys were somehow passed to ' 280 'this function. Please ensure that you are using the new ModeKeys.' 281 .format(mode, SIGNATURE_KEY_MAP.keys())) 282 signature_key = SIGNATURE_KEY_MAP[mode] 283 if mode_keys.is_predict(mode): 284 return get_export_outputs(serving_export_outputs, predictions) 285 elif mode_keys.is_train(mode): 286 return {signature_key: export_output_lib.TrainOutput( 287 loss=loss, predictions=predictions, metrics=metrics)} 288 else: 289 return {signature_key: export_output_lib.EvalOutput( 290 loss=loss, predictions=predictions, metrics=metrics)} 291 292 293def get_export_outputs(export_outputs, predictions): 294 """Validate export_outputs or create default export_outputs. 295 296 Args: 297 export_outputs: Describes the output signatures to be exported to 298 `SavedModel` and used during serving. Should be a dict or None. 299 predictions: Predictions `Tensor` or dict of `Tensor`. 300 301 Returns: 302 Valid export_outputs dict 303 304 Raises: 305 TypeError: if export_outputs is not a dict or its values are not 306 ExportOutput instances. 307 """ 308 if export_outputs is None: 309 default_output = export_output_lib.PredictOutput(predictions) 310 export_outputs = { 311 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: default_output} 312 313 if not isinstance(export_outputs, dict): 314 raise TypeError('export_outputs must be dict, given: {}'.format( 315 export_outputs)) 316 for v in six.itervalues(export_outputs): 317 if not isinstance(v, export_output_lib.ExportOutput): 318 raise TypeError( 319 'Values in export_outputs must be ExportOutput objects. ' 320 'Given: {}'.format(export_outputs)) 321 322 _maybe_add_default_serving_output(export_outputs) 323 324 return export_outputs 325 326 327def _maybe_add_default_serving_output(export_outputs): 328 """Add a default serving output to the export_outputs if not present. 329 330 Args: 331 export_outputs: Describes the output signatures to be exported to 332 `SavedModel` and used during serving. Should be a dict. 333 334 Returns: 335 export_outputs dict with default serving signature added if necessary 336 337 Raises: 338 ValueError: if multiple export_outputs were provided without a default 339 serving key. 340 """ 341 if len(export_outputs) == 1: 342 (key, value), = export_outputs.items() 343 if key != signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 344 export_outputs[ 345 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = value 346 if len(export_outputs) > 1: 347 if (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 348 not in export_outputs): 349 raise ValueError( 350 'Multiple export_outputs were provided, but none of them is ' 351 'specified as the default. Do this by naming one of them with ' 352 'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.') 353 354 return export_outputs 355