1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities supporting export to SavedModel (deprecated). 16 17This module and all its submodules are deprecated. See 18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 19for migration instructions. 20 21Some contents of this file are moved to tensorflow/python/estimator/export.py: 22 23get_input_alternatives() -> obsolete 24get_output_alternatives() -> obsolete, but see _get_default_export_output() 25build_all_signature_defs() -> build_all_signature_defs() 26get_timestamped_export_directory() -> get_timestamped_export_directory() 27_get_* -> obsolete 28_is_* -> obsolete 29 30Functionality of build_standardized_signature_def() is moved to 31tensorflow/python/estimator/export_output.py as ExportOutput.as_signature_def(). 32 33Anything to do with ExportStrategies or garbage collection is not moved. 34""" 35from __future__ import absolute_import 36from __future__ import division 37from __future__ import print_function 38 39import os 40import time 41 42from tensorflow.contrib.layers.python.layers import feature_column 43from tensorflow.contrib.learn.python.learn import export_strategy 44from tensorflow.contrib.learn.python.learn.estimators import constants 45from tensorflow.contrib.learn.python.learn.estimators import metric_key 46from tensorflow.contrib.learn.python.learn.estimators import prediction_key 47from tensorflow.contrib.learn.python.learn.utils import gc 48from tensorflow.contrib.learn.python.learn.utils import input_fn_utils 49from tensorflow.python.estimator import estimator as core_estimator 50from tensorflow.python.estimator.export import export as core_export 51from tensorflow.python.framework import dtypes 52from tensorflow.python.framework import errors_impl 53from tensorflow.python.platform import gfile 54from tensorflow.python.platform import tf_logging as logging 55from tensorflow.python.saved_model import signature_constants 56from tensorflow.python.saved_model import signature_def_utils 57from tensorflow.python.summary import summary_iterator 58from tensorflow.python.training import checkpoint_management 59from tensorflow.python.util import compat 60from tensorflow.python.util.deprecation import deprecated 61 62 63# A key for use in the input_alternatives dict indicating the default input. 64# This is the input that will be expected when a serving request does not 65# specify a specific signature. 66# The default input alternative specifies placeholders that the input_fn 67# requires to be fed (in the typical case, a single placeholder for a 68# serialized tf.Example). 69DEFAULT_INPUT_ALTERNATIVE_KEY = 'default_input_alternative' 70 71# A key for use in the input_alternatives dict indicating the features input. 72# The features inputs alternative specifies the feature Tensors provided as 73# input to the model_fn, i.e. the outputs of the input_fn. 74FEATURES_INPUT_ALTERNATIVE_KEY = 'features_input_alternative' 75 76# A key for use in the output_alternatives dict indicating the default output. 77# This is the output that will be provided when a serving request does not 78# specify a specific signature. 79# In a single-headed model, the single output is automatically the default. 80# In a multi-headed model, the name of the desired default head should be 81# provided to get_output_alternatives. 82_FALLBACK_DEFAULT_OUTPUT_ALTERNATIVE_KEY = 'default_output_alternative' 83 84 85@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') 86def build_standardized_signature_def(input_tensors, output_tensors, 87 problem_type): 88 """Build a SignatureDef using problem type and input and output Tensors. 89 90 Note that this delegates the actual creation of the signatures to methods in 91 //third_party/tensorflow/python/saved_model/signature_def_utils.py, which may 92 assign names to the input and output tensors (depending on the problem type) 93 that are standardized in the context of SavedModel. 94 95 Args: 96 input_tensors: a dict of string key to `Tensor` 97 output_tensors: a dict of string key to `Tensor` 98 problem_type: an instance of constants.ProblemType, specifying 99 classification, regression, etc. 100 101 Returns: 102 A SignatureDef using SavedModel standard keys where possible. 103 104 Raises: 105 ValueError: if input_tensors or output_tensors is None or empty. 106 """ 107 108 if not input_tensors: 109 raise ValueError('input_tensors must be provided.') 110 if not output_tensors: 111 raise ValueError('output_tensors must be provided.') 112 113 # Per-method signature_def functions will standardize the keys if possible 114 if _is_classification_problem(problem_type, input_tensors, output_tensors): 115 (_, examples), = input_tensors.items() 116 classes = _get_classification_classes(output_tensors) 117 scores = _get_classification_scores(output_tensors) 118 if classes is None and scores is None: 119 items = list(output_tensors.items()) 120 if items[0][1].dtype == dtypes.string: 121 (_, classes), = items 122 else: 123 (_, scores), = items 124 return signature_def_utils.classification_signature_def( 125 examples, classes, scores) 126 elif _is_regression_problem(problem_type, input_tensors, output_tensors): 127 (_, examples), = input_tensors.items() 128 (_, predictions), = output_tensors.items() 129 return signature_def_utils.regression_signature_def(examples, predictions) 130 else: 131 return signature_def_utils.predict_signature_def(input_tensors, 132 output_tensors) 133 134 135def _get_classification_scores(output_tensors): 136 scores = output_tensors.get(prediction_key.PredictionKey.SCORES) 137 if scores is None: 138 scores = output_tensors.get(prediction_key.PredictionKey.PROBABILITIES) 139 return scores 140 141 142def _get_classification_classes(output_tensors): 143 classes = output_tensors.get(prediction_key.PredictionKey.CLASSES) 144 if classes is not None and classes.dtype != dtypes.string: 145 # Servo classification can only serve string classes. 146 return None 147 return classes 148 149 150def _is_classification_problem(problem_type, input_tensors, output_tensors): 151 classes = _get_classification_classes(output_tensors) 152 scores = _get_classification_scores(output_tensors) 153 return ((problem_type == constants.ProblemType.CLASSIFICATION or 154 problem_type == constants.ProblemType.LOGISTIC_REGRESSION) and 155 len(input_tensors) == 1 and 156 (classes is not None or scores is not None or 157 len(output_tensors) == 1)) 158 159 160def _is_regression_problem(problem_type, input_tensors, output_tensors): 161 return (problem_type == constants.ProblemType.LINEAR_REGRESSION and 162 len(input_tensors) == 1 and len(output_tensors) == 1) 163 164 165@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') 166def get_input_alternatives(input_ops): 167 """Obtain all input alternatives using the input_fn output and heuristics.""" 168 input_alternatives = {} 169 if isinstance(input_ops, input_fn_utils.InputFnOps): 170 features, unused_labels, default_inputs = input_ops 171 input_alternatives[DEFAULT_INPUT_ALTERNATIVE_KEY] = default_inputs 172 else: 173 features, unused_labels = input_ops 174 175 if not features: 176 raise ValueError('Features must be defined.') 177 178 # TODO(b/34253951): reinstate the "features" input_signature. 179 # The "features" input_signature, as written, does not work with 180 # SparseTensors. It is simply commented out as a stopgap, pending discussion 181 # on the bug as to the correct solution. 182 183 # Add the "features" input_signature in any case. 184 # Note defensive copy because model_fns alter the features dict. 185 # input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY] = ( 186 # copy.copy(features)) 187 188 return input_alternatives, features 189 190 191@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') 192def get_output_alternatives(model_fn_ops, default_output_alternative_key=None): 193 """Obtain all output alternatives using the model_fn output and heuristics. 194 195 Args: 196 model_fn_ops: a `ModelFnOps` object produced by a `model_fn`. This may or 197 may not have output_alternatives populated. 198 default_output_alternative_key: the name of the head to serve when an 199 incoming serving request does not explicitly request a specific head. 200 Not needed for single-headed models. 201 202 Returns: 203 A tuple of (output_alternatives, actual_default_output_alternative_key), 204 where the latter names the head that will actually be served by default. 205 This may differ from the requested default_output_alternative_key when 206 a) no output_alternatives are provided at all, so one must be generated, or 207 b) there is exactly one head, which is used regardless of the requested 208 default. 209 210 Raises: 211 ValueError: if the requested default_output_alternative_key is not available 212 in output_alternatives, or if there are multiple output_alternatives and 213 no default is specified. 214 """ 215 output_alternatives = model_fn_ops.output_alternatives 216 217 if not output_alternatives: 218 if default_output_alternative_key: 219 raise ValueError('Requested default_output_alternative: {}, ' 220 'but available output_alternatives are: []'.format( 221 default_output_alternative_key)) 222 223 # Lacking provided output alternatives, the best we can do is to 224 # interpret the model as single-headed of unknown type. 225 default_problem_type = constants.ProblemType.UNSPECIFIED 226 default_outputs = model_fn_ops.predictions 227 if not isinstance(default_outputs, dict): 228 default_outputs = {prediction_key.PredictionKey.GENERIC: default_outputs} 229 actual_default_output_alternative_key = ( 230 _FALLBACK_DEFAULT_OUTPUT_ALTERNATIVE_KEY) 231 output_alternatives = { 232 actual_default_output_alternative_key: (default_problem_type, 233 default_outputs) 234 } 235 return output_alternatives, actual_default_output_alternative_key 236 237 if default_output_alternative_key: 238 # If a default head is provided, use it. 239 if default_output_alternative_key in output_alternatives: 240 return output_alternatives, default_output_alternative_key 241 242 raise ValueError('Requested default_output_alternative: {}, ' 243 'but available output_alternatives are: {}'.format( 244 default_output_alternative_key, 245 sorted(output_alternatives.keys()))) 246 247 if len(output_alternatives) == 1: 248 # If there is only one head, use it as the default regardless of its name. 249 (actual_default_output_alternative_key, _), = output_alternatives.items() 250 return output_alternatives, actual_default_output_alternative_key 251 252 raise ValueError('Please specify a default_output_alternative. ' 253 'Available output_alternatives are: {}'.format( 254 sorted(output_alternatives.keys()))) 255 256 257@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') 258def build_all_signature_defs(input_alternatives, output_alternatives, 259 actual_default_output_alternative_key): 260 """Build `SignatureDef`s from all pairs of input and output alternatives.""" 261 262 signature_def_map = {('%s:%s' % (input_key, output_key or 'None')): 263 build_standardized_signature_def(inputs, outputs, 264 problem_type) 265 for input_key, inputs in input_alternatives.items() 266 for output_key, (problem_type, 267 outputs) in output_alternatives.items()} 268 269 # Add the default SignatureDef 270 default_inputs = input_alternatives.get(DEFAULT_INPUT_ALTERNATIVE_KEY) 271 if not default_inputs: 272 raise ValueError('A default input_alternative must be provided.') 273 # default_inputs = input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY] 274 # default outputs are guaranteed to exist above 275 (default_problem_type, default_outputs) = ( 276 output_alternatives[actual_default_output_alternative_key]) 277 signature_def_map[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = ( 278 build_standardized_signature_def(default_inputs, default_outputs, 279 default_problem_type)) 280 281 return signature_def_map 282 283 284# When we create a timestamped directory, there is a small chance that the 285# directory already exists because another worker is also writing exports. 286# In this case we just wait one second to get a new timestamp and try again. 287# If this fails several times in a row, then something is seriously wrong. 288MAX_DIRECTORY_CREATION_ATTEMPTS = 10 289 290 291@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') 292def get_timestamped_export_dir(export_dir_base): 293 """Builds a path to a new subdirectory within the base directory. 294 295 Each export is written into a new subdirectory named using the 296 current time. This guarantees monotonically increasing version 297 numbers even across multiple runs of the pipeline. 298 The timestamp used is the number of seconds since epoch UTC. 299 300 Args: 301 export_dir_base: A string containing a directory to write the exported 302 graph and checkpoints. 303 Returns: 304 The full path of the new subdirectory (which is not actually created yet). 305 306 Raises: 307 RuntimeError: if repeated attempts fail to obtain a unique timestamped 308 directory name. 309 """ 310 attempts = 0 311 while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS: 312 export_timestamp = int(time.time()) 313 314 export_dir = os.path.join( 315 compat.as_bytes(export_dir_base), 316 compat.as_bytes(str(export_timestamp))) 317 if not gfile.Exists(export_dir): 318 # Collisions are still possible (though extremely unlikely): this 319 # directory is not actually created yet, but it will be almost 320 # instantly on return from this function. 321 return export_dir 322 time.sleep(1) 323 attempts += 1 324 logging.warn('Export directory {} already exists; retrying (attempt {}/{})'. 325 format(export_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)) 326 raise RuntimeError('Failed to obtain a unique export directory name after ' 327 '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS)) 328 329 330@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') 331def get_temp_export_dir(timestamped_export_dir): 332 """Builds a directory name based on the argument but starting with 'temp-'. 333 334 This relies on the fact that TensorFlow Serving ignores subdirectories of 335 the base directory that can't be parsed as integers. 336 337 Args: 338 timestamped_export_dir: the name of the eventual export directory, e.g. 339 /foo/bar/<timestamp> 340 341 Returns: 342 A sister directory prefixed with 'temp-', e.g. /foo/bar/temp-<timestamp>. 343 """ 344 (dirname, basename) = os.path.split(timestamped_export_dir) 345 temp_export_dir = os.path.join( 346 compat.as_bytes(dirname), 347 compat.as_bytes('temp-{}'.format(compat.as_text(basename)))) 348 return temp_export_dir 349 350 351# create a simple parser that pulls the export_version from the directory. 352def _export_version_parser(path): 353 filename = os.path.basename(path.path) 354 if not (len(filename) == 10 and filename.isdigit()): 355 return None 356 return path._replace(export_version=int(filename)) 357 358 359@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') 360def get_most_recent_export(export_dir_base): 361 """Locate the most recent SavedModel export in a directory of many exports. 362 363 This method assumes that SavedModel subdirectories are named as a timestamp 364 (seconds from epoch), as produced by get_timestamped_export_dir(). 365 366 Args: 367 export_dir_base: A base directory containing multiple timestamped 368 directories. 369 370 Returns: 371 A gc.Path, with is just a namedtuple of (path, export_version). 372 """ 373 select_filter = gc.largest_export_versions(1) 374 results = select_filter( 375 gc.get_paths(export_dir_base, parser=_export_version_parser)) 376 return next(iter(results or []), None) 377 378 379@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') 380def garbage_collect_exports(export_dir_base, exports_to_keep): 381 """Deletes older exports, retaining only a given number of the most recent. 382 383 Export subdirectories are assumed to be named with monotonically increasing 384 integers; the most recent are taken to be those with the largest values. 385 386 Args: 387 export_dir_base: the base directory under which each export is in a 388 versioned subdirectory. 389 exports_to_keep: the number of recent exports to retain. 390 """ 391 if exports_to_keep is None: 392 return 393 394 keep_filter = gc.largest_export_versions(exports_to_keep) 395 delete_filter = gc.negation(keep_filter) 396 for p in delete_filter( 397 gc.get_paths(export_dir_base, parser=_export_version_parser)): 398 try: 399 gfile.DeleteRecursively(p.path) 400 except errors_impl.NotFoundError as e: 401 logging.warn('Can not delete %s recursively: %s', p.path, e) 402 403 404@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') 405def make_export_strategy(serving_input_fn, 406 default_output_alternative_key=None, 407 assets_extra=None, 408 as_text=False, 409 exports_to_keep=5, 410 strip_default_attrs=None): 411 """Create an ExportStrategy for use with Experiment. 412 413 Args: 414 serving_input_fn: A function that takes no arguments and returns an 415 `InputFnOps`. 416 default_output_alternative_key: the name of the head to serve when an 417 incoming serving request does not explicitly request a specific head. 418 Must be `None` if the estimator inherits from `tf.estimator.Estimator` 419 or for single-headed models. 420 assets_extra: A dict specifying how to populate the assets.extra directory 421 within the exported SavedModel. Each key should give the destination 422 path (including the filename) relative to the assets.extra directory. 423 The corresponding value gives the full path of the source file to be 424 copied. For example, the simple case of copying a single file without 425 renaming it is specified as 426 `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. 427 as_text: whether to write the SavedModel proto in text format. 428 exports_to_keep: Number of exports to keep. Older exports will be 429 garbage-collected. Defaults to 5. Set to None to disable garbage 430 collection. 431 strip_default_attrs: Boolean. If True, default attrs in the 432 `GraphDef` will be stripped on write. This is recommended for better 433 forward compatibility of the resulting `SavedModel`. 434 435 Returns: 436 An ExportStrategy that can be passed to the Experiment constructor. 437 """ 438 439 def export_fn(estimator, export_dir_base, checkpoint_path=None, 440 strip_default_attrs=False): 441 """Exports the given Estimator as a SavedModel. 442 443 Args: 444 estimator: the Estimator to export. 445 export_dir_base: A string containing a directory to write the exported 446 graph and checkpoints. 447 checkpoint_path: The checkpoint path to export. If None (the default), 448 the most recent checkpoint found within the model directory is chosen. 449 strip_default_attrs: Boolean. If `True`, default-valued attributes will 450 be removed from the NodeDefs. 451 452 Returns: 453 The string path to the exported directory. 454 455 Raises: 456 ValueError: If `estimator` is a `tf.estimator.Estimator` instance 457 and `default_output_alternative_key` was specified. 458 """ 459 if isinstance(estimator, core_estimator.Estimator): 460 if default_output_alternative_key is not None: 461 raise ValueError( 462 'default_output_alternative_key is not supported in core ' 463 'Estimator. Given: {}'.format(default_output_alternative_key)) 464 export_result = estimator.export_savedmodel( 465 export_dir_base, 466 serving_input_fn, 467 assets_extra=assets_extra, 468 as_text=as_text, 469 checkpoint_path=checkpoint_path, 470 strip_default_attrs=strip_default_attrs) 471 else: 472 export_result = estimator.export_savedmodel( 473 export_dir_base, 474 serving_input_fn, 475 default_output_alternative_key=default_output_alternative_key, 476 assets_extra=assets_extra, 477 as_text=as_text, 478 checkpoint_path=checkpoint_path, 479 strip_default_attrs=strip_default_attrs) 480 481 garbage_collect_exports(export_dir_base, exports_to_keep) 482 return export_result 483 484 return export_strategy.ExportStrategy('Servo', export_fn, strip_default_attrs) 485 486 487@deprecated(None, 488 'Use tf.estimator.export.build_parsing_serving_input_receiver_fn') 489def make_parsing_export_strategy(feature_columns, 490 default_output_alternative_key=None, 491 assets_extra=None, 492 as_text=False, 493 exports_to_keep=5, 494 target_core=False, 495 strip_default_attrs=None): 496 """Create an ExportStrategy for use with Experiment, using `FeatureColumn`s. 497 498 Creates a SavedModel export that expects to be fed with a single string 499 Tensor containing serialized tf.Examples. At serving time, incoming 500 tf.Examples will be parsed according to the provided `FeatureColumn`s. 501 502 Args: 503 feature_columns: An iterable of `FeatureColumn`s representing the features 504 that must be provided at serving time (excluding labels!). 505 default_output_alternative_key: the name of the head to serve when an 506 incoming serving request does not explicitly request a specific head. 507 Must be `None` if the estimator inherits from `tf.estimator.Estimator` 508 or for single-headed models. 509 assets_extra: A dict specifying how to populate the assets.extra directory 510 within the exported SavedModel. Each key should give the destination 511 path (including the filename) relative to the assets.extra directory. 512 The corresponding value gives the full path of the source file to be 513 copied. For example, the simple case of copying a single file without 514 renaming it is specified as 515 `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. 516 as_text: whether to write the SavedModel proto in text format. 517 exports_to_keep: Number of exports to keep. Older exports will be 518 garbage-collected. Defaults to 5. Set to None to disable garbage 519 collection. 520 target_core: If True, prepare an ExportStrategy for use with 521 tensorflow.python.estimator.*. If False (default), prepare an 522 ExportStrategy for use with tensorflow.contrib.learn.python.learn.*. 523 strip_default_attrs: Boolean. If True, default attrs in the 524 `GraphDef` will be stripped on write. This is recommended for better 525 forward compatibility of the resulting `SavedModel`. 526 527 Returns: 528 An ExportStrategy that can be passed to the Experiment constructor. 529 """ 530 feature_spec = feature_column.create_feature_spec_for_parsing(feature_columns) 531 if target_core: 532 serving_input_fn = ( 533 core_export.build_parsing_serving_input_receiver_fn(feature_spec)) 534 else: 535 serving_input_fn = ( 536 input_fn_utils.build_parsing_serving_input_fn(feature_spec)) 537 return make_export_strategy( 538 serving_input_fn, 539 default_output_alternative_key=default_output_alternative_key, 540 assets_extra=assets_extra, 541 as_text=as_text, 542 exports_to_keep=exports_to_keep, 543 strip_default_attrs=strip_default_attrs) 544 545 546def _default_compare_fn(curr_best_eval_result, cand_eval_result): 547 """Compares two evaluation results and returns true if the 2nd one is better. 548 549 Both evaluation results should have the values for MetricKey.LOSS, which are 550 used for comparison. 551 552 Args: 553 curr_best_eval_result: current best eval metrics. 554 cand_eval_result: candidate eval metrics. 555 556 Returns: 557 True if cand_eval_result is better. 558 559 Raises: 560 ValueError: If input eval result is None or no loss is available. 561 """ 562 default_key = metric_key.MetricKey.LOSS 563 if not curr_best_eval_result or default_key not in curr_best_eval_result: 564 raise ValueError( 565 'curr_best_eval_result cannot be empty or no loss is found in it.') 566 567 if not cand_eval_result or default_key not in cand_eval_result: 568 raise ValueError( 569 'cand_eval_result cannot be empty or no loss is found in it.') 570 571 return curr_best_eval_result[default_key] > cand_eval_result[default_key] 572 573 574class BestModelSelector(object): 575 """A helper that keeps track of export selection candidates. 576 577 THIS CLASS IS DEPRECATED. See 578 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 579 for general migration instructions. 580 """ 581 582 @deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') 583 def __init__(self, event_file_pattern=None, compare_fn=None): 584 """Constructor of this class. 585 586 Args: 587 event_file_pattern: absolute event file name pattern. 588 compare_fn: a function that returns true if the candidate is better than 589 the current best model. 590 """ 591 self._compare_fn = compare_fn or _default_compare_fn 592 self._best_eval_result = self._get_best_eval_result(event_file_pattern) 593 594 def update(self, checkpoint_path, eval_result): 595 """Records a given checkpoint and exports if this is the best model. 596 597 Args: 598 checkpoint_path: the checkpoint path to export. 599 eval_result: a dictionary which is usually generated in evaluation runs. 600 By default, eval_results contains 'loss' field. 601 602 Returns: 603 A string representing the path to the checkpoint to be exported. 604 A dictionary of the same type of eval_result. 605 606 Raises: 607 ValueError: if checkpoint path is empty. 608 ValueError: if eval_results is None object. 609 """ 610 if not checkpoint_path: 611 raise ValueError('Checkpoint path is empty.') 612 if eval_result is None: 613 raise ValueError('%s has empty evaluation results.', checkpoint_path) 614 615 if (self._best_eval_result is None or 616 self._compare_fn(self._best_eval_result, eval_result)): 617 self._best_eval_result = eval_result 618 return checkpoint_path, eval_result 619 else: 620 return '', None 621 622 def _get_best_eval_result(self, event_files): 623 """Get the best eval result from event files. 624 625 Args: 626 event_files: Absolute pattern of event files. 627 628 Returns: 629 The best eval result. 630 """ 631 if not event_files: 632 return None 633 634 best_eval_result = None 635 for event_file in gfile.Glob(os.path.join(event_files)): 636 for event in summary_iterator.summary_iterator(event_file): 637 if event.HasField('summary'): 638 event_eval_result = {} 639 for value in event.summary.value: 640 if value.HasField('simple_value'): 641 event_eval_result[value.tag] = value.simple_value 642 if best_eval_result is None or self._compare_fn( 643 best_eval_result, event_eval_result): 644 best_eval_result = event_eval_result 645 return best_eval_result 646 647 648@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') 649def make_best_model_export_strategy( 650 serving_input_fn, 651 exports_to_keep=1, 652 model_dir=None, 653 event_file_pattern=None, 654 compare_fn=None, 655 default_output_alternative_key=None, 656 strip_default_attrs=None): 657 """Creates an custom ExportStrategy for use with tf.contrib.learn.Experiment. 658 659 Args: 660 serving_input_fn: a function that takes no arguments and returns an 661 `InputFnOps`. 662 exports_to_keep: an integer indicating how many historical best models need 663 to be preserved. 664 model_dir: Directory where model parameters, graph etc. are saved. This will 665 be used to load eval metrics from the directory when the export strategy 666 is created. So the best metrics would not be lost even if the export 667 strategy got preempted, which guarantees that only the best model would 668 be exported regardless of preemption. If None, however, the export 669 strategy would not be preemption-safe. To be preemption-safe, both 670 model_dir and event_file_pattern would be needed. 671 event_file_pattern: event file name pattern relative to model_dir, e.g. 672 "eval_continuous/*.tfevents.*". If None, however, the export strategy 673 would not be preemption-safe. To be preemption-safe, both 674 model_dir and event_file_pattern would be needed. 675 compare_fn: a function that select the 'best' candidate from a dictionary 676 of evaluation result keyed by corresponding checkpoint path. 677 default_output_alternative_key: the key for default serving signature for 678 multi-headed inference graphs. 679 strip_default_attrs: Boolean. If True, default attrs in the 680 `GraphDef` will be stripped on write. This is recommended for better 681 forward compatibility of the resulting `SavedModel`. 682 683 Returns: 684 An ExportStrategy that can be passed to the Experiment constructor. 685 """ 686 best_model_export_strategy = make_export_strategy( 687 serving_input_fn, 688 exports_to_keep=exports_to_keep, 689 default_output_alternative_key=default_output_alternative_key, 690 strip_default_attrs=strip_default_attrs) 691 692 full_event_file_pattern = os.path.join( 693 model_dir, 694 event_file_pattern) if model_dir and event_file_pattern else None 695 best_model_selector = BestModelSelector(full_event_file_pattern, compare_fn) 696 697 def export_fn(estimator, export_dir_base, checkpoint_path, eval_result=None): 698 """Exports the given Estimator as a SavedModel. 699 700 Args: 701 estimator: the Estimator to export. 702 export_dir_base: A string containing a directory to write the exported 703 graph and checkpoints. 704 checkpoint_path: The checkpoint path to export. If None (the default), 705 the most recent checkpoint found within the model directory is chosen. 706 eval_result: placehold args matching the call signature of ExportStrategy. 707 708 Returns: 709 The string path to the exported directory. 710 """ 711 if not checkpoint_path: 712 # TODO(b/67425018): switch to 713 # checkpoint_path = estimator.latest_checkpoint() 714 # as soon as contrib is cleaned up and we can thus be sure that 715 # estimator is a tf.estimator.Estimator and not a 716 # tf.contrib.learn.Estimator 717 checkpoint_path = checkpoint_management.latest_checkpoint( 718 estimator.model_dir) 719 export_checkpoint_path, export_eval_result = best_model_selector.update( 720 checkpoint_path, eval_result) 721 722 if export_checkpoint_path and export_eval_result is not None: 723 checkpoint_base = os.path.basename(export_checkpoint_path) 724 export_dir = os.path.join(export_dir_base, checkpoint_base) 725 return best_model_export_strategy.export( 726 estimator, export_dir, export_checkpoint_path, export_eval_result) 727 else: 728 return '' 729 730 return export_strategy.ExportStrategy('best_model', export_fn) 731 732 733# TODO(b/67013778): Revisit this approach when corresponding changes to 734# TF Core are finalized. 735@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') 736def extend_export_strategy(base_export_strategy, 737 post_export_fn, 738 post_export_name=None): 739 """Extend ExportStrategy, calling post_export_fn after export. 740 741 Args: 742 base_export_strategy: An ExportStrategy that can be passed to the Experiment 743 constructor. 744 post_export_fn: A user-specified function to call after exporting the 745 SavedModel. Takes two arguments - the path to the SavedModel exported by 746 base_export_strategy and the directory where to export the SavedModel 747 modified by the post_export_fn. Returns the path to the exported 748 SavedModel. 749 post_export_name: The directory name under the export base directory where 750 SavedModels generated by the post_export_fn will be written. If None, the 751 directory name of base_export_strategy is used. 752 753 Returns: 754 An ExportStrategy that can be passed to the Experiment constructor. 755 """ 756 def export_fn(estimator, export_dir_base, checkpoint_path=None): 757 """Exports the given Estimator as a SavedModel and invokes post_export_fn. 758 759 Args: 760 estimator: the Estimator to export. 761 export_dir_base: A string containing a directory to write the exported 762 graphs and checkpoint. 763 checkpoint_path: The checkpoint path to export. If None (the default), 764 the most recent checkpoint found within the model directory is chosen. 765 766 Returns: 767 The string path to the SavedModel indicated by post_export_fn. 768 769 Raises: 770 ValueError: If `estimator` is a `tf.estimator.Estimator` instance 771 and `default_output_alternative_key` was specified or if post_export_fn 772 does not return a valid directory. 773 RuntimeError: If unable to create temporary or final export directory. 774 """ 775 tmp_base_export_folder = 'temp-base-export-' + str(int(time.time())) 776 tmp_base_export_dir = os.path.join(export_dir_base, tmp_base_export_folder) 777 if gfile.Exists(tmp_base_export_dir): 778 raise RuntimeError('Failed to obtain base export directory') 779 gfile.MakeDirs(tmp_base_export_dir) 780 tmp_base_export = base_export_strategy.export( 781 estimator, tmp_base_export_dir, checkpoint_path) 782 783 tmp_post_export_folder = 'temp-post-export-' + str(int(time.time())) 784 tmp_post_export_dir = os.path.join(export_dir_base, tmp_post_export_folder) 785 if gfile.Exists(tmp_post_export_dir): 786 raise RuntimeError('Failed to obtain temp export directory') 787 788 gfile.MakeDirs(tmp_post_export_dir) 789 tmp_post_export = post_export_fn(tmp_base_export, tmp_post_export_dir) 790 791 if not tmp_post_export.startswith(tmp_post_export_dir): 792 raise ValueError('post_export_fn must return a sub-directory of {}' 793 .format(tmp_post_export_dir)) 794 post_export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir) 795 post_export = os.path.join(export_dir_base, post_export_relpath) 796 if gfile.Exists(post_export): 797 raise RuntimeError('Failed to obtain final export directory') 798 gfile.Rename(tmp_post_export, post_export) 799 800 gfile.DeleteRecursively(tmp_base_export_dir) 801 gfile.DeleteRecursively(tmp_post_export_dir) 802 return post_export 803 804 name = post_export_name if post_export_name else base_export_strategy.name 805 return export_strategy.ExportStrategy(name, export_fn) 806