1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Experiment class collecting information for a single training run (deprecated). 16 17This module and all its submodules are deprecated. See 18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 19for migration instructions. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import contextlib 27import functools 28import math 29import os 30import time 31 32from tensorflow.contrib.framework import deprecated 33from tensorflow.contrib.framework.python.framework import experimental 34from tensorflow.contrib.learn.python.learn import evaluable 35from tensorflow.contrib.learn.python.learn import export_strategy 36from tensorflow.contrib.learn.python.learn import monitors 37from tensorflow.contrib.learn.python.learn import trainable 38from tensorflow.contrib.learn.python.learn.estimators import run_config 39from tensorflow.contrib.tpu.python.tpu import tpu_estimator 40from tensorflow.python.estimator import estimator as core_estimator 41from tensorflow.python.framework import ops 42from tensorflow.python.platform import tf_logging as logging 43from tensorflow.python.training import basic_session_run_hooks 44from tensorflow.python.training import checkpoint_management 45from tensorflow.python.training import server_lib 46from tensorflow.python.util import compat 47from tensorflow.python.util import function_utils 48 49__all__ = ["Experiment"] 50 51 52def _get_standardized_predicate_fn(predicate_fn): 53 pred_fn_args = function_utils.fn_args(predicate_fn) 54 if "checkpoint_path" not in pred_fn_args: 55 # pylint: disable=unused-argument 56 def _pred_fn_wrapper(eval_results, checkpoint_path): 57 return predicate_fn(eval_results) 58 59 return _pred_fn_wrapper 60 else: 61 return predicate_fn 62 63 64class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener): 65 """Listener that evaluates and exports a model after creating a checkpoint. 66 67 The `EvalAndExportListener` waits for the associated `CheckpointSaverHook` 68 to save a checkpoint. It then uses the provided `eval_fn` and `export_fn` to 69 first evaluate the model using the newly-created checkpoint, and then export 70 the model according to the `export_strategies` provided in the `Experiment`. 71 72 This listener is experimental and may be changed or removed in the future. 73 """ 74 75 def __init__(self, eval_fn, export_fn, model_dir): 76 """Initializes an `EvalAndExportListener`. 77 78 Args: 79 eval_fn: function which evaluates the model with the following signature: 80 `(name, checkpoint_path) -> eval_result` 81 export_fn: function which exports the model according to a set of export 82 strategies. Has the following signature: 83 `(eval_result, checkpoint_path) -> export_results` 84 model_dir: directory which contains estimator parameters and checkpoints. 85 """ 86 self._eval_fn = eval_fn 87 self._export_fn = export_fn 88 self._model_dir = model_dir 89 self._latest_path = None 90 self._eval_result = None 91 self._export_results = None 92 93 def after_save(self, session, global_step_value): 94 """Evaluates and exports the model after a checkpoint is created.""" 95 # Load and cache the path of the most recent checkpoint to avoid duplicate 96 # searches on GCS. 97 logging.info("Checking for checkpoint in %s", self._model_dir) 98 latest_path = checkpoint_management.latest_checkpoint(self._model_dir) 99 100 if not latest_path: 101 logging.warning("Skipping evaluation and export since model has not been " 102 "saved yet.") 103 elif latest_path == self._latest_path: 104 logging.warning("Skipping evaluation due to same latest checkpoint %s.", 105 latest_path) 106 else: 107 self._latest_path = latest_path 108 self._eval_result = self._eval_fn( 109 name="intermediate_export", checkpoint_path=latest_path) 110 self._export_results = self._export_fn( 111 self._eval_result, checkpoint_path=latest_path) 112 113 @property 114 def eval_result(self): 115 return self._eval_result 116 117 @property 118 def export_results(self): 119 return self._export_results 120 121 122class Experiment(object): 123 """Experiment is a class containing all information needed to train a model. 124 125 THIS CLASS IS DEPRECATED. See 126 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 127 for general migration instructions. 128 129 After an experiment is created (by passing an Estimator and inputs for 130 training and evaluation), an Experiment instance knows how to invoke training 131 and eval loops in a sensible fashion for distributed training. 132 """ 133 134 # TODO(ispir): remove delay_workers_by_global_step and make global step based 135 # waiting as only behavior. 136 @deprecated(None, "Please switch to tf.estimator.train_and_evaluate. You will" 137 " also have to convert to a tf.estimator.Estimator.") 138 def __init__(self, 139 estimator, 140 train_input_fn, 141 eval_input_fn, 142 eval_metrics=None, 143 train_steps=None, 144 eval_steps=100, 145 train_monitors=None, 146 eval_hooks=None, 147 local_eval_frequency=None, 148 eval_delay_secs=120, 149 continuous_eval_throttle_secs=60, 150 min_eval_frequency=None, 151 delay_workers_by_global_step=False, 152 export_strategies=None, 153 train_steps_per_iteration=None, 154 checkpoint_and_export=False, 155 saving_listeners=None, 156 check_interval_secs=5): 157 """Constructor for `Experiment`. 158 159 Creates an Experiment instance. None of the functions passed to this 160 constructor are executed at construction time. They are stored and used 161 when a method is executed which requires it. 162 163 Args: 164 estimator: Object implementing Estimator interface, which could be a 165 combination of `tf.contrib.learn.Trainable` and 166 `tf.contrib.learn.Evaluable` (deprecated), or 167 `tf.estimator.Estimator`. 168 train_input_fn: function, returns features and labels for training. 169 eval_input_fn: function, returns features and labels for evaluation. If 170 `eval_steps` is `None`, this should be configured only to produce for a 171 finite number of batches (generally, 1 epoch over the evaluation data). 172 eval_metrics: `dict` of string, metric function. If `None`, default set 173 is used. This should be `None` if the `estimator` is 174 `tf.estimator.Estimator`. If metrics are provided they will be 175 *appended* to the default set. 176 train_steps: Perform this many steps of training. `None`, the default, 177 means train forever. 178 eval_steps: `evaluate` runs until input is exhausted (or another exception 179 is raised), or for `eval_steps` steps, if specified. 180 train_monitors: A list of monitors to pass to the `Estimator`'s `fit` 181 function. 182 eval_hooks: A list of `SessionRunHook` hooks to pass to the 183 `Estimator`'s `evaluate` function. 184 local_eval_frequency: (applies only to local_run) Frequency of running 185 eval in steps. If `None`, runs evaluation only at the end of training. 186 eval_delay_secs: Start evaluating after waiting for this many seconds. 187 continuous_eval_throttle_secs: Do not re-evaluate unless the last 188 evaluation was started at least this many seconds ago for 189 continuous_eval(). 190 min_eval_frequency: (applies only to train_and_evaluate). the minimum 191 number of steps between evaluations. Of course, evaluation does not 192 occur if no new snapshot is available, hence, this is the minimum. 193 If 0, the evaluation will only happen after training. 194 If None, defaults to 1. To avoid checking for new checkpoints too 195 frequent, the interval is further limited to be at least 196 check_interval_secs between checks. 197 delay_workers_by_global_step: if `True` delays training workers 198 based on global step instead of time. 199 export_strategies: Iterable of `ExportStrategy`s, or a single one, or 200 `None`. 201 train_steps_per_iteration: (applies only to continuous_train_and_eval). 202 Perform this many (integer) number of train steps for each 203 training-evaluation iteration. With a small value, the model will be 204 evaluated more frequently with more checkpoints saved. If `None`, will 205 use a default value (which is smaller than `train_steps` if provided). 206 checkpoint_and_export: (applies only to train_and_evaluate). If `True`, 207 performs intermediate model checkpoints and exports during the training 208 process, rather than only once model training is complete. This 209 parameter is experimental and may be changed or removed in the future. 210 Setting this parameter leads to the following: the value of 211 `min_eval_frequency` will be ignored, and the number of steps between 212 evaluations and exports will instead be determined by the Estimator 213 configuration parameters `save_checkpoints_secs` and 214 `save_checkpoints_steps`. Also, this parameter leads to the creation of 215 a default `CheckpointSaverHook` instead of a `ValidationMonitor`, so the 216 provided `train_monitors` will need to be adjusted accordingly. 217 saving_listeners: list of `CheckpointSaverListener` objects. Used by 218 tf.estimator.Estimator for callbacks that run immediately before or 219 after checkpoint savings. 220 check_interval_secs: 221 Minimum time between subsequent checks for a new checkpoint. This 222 mostly applies if both min_eval_frequency and the time spent per 223 training step is low. 224 Raises: 225 ValueError: if `estimator` does not implement Estimator interface, 226 or if export_strategies has the wrong type. 227 """ 228 if isinstance(estimator, core_estimator.Estimator): 229 self._core_estimator_used = True 230 if eval_metrics is not None: 231 raise ValueError( 232 "`eval_metrics` must be `None` with `tf.estimator.Estimator`. " 233 "Use `eval_metric_ops` in `tf.estimator.EstimatorSpec` instead.") 234 else: 235 self._core_estimator_used = False 236 if not isinstance(estimator, evaluable.Evaluable): 237 raise ValueError( 238 "`estimator` must implement `tf.contrib.learn.Evaluable` " 239 "or `tf.estimator.Estimator`.") 240 if not isinstance(estimator, trainable.Trainable): 241 raise ValueError( 242 "`estimator` must implement `tf.contrib.learn.Trainable`" 243 "or `tf.estimator.`Estimator`.") 244 if saving_listeners is not None: 245 raise ValueError("`saving_listeners` must be `None` with " 246 "`tf.contrib.learn.Estimator`.") 247 248 if isinstance(estimator, tpu_estimator.TPUEstimator): 249 logging.warn( 250 "`Experiment` class cannot work with `tf.contrib.tpu.TPUEstimator`. " 251 "Please call `TPUEstimator` train/evaluate directly. \n" 252 "Details: `Experiment` class is designed for between-graph " 253 "distributed training, while `TPUEstimator` is working in in-graph " 254 "distributed mode. Use with care.") 255 256 super(Experiment, self).__init__() 257 # Immutable fields. 258 self._estimator = estimator 259 self._train_input_fn = train_input_fn 260 self._eval_input_fn = eval_input_fn 261 self._eval_metrics = eval_metrics 262 self._train_steps = train_steps 263 self._eval_steps = eval_steps 264 self._local_eval_frequency = local_eval_frequency 265 self._eval_delay_secs = eval_delay_secs 266 self._continuous_eval_throttle_secs = continuous_eval_throttle_secs 267 self._checkpoint_and_export = checkpoint_and_export 268 self._saving_listeners = saving_listeners 269 self._min_eval_frequency = min_eval_frequency if ( 270 min_eval_frequency is not None) else 1 271 self._check_interval_secs = check_interval_secs 272 self._delay_workers_by_global_step = delay_workers_by_global_step 273 self._train_monitors = train_monitors[:] if train_monitors else [] 274 self._eval_hooks = eval_hooks[:] if eval_hooks else [] 275 self._set_export_strategies(export_strategies) 276 277 self._train_steps_per_iteration = train_steps_per_iteration 278 if (self._train_steps_per_iteration is not None and 279 not isinstance(self._train_steps_per_iteration, int)): 280 raise ValueError("`train_steps_per_iteration` must be an integer.") 281 282 @property 283 def estimator(self): 284 return self._estimator 285 286 @property 287 def eval_metrics(self): 288 return self._eval_metrics 289 290 @property 291 def train_steps(self): 292 return self._train_steps 293 294 @property 295 def eval_steps(self): 296 return self._eval_steps 297 298 def _set_export_strategies(self, values): # pylint: disable=missing-docstring 299 export_strategies = [] 300 if values: 301 if isinstance(values, export_strategy.ExportStrategy): 302 export_strategies.append(values) 303 else: 304 for value in values: 305 if not isinstance(value, export_strategy.ExportStrategy): 306 raise ValueError("`export_strategies` must be an ExportStrategy," 307 " an iterable of ExportStrategy, or `None`," 308 " found %s." % value) 309 export_strategies.append(value) 310 self._export_strategies = tuple(export_strategies) 311 312 def extend_train_hooks(self, additional_hooks): 313 """Extends the hooks for training.""" 314 self._train_monitors.extend(additional_hooks) 315 316 def reset_export_strategies(self, new_export_strategies=None): 317 """Resets the export strategies with the `new_export_strategies`. 318 319 Args: 320 new_export_strategies: A new list of `ExportStrategy`s, or a single one, 321 or None. 322 323 Returns: 324 The old export strategies. 325 """ 326 old_export_strategies = self._export_strategies 327 self._set_export_strategies(new_export_strategies) 328 return old_export_strategies 329 330 def train(self, delay_secs=None): 331 """Fit the estimator using the training data. 332 333 Train the estimator for `self._train_steps` steps, after waiting for 334 `delay_secs` seconds. If `self._train_steps` is `None`, train forever. 335 336 Args: 337 delay_secs: Start training after this many seconds. 338 339 Returns: 340 The trained estimator. 341 """ 342 start = time.time() 343 344 # Start the server, if needed. It's important to start the server before 345 # we (optionally) sleep for the case where no device_filters are set. 346 # Otherwise, the servers will wait to connect to each other before starting 347 # to train. We might as well start as soon as we can. 348 config = self._estimator.config 349 if isinstance(config, run_config.RunConfig): 350 if (config.cluster_spec and config.master and 351 config.environment == run_config.Environment.LOCAL): 352 logging.warn("ClusterSpec and master are provided, but environment is " 353 "set to 'local'. Set environment to 'cloud' if you intend " 354 "to use the distributed runtime.") 355 if (config.environment != run_config.Environment.LOCAL and 356 config.environment != run_config.Environment.GOOGLE and 357 config.cluster_spec and config.master): 358 self._start_server() 359 elif config.cluster_spec and config.master: 360 raise ValueError( 361 "For distributed runtime, Experiment class only works with " 362 "tf.contrib.learn.RunConfig for now, but provided {}".format( 363 type(config))) 364 365 extra_hooks = [] 366 if delay_secs is None: 367 task_id = self._estimator.config.task_id or 0 368 if self._delay_workers_by_global_step: 369 # Wait 5500 global steps for the second worker. Each worker waits more 370 # then previous one but with a diminishing number of steps. 371 extra_hooks.append( 372 basic_session_run_hooks.GlobalStepWaiterHook( 373 int(8000.0 * math.log(task_id + 1)))) 374 delay_secs = 0 375 else: 376 # Wait 5 secs more for each new worker up to 60 secs. 377 delay_secs = min(60, task_id * 5) 378 379 if delay_secs > 0: 380 elapsed_secs = time.time() - start 381 remaining = delay_secs - elapsed_secs 382 logging.info("Waiting %d secs before starting training.", remaining) 383 time.sleep(delay_secs) 384 385 return self._call_train( 386 input_fn=self._train_input_fn, 387 max_steps=self._train_steps, 388 hooks=self._train_monitors + extra_hooks, 389 saving_listeners=self._saving_listeners) 390 391 def evaluate(self, delay_secs=None, name=None): 392 """Evaluate on the evaluation data. 393 394 Runs evaluation on the evaluation data and returns the result. Runs for 395 `self._eval_steps` steps, or if it's `None`, then run until input is 396 exhausted or another exception is raised. Start the evaluation after 397 `delay_secs` seconds, or if it's `None`, defaults to using 398 `self._eval_delay_secs` seconds. 399 400 Args: 401 delay_secs: Start evaluating after this many seconds. If `None`, defaults 402 to using `self._eval_delays_secs`. 403 name: Gives the name to the evauation for the case multiple evaluation is 404 run for the same experiment. 405 406 Returns: 407 The result of the `evaluate` call to the `Estimator`. 408 """ 409 if delay_secs is None: 410 delay_secs = self._eval_delay_secs 411 412 if delay_secs: 413 logging.info("Waiting %d secs before starting eval.", delay_secs) 414 time.sleep(delay_secs) 415 416 return self._call_evaluate( 417 input_fn=self._eval_input_fn, 418 steps=self._eval_steps, 419 metrics=self._eval_metrics, 420 name=(name or "one_pass"), 421 hooks=self._eval_hooks) 422 423 @deprecated( 424 "2016-10-23", 425 "local_run will be renamed to train_and_evaluate and the new default " 426 "behavior will be to run evaluation every time there is a new " 427 "checkpoint.") 428 def local_run(self): 429 with _new_attr_context(self, "_min_eval_frequency"): 430 self._min_eval_frequency = self._local_eval_frequency 431 return self.train_and_evaluate() 432 433 # TODO(xiejw): Allow continuous_eval_predicate_fn to be passed via constructor 434 # once stopping all jobs is implemented. 435 def _continuous_eval(self, 436 input_fn, 437 name, 438 delay_secs, 439 throttle_delay_secs, 440 evaluate_checkpoint_only_once=True, 441 continuous_eval_predicate_fn=None, 442 export=True): 443 """Run continuous eval. 444 445 Runs infinite eval on the evaluation data set. This function starts 446 evaluating after `delay_secs` seconds and then runs no more than one 447 evaluation (with `self._eval_steps` steps each time) per 448 `throttle_delay_secs`. If `train_steps` is not None, will return after 449 global_step reaches `train_steps`. 450 451 Args: 452 input_fn: The input to use for this eval. 453 name: A string appended to the folder name of evaluation results. 454 delay_secs: Start evaluating after this many seconds. If None, defaults to 455 self._eval_delay_secs. 456 throttle_delay_secs: Do not re-evaluate unless the last evaluation was 457 started at least this many seconds ago. If None, defaults to 458 self._continuous_eval_throttle_secs. 459 evaluate_checkpoint_only_once: Whether to skip evaluation of checkpoints 460 that have already been evaluated. Default is `True`. 461 continuous_eval_predicate_fn: A predicate function determining whether to 462 continue eval after each iteration. A `predicate_fn` has one of the 463 following signatures: 464 * (eval_results) -> boolean 465 * (eval_results, checkpoint_path) -> boolean 466 Where `eval_results` is the dictionary of metric evaluations and 467 checkpoint_path is the path to the checkpoint containing the parameters 468 on which that evaluation was based. 469 At the beginning of evaluation, the passed `eval_results` will be None 470 so it's expected that the predicate function handles that gracefully. 471 Continuous eval behavior under different conditions: 472 * When `predicate_fn` is specified: 473 + if `train_steps` is None, run until `predicate_fn` returns False. 474 + if `train_steps` is specified, run until either global step 475 reaches `train_steps` or `predicate_fn` returns False. 476 * When `predicate_fn` is not specified: 477 + if `train_steps` is None, run in an infinite loop. 478 + if `train_steps` is specified, run until global step reaches 479 `train_steps`. 480 export: Whether to export from this step. Default is 'True'. 481 482 Raises: 483 ValueError: if `continuous_eval_predicate_fn` is neither None nor 484 callable. 485 """ 486 if continuous_eval_predicate_fn is not None: 487 if not callable(continuous_eval_predicate_fn): 488 raise ValueError( 489 "`continuous_eval_predicate_fn` must be a callable, or None.") 490 predicate_fn = _get_standardized_predicate_fn( 491 continuous_eval_predicate_fn) 492 else: 493 predicate_fn = None 494 495 if delay_secs is None: 496 delay_secs = self._eval_delay_secs 497 if throttle_delay_secs is None: 498 throttle_delay_secs = self._continuous_eval_throttle_secs 499 500 if delay_secs: 501 logging.info("Waiting %f secs before starting eval.", delay_secs) 502 time.sleep(delay_secs) 503 504 previous_path = None 505 eval_result = None 506 last_warning_time = 0 507 while (not predicate_fn or predicate_fn( 508 eval_result, checkpoint_path=previous_path)): 509 # Exit if we have already reached number of steps to train. 510 if self._has_training_stopped(eval_result): 511 logging.info("Exiting continuous eval, global_step=%s >= " 512 "train_step=%s", eval_result[ops.GraphKeys.GLOBAL_STEP], 513 self._train_steps) 514 return 515 516 start = time.time() 517 518 error_msg = None 519 latest_path = checkpoint_management.latest_checkpoint( 520 self._estimator.model_dir) 521 if not latest_path: 522 error_msg = ("Estimator is not fitted yet. " 523 "Will start an evaluation when a checkpoint is ready.") 524 elif evaluate_checkpoint_only_once and latest_path == previous_path: 525 error_msg = "No new checkpoint ready for evaluation." 526 527 if error_msg: 528 # Print warning message every 10 mins. 529 eval_result = {} 530 if time.time() - last_warning_time > 600: 531 logging.warning(error_msg) 532 last_warning_time = time.time() 533 else: 534 eval_result = self._call_evaluate( 535 input_fn=input_fn, 536 steps=self._eval_steps, 537 metrics=self._eval_metrics, 538 name=name, 539 checkpoint_path=latest_path, 540 hooks=self._eval_hooks) 541 # Ensure eval result is not None for next round of evaluation. 542 if not eval_result: 543 eval_result = {} 544 545 if export: 546 self._maybe_export(eval_result, checkpoint_path=latest_path) 547 548 # Clear warning timer and update last evaluated checkpoint 549 last_warning_time = 0 550 previous_path = latest_path 551 552 duration = time.time() - start 553 if duration < throttle_delay_secs: 554 difference = throttle_delay_secs - duration 555 logging.info("Waiting %f secs before starting next eval run.", 556 difference) 557 time.sleep(difference) 558 559 def _has_training_stopped(self, eval_result): 560 """Determines whether the training has stopped.""" 561 if not eval_result: 562 return False 563 564 global_step = eval_result.get(ops.GraphKeys.GLOBAL_STEP) 565 return global_step and self._train_steps and (global_step >= 566 self._train_steps) 567 568 def continuous_eval(self, 569 delay_secs=None, 570 throttle_delay_secs=None, 571 evaluate_checkpoint_only_once=True, 572 continuous_eval_predicate_fn=None, 573 name="continuous"): 574 self._continuous_eval( 575 self._eval_input_fn, 576 name=name, 577 delay_secs=delay_secs, 578 throttle_delay_secs=throttle_delay_secs, 579 evaluate_checkpoint_only_once=evaluate_checkpoint_only_once, 580 continuous_eval_predicate_fn=continuous_eval_predicate_fn) 581 582 def continuous_eval_on_train_data(self, 583 delay_secs=None, 584 throttle_delay_secs=None, 585 continuous_eval_predicate_fn=None, 586 name="continuous_on_train_data"): 587 self._continuous_eval( 588 self._train_input_fn, 589 name=name, 590 delay_secs=delay_secs, 591 throttle_delay_secs=throttle_delay_secs, 592 continuous_eval_predicate_fn=continuous_eval_predicate_fn, 593 export=False) 594 595 def train_and_evaluate(self): 596 """Interleaves training and evaluation. 597 598 The frequency of evaluation is controlled by the constructor arg 599 `min_eval_frequency`. When this parameter is 0, evaluation happens 600 only after training has completed. Note that evaluation cannot happen 601 more frequently than checkpoints are taken. If no new snapshots are 602 available when evaluation is supposed to occur, then evaluation doesn't 603 happen for another `min_eval_frequency` steps (assuming a checkpoint is 604 available at that point). Thus, settings `min_eval_frequency` to 1 means 605 that the model will be evaluated everytime there is a new checkpoint. 606 607 This is particular useful for a "Master" task in the cloud, whose 608 responsibility it is to take checkpoints, evaluate those checkpoints, 609 and write out summaries. Participating in training as the supervisor 610 allows such a task to accomplish the first and last items, while 611 performing evaluation allows for the second. 612 613 Returns: 614 The result of the `evaluate` call to the `Estimator` as well as the 615 export results using the specified `ExportStrategy`. 616 """ 617 # The directory to which evaluation summaries are written are determined 618 # by adding a suffix to 'eval'; that suffix is the 'name' parameter to 619 # the various evaluate(...) methods. By setting it to None, we force 620 # the directory name to simply be 'eval'. 621 eval_dir_suffix = None 622 623 # We set every_n_steps to 1, but evaluation only occurs when a new 624 # snapshot is available. If, by the time we finish evaluation 625 # there is a new snapshot, then we just evaluate again. Otherwise, 626 # we keep training until one becomes available. 627 with _new_attr_context(self, "_train_monitors"): 628 self._train_monitors = self._train_monitors or [] 629 config = self._estimator.config 630 intermediate_export = self._checkpoint_and_export and ( 631 config.save_checkpoints_secs or config.save_checkpoints_steps) 632 if intermediate_export: 633 # Create a partially specified evaluate function with the desired 634 # arguments. This will be executed by the _EvalAndExportListener, 635 # which will specify the latest checkpoint path. 636 eval_fn = functools.partial( 637 self._call_evaluate, 638 input_fn=self._eval_input_fn, 639 steps=self._eval_steps, 640 metrics=self._eval_metrics, 641 hooks=self._eval_hooks) 642 643 export_listener = _EvalAndExportListener( 644 eval_fn=eval_fn, 645 export_fn=self._maybe_export, 646 model_dir=self._estimator.model_dir) 647 648 saver_hook = basic_session_run_hooks.CheckpointSaverHook( 649 checkpoint_dir=self._estimator.model_dir, 650 save_secs=config.save_checkpoints_secs, 651 save_steps=config.save_checkpoints_steps, 652 listeners=[export_listener]) 653 self._train_monitors += [saver_hook] 654 else: 655 if self._min_eval_frequency: 656 # Using low min_eval_frequency (default is 1) on a non-cached file 657 # system requires a lot of overhead to read the checkpoint state file. 658 # This is particular bad on GCS and CNS. See also b/36498507 for 659 # context. `check_interval_secs = 5` avoids polling a remote 660 # fileystem too often. 661 662 self._train_monitors += [ 663 monitors.ValidationMonitor( 664 input_fn=self._eval_input_fn, 665 eval_steps=self._eval_steps, 666 metrics=self._eval_metrics, 667 every_n_steps=self._min_eval_frequency, 668 check_interval_secs=self._check_interval_secs, 669 name=eval_dir_suffix, 670 hooks=self._eval_hooks) 671 ] 672 self.train(delay_secs=0) 673 674 # If the checkpoint_and_export flag and appropriate estimator configuration 675 # parameters are set, then model evaluations and exports are done during the 676 # training process. In particular, this will always occur at the end of 677 # training, so we return the most recent results to avoid performing a 678 # duplicate evaluation and model export. 679 if intermediate_export: 680 return export_listener.eval_result, export_listener.export_results 681 else: 682 eval_result = self._call_evaluate( 683 input_fn=self._eval_input_fn, 684 steps=self._eval_steps, 685 metrics=self._eval_metrics, 686 name=eval_dir_suffix, 687 hooks=self._eval_hooks) 688 export_results = self._maybe_export(eval_result) 689 return eval_result, export_results 690 691 @experimental 692 def continuous_train_and_eval(self, continuous_eval_predicate_fn=None): 693 """Interleaves training and evaluation. 694 695 The frequency of evaluation is controlled by the `train_steps_per_iteration` 696 (via constructor). The model will be first trained for 697 `train_steps_per_iteration`, and then be evaluated in turns. 698 699 This method is intended for single machine usage. 700 701 This differs from `train_and_evaluate` as follows: 702 703 1. The procedure will have train and evaluation in turns. The model 704 will be trained for a number of steps (usually smaller than `train_steps` 705 if provided) and then be evaluated. `train_and_evaluate` will train the 706 model for `train_steps` (no small training iterations). 707 708 2. Due to the different approach this schedule takes, it leads to two 709 differences in resource control. First, the resources (e.g., memory) used 710 by training will be released before evaluation (`train_and_evaluate` takes 711 double resources). Second, more checkpoints will be saved as a checkpoint 712 is generated at the end of each training iteration. 713 714 3. As the estimator.train starts from scratch (new graph, new states for 715 input, etc) at each iteration, it is recommended to have the 716 `train_steps_per_iteration` larger. It is also recommended to shuffle your 717 input. 718 719 Args: 720 continuous_eval_predicate_fn: A predicate function determining whether to 721 continue eval after each iteration. A `predicate_fn` has one of the 722 following signatures: 723 * (eval_results) -> boolean 724 * (eval_results, checkpoint_path) -> boolean 725 Where `eval_results` is the dictionary of metric evaluations and 726 checkpoint_path is the path to the checkpoint containing the parameters 727 on which that evaluation was based. 728 At the beginning of evaluation, the passed `eval_results` and 729 `checkpoint_path` will be None so it's expected that the predicate 730 function handles that gracefully. 731 When `predicate_fn` is not specified, continuous eval will run in an 732 infinite loop (if `train_steps` is None). or exit once global step 733 reaches `train_steps`. 734 735 Returns: 736 A tuple of the result of the `evaluate` call to the `Estimator` and the 737 export results using the specified `ExportStrategy`. 738 739 Raises: 740 ValueError: if `continuous_eval_predicate_fn` is neither None nor 741 callable. 742 """ 743 744 if continuous_eval_predicate_fn is not None: 745 if not callable(continuous_eval_predicate_fn): 746 raise ValueError( 747 "`continuous_eval_predicate_fn` must be a callable, or None.") 748 predicate_fn = _get_standardized_predicate_fn( 749 continuous_eval_predicate_fn) 750 else: 751 predicate_fn = None 752 753 export_results = None 754 latest_checkpoint = None 755 eval_result = None 756 757 # Set the default value for train_steps_per_iteration, which will be 758 # overridden by other settings. 759 train_steps_per_iteration = 1000 760 if self._train_steps_per_iteration is not None: 761 train_steps_per_iteration = self._train_steps_per_iteration 762 elif self._train_steps is not None: 763 train_steps_per_iteration = int(self._train_steps / 10) 764 765 while (not predicate_fn or predicate_fn( 766 eval_result, checkpoint_path=latest_checkpoint 767 if eval_result else None)): 768 769 if self._has_training_stopped(eval_result): 770 # Exits once max steps of training is satisfied. 771 logging.info("Stop training model as max steps reached") 772 break 773 774 logging.info("Training model for %s steps", train_steps_per_iteration) 775 self._call_train( 776 input_fn=self._train_input_fn, 777 steps=train_steps_per_iteration, 778 hooks=self._train_monitors, 779 saving_listeners=self._saving_listeners) 780 781 logging.info("Evaluating model now.") 782 latest_checkpoint = checkpoint_management.latest_checkpoint( 783 self._estimator.model_dir) 784 eval_result = self._call_evaluate( 785 input_fn=self._eval_input_fn, 786 steps=self._eval_steps, 787 metrics=self._eval_metrics, 788 name="one_pass", 789 checkpoint_path=latest_checkpoint, 790 hooks=self._eval_hooks) 791 export_results = self._maybe_export(eval_result) 792 793 return eval_result, export_results 794 795 def _maybe_export(self, eval_result, checkpoint_path=None): 796 """Export the Estimator using export_fn, if defined.""" 797 export_dir_base = os.path.join( 798 compat.as_bytes(self._estimator.model_dir), compat.as_bytes("export")) 799 800 export_results = [] 801 for strategy in self._export_strategies: 802 export_results.append( 803 strategy.export( 804 self._estimator, 805 os.path.join( 806 compat.as_bytes(export_dir_base), 807 compat.as_bytes(strategy.name)), 808 checkpoint_path=checkpoint_path, 809 eval_result=eval_result)) 810 811 return export_results 812 813 def run_std_server(self): 814 """Starts a TensorFlow server and joins the serving thread. 815 816 Typically used for parameter servers. 817 818 Raises: 819 ValueError: if not enough information is available in the estimator's 820 config to create a server. 821 """ 822 self._start_server().join() 823 824 def test(self): 825 """Tests training, evaluating and exporting the estimator for a single step. 826 827 Returns: 828 The result of the `evaluate` call to the `Estimator`. 829 """ 830 self._call_train( 831 input_fn=self._train_input_fn, 832 steps=1, 833 hooks=self._train_monitors, 834 saving_listeners=self._saving_listeners) 835 836 eval_result = self._call_evaluate( 837 input_fn=self._eval_input_fn, 838 steps=1, 839 metrics=self._eval_metrics, 840 name="one_pass") 841 _ = self._maybe_export(eval_result) 842 843 return eval_result 844 845 def _start_server(self): 846 """Creates, starts, and returns a server_lib.Server.""" 847 config = self._estimator.config 848 if (not config.cluster_spec or not config.task_type or not config.master or 849 config.task_id is None): 850 raise ValueError("Could not start server; be sure to specify " 851 "cluster_spec, task_type, master, and task in " 852 "RunConfig or set the TF_CONFIG environment variable.") 853 server = server_lib.Server( 854 config.cluster_spec, 855 job_name=config.task_type, 856 task_index=config.task_id, 857 config=config.tf_config, 858 start=False) 859 server.start() 860 return server 861 862 def _call_train( 863 self, 864 _sentinel=None, # pylint: disable=invalid-name, 865 input_fn=None, 866 steps=None, 867 hooks=None, 868 max_steps=None, 869 saving_listeners=None): 870 if _sentinel is not None: 871 raise ValueError("_call_train should be called with keyword args only") 872 873 # Estimator in core cannot work with monitors. We need to convert them 874 # to hooks. For Estimator in contrib, it is converted internally. So, it is 875 # safe to convert for both cases. 876 hooks = monitors.replace_monitors_with_hooks(hooks, self._estimator) 877 if self._core_estimator_used: 878 return self._estimator.train( 879 input_fn=input_fn, 880 steps=steps, 881 max_steps=max_steps, 882 hooks=hooks, 883 saving_listeners=saving_listeners) 884 else: 885 return self._estimator.fit( 886 input_fn=input_fn, steps=steps, max_steps=max_steps, monitors=hooks) 887 888 def _call_evaluate( 889 self, 890 _sentinel=None, # pylint: disable=invalid-name, 891 input_fn=None, 892 steps=None, 893 metrics=None, 894 name=None, 895 checkpoint_path=None, 896 hooks=None): 897 if _sentinel is not None: 898 raise ValueError("_call_evaluate should be called with keyword args only") 899 900 if self._core_estimator_used: 901 if metrics is not None: 902 raise ValueError( 903 "`eval_metrics` must be `None` with `tf.estimator.Estimator`") 904 return self._estimator.evaluate( 905 input_fn=input_fn, 906 steps=steps, 907 name=name, 908 checkpoint_path=checkpoint_path, 909 hooks=hooks) 910 else: 911 return self._estimator.evaluate( 912 input_fn=input_fn, 913 steps=steps, 914 metrics=metrics, 915 name=name, 916 checkpoint_path=checkpoint_path, 917 hooks=hooks) 918 919 920@contextlib.contextmanager 921def _new_attr_context(obj, attr): 922 """Creates a new context in which an object's attribute can be changed. 923 924 This creates a context in which an object's attribute can be changed. 925 Once the context is exited, the attribute reverts to its original value. 926 927 Args: 928 obj: An object whose attribute to restore at the end of the context. 929 attr: An attribute to remember and restore at the end of the context. 930 931 Yields: 932 Context. 933 934 Example: 935 my_obj.x = 1 936 with _new_attr_context(my_obj, "x"): 937 my_obj.x = 2 938 print(my_obj.x) 939 print(my_obj.x) 940 """ 941 saved = getattr(obj, attr) 942 try: 943 yield 944 finally: 945 setattr(obj, attr, saved) 946