1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Monitors instrument the training process (deprecated). 16 17This module and all its submodules are deprecated. See 18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 19for migration instructions. 20 21@@get_default_monitors 22@@BaseMonitor 23@@CaptureVariable 24@@CheckpointSaver 25@@EveryN 26@@ExportMonitor 27@@GraphDump 28@@LoggingTrainable 29@@NanLoss 30@@PrintTensor 31@@StepCounter 32@@StopAtStep 33@@SummarySaver 34@@ValidationMonitor 35""" 36 37from __future__ import absolute_import 38from __future__ import division 39from __future__ import print_function 40 41import copy 42import os 43import time 44 45import numpy as np 46import six 47 48from tensorflow.core.framework.summary_pb2 import Summary 49from tensorflow.core.util.event_pb2 import SessionLog 50from tensorflow.python.estimator import estimator as core_estimator 51from tensorflow.python.framework import ops 52from tensorflow.python.platform import tf_logging as logging 53from tensorflow.python.summary import summary as core_summary 54from tensorflow.python.training import checkpoint_management 55from tensorflow.python.training import session_run_hook 56from tensorflow.python.training import training_util 57from tensorflow.python.util import deprecation 58from tensorflow.python.util import tf_inspect 59 60 61# TODO(ptucker): Split each monitor class into a separate file. 62# TODO(ptucker): Fail if epoch or step does not monotonically increase? 63class BaseMonitor(object): 64 """Base class for Monitors. 65 66 THIS CLASS IS DEPRECATED. See 67 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 68 for general migration instructions. 69 70 Defines basic interfaces of Monitors. 71 Monitors can either be run on all workers or, more commonly, restricted 72 to run exclusively on the elected chief worker. 73 """ 74 75 @deprecation.deprecated( 76 "2016-12-05", 77 "Monitors are deprecated. Please use tf.train.SessionRunHook.") 78 def __init__(self): 79 self._begun = False 80 self._current_epoch = None 81 self._current_step = None 82 self._max_steps = None 83 self._estimator = None 84 85 @property 86 def run_on_all_workers(self): 87 return False 88 89 def set_estimator(self, estimator): 90 """A setter called automatically by the target estimator. 91 92 If the estimator is locked, this method does nothing. 93 94 Args: 95 estimator: the estimator that this monitor monitors. 96 97 Raises: 98 ValueError: if the estimator is None. 99 """ 100 if estimator is None: 101 raise ValueError("Missing estimator.") 102 # TODO(mdan): This should fail if called twice with the same estimator. 103 self._estimator = estimator 104 105 def begin(self, max_steps=None): 106 """Called at the beginning of training. 107 108 When called, the default graph is the one we are executing. 109 110 Args: 111 max_steps: `int`, the maximum global step this training will run until. 112 113 Raises: 114 ValueError: if we've already begun a run. 115 """ 116 if self._begun: 117 raise ValueError("begin called twice without end.") 118 self._max_steps = max_steps 119 self._begun = True 120 121 def end(self, session=None): 122 """Callback at the end of training/evaluation. 123 124 Args: 125 session: A `tf.Session` object that can be used to run ops. 126 127 Raises: 128 ValueError: if we've not begun a run. 129 """ 130 _ = session 131 if not self._begun: 132 raise ValueError("end called without begin.") 133 self._max_steps = None 134 self._begun = False 135 136 def epoch_begin(self, epoch): 137 """Begin epoch. 138 139 Args: 140 epoch: `int`, the epoch number. 141 142 Raises: 143 ValueError: if we've already begun an epoch, or `epoch` < 0. 144 """ 145 if self._current_epoch is not None: 146 raise ValueError("epoch_begin called twice without epoch_end.") 147 if epoch < 0: 148 raise ValueError("Invalid epoch %s." % epoch) 149 self._current_epoch = epoch 150 151 def epoch_end(self, epoch): 152 """End epoch. 153 154 Args: 155 epoch: `int`, the epoch number. 156 157 Raises: 158 ValueError: if we've not begun an epoch, or `epoch` number does not match. 159 """ 160 if self._current_epoch != epoch: 161 raise ValueError("epoch_end expected %s but got %s.", self._current_epoch, 162 epoch) 163 self._current_epoch = None 164 165 def step_begin(self, step): 166 """Callback before training step begins. 167 168 You may use this callback to request evaluation of additional tensors 169 in the graph. 170 171 Args: 172 step: `int`, the current value of the global step. 173 174 Returns: 175 List of `Tensor` objects or string tensor names to be run. 176 177 Raises: 178 ValueError: if we've already begun a step, or `step` < 0, or 179 `step` > `max_steps`. 180 """ 181 if (step < 0) or ((self._max_steps is not None) and 182 (step > self._max_steps)): 183 raise ValueError("Invalid step %s." % step) 184 self._current_step = step 185 return [] 186 187 def step_end(self, step, output): # pylint: disable=unused-argument 188 """Callback after training step finished. 189 190 This callback provides access to the tensors/ops evaluated at this step, 191 including the additional tensors for which evaluation was requested in 192 `step_begin`. 193 194 In addition, the callback has the opportunity to stop training by returning 195 `True`. This is useful for early stopping, for example. 196 197 Note that this method is not called if the call to `Session.run()` that 198 followed the last call to `step_begin()` failed. 199 200 Args: 201 step: `int`, the current value of the global step. 202 output: `dict` mapping `string` values representing tensor names to 203 the value resulted from running these tensors. Values may be either 204 scalars, for scalar tensors, or Numpy `array`, for non-scalar tensors. 205 206 Returns: 207 `bool`. True if training should stop. 208 209 Raises: 210 ValueError: if we've not begun a step, or `step` number does not match. 211 """ 212 if self._current_step != step: 213 raise ValueError("step_end expected %s but got %s.", self._current_step, 214 step) 215 self._current_step = None 216 return False 217 218 def post_step(self, step, session): # pylint: disable=unused-argument 219 """Callback after the step is finished. 220 221 Called after step_end and receives session to perform extra session.run 222 calls. If failure occurred in the process, will be called as well. 223 224 Args: 225 step: `int`, global step of the model. 226 session: `Session` object. 227 """ 228 _ = step, session 229 230 231def _extract_output(outputs, request): 232 if request in outputs: 233 return outputs[request] 234 return outputs[request.name] 235 236 237class EveryN(BaseMonitor): 238 """Base class for monitors that execute callbacks every N steps. 239 240 THIS CLASS IS DEPRECATED. See 241 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 242 for general migration instructions. 243 244 This class adds three new callbacks: 245 - every_n_step_begin 246 - every_n_step_end 247 - every_n_post_step 248 249 The callbacks are executed every n steps, or optionally every step for the 250 first m steps, where m and n can both be user-specified. 251 252 When extending this class, note that if you wish to use any of the 253 `BaseMonitor` callbacks, you must call their respective super implementation: 254 255 def step_begin(self, step): 256 super(ExampleMonitor, self).step_begin(step) 257 return [] 258 259 Failing to call the super implementation will cause unpredictable behavior. 260 261 The `every_n_post_step()` callback is also called after the last step if it 262 was not already called through the regular conditions. Note that 263 `every_n_step_begin()` and `every_n_step_end()` do not receive that special 264 treatment. 265 266 """ 267 268 # TODO(ipolosukhin): Add also every n seconds. 269 270 def __init__(self, every_n_steps=100, first_n_steps=1): 271 """Initializes an `EveryN` monitor. 272 273 Args: 274 every_n_steps: `int`, the number of steps to allow between callbacks. 275 first_n_steps: `int`, specifying the number of initial steps during 276 which the callbacks will always be executed, regardless of the value 277 of `every_n_steps`. Note that this value is relative to the global step 278 """ 279 super(EveryN, self).__init__() 280 self._every_n_steps = every_n_steps 281 self._first_n_steps = first_n_steps 282 # Last step in the model. 283 self._last_successful_step = None 284 # Last step at which we called one of the every_n methods 285 self._last_active_step = 0 286 self._every_n_step_begin_called = False 287 288 def every_n_step_begin(self, step): # pylint: disable=unused-argument 289 """Callback before every n'th step begins. 290 291 Args: 292 step: `int`, the current value of the global step. 293 294 Returns: 295 A `list` of tensors that will be evaluated at this step. 296 """ 297 return [] 298 299 def every_n_step_end(self, step, outputs): # pylint: disable=unused-argument 300 """Callback after every n'th step finished. 301 302 This callback provides access to the tensors/ops evaluated at this step, 303 including the additional tensors for which evaluation was requested in 304 `step_begin`. 305 306 In addition, the callback has the opportunity to stop training by returning 307 `True`. This is useful for early stopping, for example. 308 309 Args: 310 step: `int`, the current value of the global step. 311 outputs: `dict` mapping `string` values representing tensor names to 312 the value resulted from running these tensors. Values may be either 313 scalars, for scalar tensors, or Numpy `array`, for non-scalar tensors. 314 315 Returns: 316 `bool`. True if training should stop. 317 """ 318 return False 319 320 def every_n_post_step(self, step, session): 321 """Callback after a step is finished or `end()` is called. 322 323 Args: 324 step: `int`, the current value of the global step. 325 session: `Session` object. 326 """ 327 pass 328 329 def step_begin(self, step): 330 """Overrides `BaseMonitor.step_begin`. 331 332 When overriding this method, you must call the super implementation. 333 334 Args: 335 step: `int`, the current value of the global step. 336 Returns: 337 A `list`, the result of every_n_step_begin, if that was called this step, 338 or an empty list otherwise. 339 340 Raises: 341 ValueError: if called more than once during a step. 342 """ 343 super(EveryN, self).step_begin(step) 344 if (step <= self._first_n_steps or 345 step >= (self._every_n_steps + self._last_active_step) or 346 step == self._max_steps): # Note: max_steps can be None here. 347 self._every_n_step_begin_called = True 348 return self.every_n_step_begin(step) 349 self._every_n_step_begin_called = False 350 return [] 351 352 def step_end(self, step, output): 353 """Overrides `BaseMonitor.step_end`. 354 355 When overriding this method, you must call the super implementation. 356 357 Args: 358 step: `int`, the current value of the global step. 359 output: `dict` mapping `string` values representing tensor names to 360 the value resulted from running these tensors. Values may be either 361 scalars, for scalar tensors, or Numpy `array`, for non-scalar tensors. 362 Returns: 363 `bool`, the result of every_n_step_end, if that was called this step, 364 or `False` otherwise. 365 """ 366 super(EveryN, self).step_end(step, output) 367 if self._every_n_step_begin_called: 368 return self.every_n_step_end(step, output) 369 return False 370 371 def post_step(self, step, session): 372 super(EveryN, self).post_step(step, session) 373 if self._every_n_step_begin_called: 374 self.every_n_post_step(step, session) 375 self._last_active_step = step 376 self._last_successful_step = step 377 378 def end(self, session=None): 379 super(EveryN, self).end(session=session) 380 if self._last_successful_step != self._last_active_step: 381 self.every_n_post_step(self._last_successful_step, session) 382 383 384class StopAtStep(BaseMonitor): 385 """Monitor to request stop at a specified step.""" 386 387 def __init__(self, num_steps=None, last_step=None): 388 """Create a StopAtStep monitor. 389 390 This monitor requests stop after either a number of steps have been 391 executed or a last step has been reached. Only of the two options can be 392 specified. 393 394 if `num_steps` is specified, it indicates the number of steps to execute 395 after `begin()` is called. If instead `last_step` is specified, it 396 indicates the last step we want to execute, as passed to the `step_begin()` 397 call. 398 399 Args: 400 num_steps: Number of steps to execute. 401 last_step: Step after which to stop. 402 403 Raises: 404 ValueError: If one of the arguments is invalid. 405 """ 406 super(StopAtStep, self).__init__() 407 if num_steps is None and last_step is None: 408 raise ValueError("One of num_steps or last_step must be specified.") 409 if num_steps is not None and last_step is not None: 410 raise ValueError("Only one of num_steps or last_step can be specified.") 411 self._num_steps = num_steps 412 self._last_step = last_step 413 414 @property 415 def run_on_all_workers(self): 416 return True 417 418 def step_begin(self, step): 419 super(StopAtStep, self).step_begin(step) 420 if self._last_step is None: 421 self._last_step = step + self._num_steps - 1 422 return [] 423 424 def step_end(self, step, output): 425 super(StopAtStep, self).step_end(step, output) 426 return step >= self._last_step 427 428 429# TODO(ptucker): Rename to LoggingTensor since it's not writing to stdout. 430class PrintTensor(EveryN): 431 """Prints given tensors every N steps. 432 433 THIS CLASS IS DEPRECATED. See 434 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 435 for general migration instructions. 436 437 This is an `EveryN` monitor and has consistent semantic for `every_n` 438 and `first_n`. 439 440 The tensors will be printed to the log, with `INFO` severity. 441 """ 442 443 def __init__(self, tensor_names, every_n=100, first_n=1): 444 """Initializes a PrintTensor monitor. 445 446 Args: 447 tensor_names: `dict` of tag to tensor names or 448 `iterable` of tensor names (strings). 449 every_n: `int`, print every N steps. See `PrintN.` 450 first_n: `int`, also print the first N steps. See `PrintN.` 451 """ 452 super(PrintTensor, self).__init__(every_n, first_n) 453 if not isinstance(tensor_names, dict): 454 tensor_names = {item: item for item in tensor_names} 455 self._tensor_names = tensor_names 456 457 def every_n_step_begin(self, step): 458 super(PrintTensor, self).every_n_step_begin(step) 459 return list(self._tensor_names.values()) 460 461 def every_n_step_end(self, step, outputs): 462 super(PrintTensor, self).every_n_step_end(step, outputs) 463 stats = [] 464 for tag, tensor_name in six.iteritems(self._tensor_names): 465 if tensor_name in outputs: 466 stats.append("%s = %s" % (tag, 467 str(_extract_output(outputs, tensor_name)))) 468 logging.info("Step %d: %s", step, ", ".join(stats)) 469 470 471class LoggingTrainable(EveryN): 472 """Writes trainable variable values into log every N steps. 473 474 THIS CLASS IS DEPRECATED. See 475 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 476 for general migration instructions. 477 478 Write the tensors in trainable variables `every_n` steps, 479 starting with the `first_n`th step. 480 """ 481 482 def __init__(self, scope=None, every_n=100, first_n=1): 483 """Initializes LoggingTrainable monitor. 484 485 Args: 486 scope: An optional string to match variable names using re.match. 487 every_n: Print every N steps. 488 first_n: Print first N steps. 489 """ 490 super(LoggingTrainable, self).__init__(every_n, first_n) 491 self._scope = scope 492 493 def every_n_step_begin(self, step): 494 super(LoggingTrainable, self).every_n_step_begin(step) 495 # Get a list of trainable variables at the beginning of every N steps. 496 # We cannot get this in __init__ because train_op has not been generated. 497 trainables = ops.get_collection( 498 ops.GraphKeys.TRAINABLE_VARIABLES, scope=self._scope) 499 self._names = {} 500 for var in trainables: 501 self._names[var.name] = var.value().name 502 return list(self._names.values()) 503 504 def every_n_step_end(self, step, outputs): 505 super(LoggingTrainable, self).every_n_step_end(step, outputs) 506 stats = [] 507 for tag, tensor_name in six.iteritems(self._names): 508 if tensor_name in outputs: 509 stats.append("%s = %s" % (tag, 510 str(_extract_output(outputs, tensor_name)))) 511 logging.info("Logging Trainable: Step %d: %s", step, ", ".join(stats)) 512 513 514class SummarySaver(EveryN): 515 """Saves summaries every N steps. 516 517 THIS CLASS IS DEPRECATED. See 518 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 519 for general migration instructions. 520 """ 521 522 def __init__(self, 523 summary_op, 524 save_steps=100, 525 output_dir=None, 526 summary_writer=None, 527 scaffold=None): 528 """Initializes a `SummarySaver` monitor. 529 530 Args: 531 summary_op: `Tensor` of type `string`. A serialized `Summary` protocol 532 buffer, as output by TF summary methods like `summary.scalar` or 533 `summary.merge_all`. 534 save_steps: `int`, save summaries every N steps. See `EveryN`. 535 output_dir: `string`, the directory to save the summaries to. Only used 536 if no `summary_writer` is supplied. 537 summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed, 538 one will be created accordingly. 539 scaffold: `Scaffold` to get summary_op if it's not provided. 540 """ 541 # TODO(ipolosukhin): Implement every N seconds. 542 super(SummarySaver, self).__init__(every_n_steps=save_steps) 543 self._summary_op = summary_op 544 self._summary_writer = summary_writer 545 if summary_writer is None and output_dir: 546 self._summary_writer = core_summary.FileWriter(output_dir) 547 self._scaffold = scaffold 548 # TODO(mdan): Throw an error if output_dir and summary_writer are None. 549 550 def set_estimator(self, estimator): 551 super(SummarySaver, self).set_estimator(estimator) 552 # TODO(mdan): This line looks redundant. 553 if self._summary_writer is None: 554 self._summary_writer = core_summary.FileWriter(estimator.model_dir) 555 556 def every_n_step_begin(self, step): 557 super(SummarySaver, self).every_n_step_begin(step) 558 if self._summary_op is None and self._scaffold is not None: 559 self._summary_op = self._scaffold.summary_op 560 if self._summary_op is not None: 561 return [self._summary_op] 562 return [] 563 564 def every_n_step_end(self, step, outputs): 565 super(SummarySaver, self).every_n_step_end(step, outputs) 566 if self._summary_op is not None: 567 summary_strs = _extract_output(outputs, self._summary_op) 568 if self._summary_writer: 569 self._summary_writer.add_summary(summary_strs, step) 570 return False 571 572 def end(self, session=None): 573 super(SummarySaver, self).end(session=session) 574 if self._summary_writer: 575 self._summary_writer.flush() 576 577 578class ValidationMonitor(EveryN): 579 """Runs evaluation of a given estimator, at most every N steps. 580 581 THIS CLASS IS DEPRECATED. See 582 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 583 for general migration instructions. 584 585 Note that the evaluation is done based on the saved checkpoint, which will 586 usually be older than the current step. 587 588 Can do early stopping on validation metrics if `early_stopping_rounds` is 589 provided. 590 """ 591 592 def __init__(self, 593 x=None, 594 y=None, 595 input_fn=None, 596 batch_size=None, 597 eval_steps=None, 598 every_n_steps=100, 599 metrics=None, 600 hooks=None, 601 early_stopping_rounds=None, 602 early_stopping_metric="loss", 603 early_stopping_metric_minimize=True, 604 name=None, 605 check_interval_secs=5): 606 """Initializes a ValidationMonitor. 607 608 Args: 609 x: See `BaseEstimator.evaluate`. 610 y: See `BaseEstimator.evaluate`. 611 input_fn: See `BaseEstimator.evaluate`. 612 batch_size: See `BaseEstimator.evaluate`. 613 eval_steps: See `BaseEstimator.evaluate`. 614 every_n_steps: Check for new checkpoints to evaluate every N steps. If a 615 new checkpoint is found, it is evaluated. See `EveryN`. 616 metrics: See `BaseEstimator.evaluate`. 617 hooks: A list of `SessionRunHook` hooks to pass to the 618 `Estimator`'s `evaluate` function. 619 early_stopping_rounds: `int`. If the metric indicated by 620 `early_stopping_metric` does not change according to 621 `early_stopping_metric_minimize` for this many steps, then training 622 will be stopped. 623 early_stopping_metric: `string`, name of the metric to check for early 624 stopping. 625 early_stopping_metric_minimize: `bool`, True if `early_stopping_metric` is 626 expected to decrease (thus early stopping occurs when this metric 627 stops decreasing), False if `early_stopping_metric` is expected to 628 increase. Typically, `early_stopping_metric_minimize` is True for 629 loss metrics like mean squared error, and False for performance 630 metrics like accuracy. 631 name: See `BaseEstimator.evaluate`. 632 check_interval_secs: Only check for new checkpoint if at least 633 `check_interval_secs` have passed. Ignore if None. Default is 5 secs. 634 635 636 Raises: 637 ValueError: If both x and input_fn are provided. 638 """ 639 super(ValidationMonitor, self).__init__( 640 every_n_steps=every_n_steps, first_n_steps=-1) 641 # TODO(mdan): Checks like this are already done by evaluate. 642 if x is None and input_fn is None: 643 raise ValueError("Either x or input_fn should be provided.") 644 self.x = x 645 self.y = y 646 self.input_fn = input_fn 647 self.batch_size = batch_size 648 self.eval_steps = eval_steps 649 self.metrics = metrics 650 self.hooks = hooks 651 self.early_stopping_rounds = early_stopping_rounds 652 self.early_stopping_metric = early_stopping_metric 653 self.early_stopping_metric_minimize = early_stopping_metric_minimize 654 self.name = name 655 self._best_value_step = None 656 self._best_value = None 657 self._best_metrics = None 658 self._early_stopped = False 659 self._latest_path = None 660 self._latest_path_step = None 661 self._last_checkpoint_check_time = None 662 self._check_interval_secs = check_interval_secs 663 664 @property 665 def early_stopped(self): 666 """Returns True if this monitor caused an early stop.""" 667 return self._early_stopped 668 669 @property 670 def best_step(self): 671 """Returns the step at which the best early stopping metric was found.""" 672 return self._best_value_step 673 674 @property 675 def best_value(self): 676 """Returns the best early stopping metric value found so far.""" 677 return self._best_value 678 679 @property 680 def best_metrics(self): 681 """Returns all eval metrics computed with the best early stopping metric. 682 683 For instance, if the metrics computed in two successive evals are 684 1. {'loss':40, 'auc':0.5} 685 2. {'loss':50, 'auc':0.6} 686 this function would return the first dict {'loss':40, 'auc':0.5} after both 687 first and second eval (if `early_stopping_metric` is 'loss' and 688 `early_stopping_metric_minimize` is True). 689 690 Returns: 691 The output dict of estimator.evaluate which contains the best value of 692 the early stopping metric seen so far. 693 """ 694 return self._best_metrics 695 696 def _evaluate_estimator(self): 697 if isinstance(self._estimator, core_estimator.Estimator): 698 if any((x is not None 699 for x in [self.x, self.y, self.batch_size, self.metrics])): 700 raise ValueError( 701 "tf.estimator.Estimator does not support following " 702 "arguments: x, y, batch_size, metrics. Should set as `None` " 703 "in ValidationMonitor") 704 return self._estimator.evaluate( 705 input_fn=self.input_fn, 706 steps=self.eval_steps, 707 hooks=self.hooks, 708 name=self.name) 709 else: 710 return self._estimator.evaluate( 711 x=self.x, 712 y=self.y, 713 input_fn=self.input_fn, 714 batch_size=self.batch_size, 715 steps=self.eval_steps, 716 metrics=self.metrics, 717 hooks=self.hooks, 718 name=self.name) 719 720 def every_n_step_end(self, step, outputs): 721 super(ValidationMonitor, self).every_n_step_end(step, outputs) 722 # TODO(mdan): The use of step below is probably misleading. 723 # The code should probably use the step from the checkpoint, because 724 # that's what is being evaluated. 725 if self._estimator is None: 726 raise ValueError("Missing call to set_estimator.") 727 current_time = time.time() 728 if (self._check_interval_secs is not None and 729 self._last_checkpoint_check_time is not None and 730 current_time - self._last_checkpoint_check_time <= 731 self._check_interval_secs): 732 logging.debug( 733 "Skipping evaluation since less than %d seconds have passed since " 734 "last check for a new checkpoint.", self._check_interval_secs) 735 return False 736 self._last_checkpoint_check_time = current_time 737 # Check that we are not running evaluation on the same checkpoint. 738 latest_path = checkpoint_management.latest_checkpoint( 739 self._estimator.model_dir) 740 if latest_path is None: 741 logging.debug("Skipping evaluation since model has not been saved yet " 742 "at step %d.", step) 743 return False 744 if latest_path is not None and latest_path == self._latest_path: 745 logging.debug("Skipping evaluation due to same checkpoint %s for step %d " 746 "as for step %d.", latest_path, step, 747 self._latest_path_step) 748 return False 749 self._latest_path = latest_path 750 self._latest_path_step = step 751 752 # Run evaluation and log it. 753 validation_outputs = self._evaluate_estimator() 754 stats = [] 755 for name in validation_outputs: 756 stats.append("%s = %s" % (name, str(validation_outputs[name]))) 757 logging.info("Validation (step %d): %s", step, ", ".join(stats)) 758 759 # Early stopping logic. 760 if self.early_stopping_rounds is not None: 761 if self.early_stopping_metric not in validation_outputs: 762 raise ValueError("Metric %s missing from outputs %s." % 763 (self.early_stopping_metric, 764 set(validation_outputs.keys()))) 765 current_value = validation_outputs[self.early_stopping_metric] 766 if (self._best_value is None or (self.early_stopping_metric_minimize and 767 (current_value < self._best_value)) or 768 (not self.early_stopping_metric_minimize and 769 (current_value > self._best_value))): 770 self._best_value = current_value 771 self._best_metrics = copy.deepcopy(validation_outputs) 772 self._best_value_step = step 773 stop_now = (step - self._best_value_step >= self.early_stopping_rounds) 774 if stop_now: 775 logging.info("Stopping. Best step: {} with {} = {}.".format( 776 self._best_value_step, self.early_stopping_metric, 777 self._best_value)) 778 self._early_stopped = True 779 return True 780 return False 781 782 783# TODO(ptucker): This really reads any tensor, not just vars, and requires the 784# ':0' suffix on var_name. 785class CaptureVariable(EveryN): 786 """Captures a variable's values into a collection. 787 788 THIS CLASS IS DEPRECATED. See 789 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 790 for general migration instructions. 791 792 This monitor is useful for unit testing. You should exercise caution when 793 using this monitor in production, since it never discards values. 794 795 This is an `EveryN` monitor and has consistent semantic for `every_n` 796 and `first_n`. 797 """ 798 799 def __init__(self, var_name, every_n=100, first_n=1): 800 """Initializes a CaptureVariable monitor. 801 802 Args: 803 var_name: `string`. The variable name, including suffix (typically ":0"). 804 every_n: `int`, print every N steps. See `PrintN.` 805 first_n: `int`, also print the first N steps. See `PrintN.` 806 """ 807 super(CaptureVariable, self).__init__(every_n, first_n) 808 self._var_name = var_name 809 self._var_values = {} 810 811 @property 812 def values(self): 813 """Returns the values captured so far. 814 815 Returns: 816 `dict` mapping `int` step numbers to that values of the variable at the 817 respective step. 818 """ 819 return self._var_values 820 821 def every_n_step_begin(self, step): 822 super(CaptureVariable, self).every_n_step_begin(step) 823 return [self._var_name] 824 825 def every_n_step_end(self, step, outputs): 826 super(CaptureVariable, self).every_n_step_end(step, outputs) 827 self._var_values[step] = _extract_output(outputs, self._var_name) 828 829 830@deprecation.deprecated(None, "Use tf.train.MonitoredTrainingSession.") 831def get_default_monitors(loss_op=None, 832 summary_op=None, 833 save_summary_steps=100, 834 output_dir=None, 835 summary_writer=None): 836 """Returns a default set of typically-used monitors. 837 838 Args: 839 loss_op: `Tensor`, the loss tensor. This will be printed using `PrintTensor` 840 at the default interval. 841 summary_op: See `SummarySaver`. 842 save_summary_steps: See `SummarySaver`. 843 output_dir: See `SummarySaver`. 844 summary_writer: See `SummarySaver`. 845 Returns: 846 `list` of monitors. 847 """ 848 849 monitors = [] 850 if loss_op is not None: 851 monitors.append(PrintTensor(tensor_names={"loss": loss_op.name})) 852 if summary_op is not None: 853 monitors.append( 854 SummarySaver( 855 summary_op, 856 save_steps=save_summary_steps, 857 output_dir=output_dir, 858 summary_writer=summary_writer)) 859 return monitors 860 861 862class GraphDump(BaseMonitor): 863 """Dumps almost all tensors in the graph at every step. 864 865 THIS CLASS IS DEPRECATED. See 866 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 867 for general migration instructions. 868 869 Note, this is very expensive, prefer `PrintTensor` in production. 870 """ 871 872 IGNORE_OPS = [ 873 "Const", "Assign", "Identity", "Placeholder", "RandomUniform", "Cast", 874 "RestoreSlice" 875 ] 876 877 def __init__(self, ignore_ops=None): 878 """Initializes GraphDump monitor. 879 880 Args: 881 ignore_ops: `list` of `string`. Names of ops to ignore. 882 If None, `GraphDump.IGNORE_OPS` is used. 883 """ 884 super(GraphDump, self).__init__() 885 self._ignore_ops = ignore_ops or GraphDump.IGNORE_OPS 886 self._data = {} 887 888 def begin(self, max_steps=None): 889 super(GraphDump, self).begin(max_steps=max_steps) 890 self._tensors = [] 891 graph = ops.get_default_graph() 892 graph_def = graph.as_graph_def() 893 for node in graph_def.node: 894 if node.op in self._ignore_ops: 895 continue 896 logging.info("op=%s name=%s.", node.op, node.name) 897 try: 898 self._tensors.append(graph.get_tensor_by_name(node.name + ":0")) 899 except KeyError: 900 pass 901 902 def step_begin(self, step): 903 super(GraphDump, self).step_begin(step) 904 return self._tensors 905 906 def step_end(self, step, output): 907 super(GraphDump, self).step_end(step, output) 908 self._data[step] = output 909 910 @property 911 def data(self): 912 return self._data 913 914 # TODO(ptucker): Handle keys that are in one but not the other. 915 def compare(self, other_dump, step, atol=1e-06): 916 """Compares two `GraphDump` monitors and returns differences. 917 918 Args: 919 other_dump: Another `GraphDump` monitor. 920 step: `int`, step to compare on. 921 atol: `float`, absolute tolerance in comparison of floating arrays. 922 923 Returns: 924 Returns tuple: 925 matched: `list` of keys that matched. 926 non_matched: `dict` of keys to tuple of 2 mismatched values. 927 928 Raises: 929 ValueError: if a key in `data` is missing from `other_dump` at `step`. 930 """ 931 non_matched = {} 932 matched = [] 933 this_output = self.data[step] if step in self.data else {} 934 other_output = other_dump.data[step] if step in other_dump.data else {} 935 for key in this_output: 936 if not isinstance(key, six.string_types): 937 continue 938 if key not in other_output: 939 raise ValueError("%s missing at step %s.", (key, step)) 940 value1 = _extract_output(this_output, key) 941 value2 = _extract_output(other_output, key) 942 if isinstance(value1, str): 943 continue 944 if isinstance(value1, np.ndarray): 945 if not np.allclose(value1, value2, atol=atol): 946 non_matched[key] = value1 - value2 947 else: 948 matched.append(key) 949 else: 950 if value1 != value2: 951 non_matched[key] = (value1, value2) 952 else: 953 matched.append(key) 954 return matched, non_matched 955 956 957class ExportMonitor(EveryN): 958 """Monitor that exports Estimator every N steps. 959 960 THIS CLASS IS DEPRECATED. See 961 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 962 for general migration instructions. 963 """ 964 965 @deprecation.deprecated("2017-03-25", 966 "ExportMonitor is deprecated. Please pass an " 967 "ExportStrategy to Experiment instead.") 968 def __init__(self, 969 every_n_steps, 970 export_dir, 971 input_fn=None, 972 input_feature_key=None, 973 exports_to_keep=5, 974 signature_fn=None, 975 default_batch_size=1): 976 """Initializes ExportMonitor. 977 978 Args: 979 every_n_steps: Run monitor every N steps. 980 export_dir: str, folder to export. 981 input_fn: A function that takes no argument and returns a tuple of 982 (features, labels), where features is a dict of string key to `Tensor` 983 and labels is a `Tensor` that's currently not used (and so can be 984 `None`). 985 input_feature_key: String key into the features dict returned by 986 `input_fn` that corresponds to the raw `Example` strings `Tensor` that 987 the exported model will take as input. Should be `None` if and only if 988 you're passing in a `signature_fn` that does not use the first arg 989 (`Tensor` of `Example` strings). 990 exports_to_keep: int, number of exports to keep. 991 signature_fn: Function that returns a default signature and a named 992 signature map, given `Tensor` of `Example` strings, `dict` of `Tensor`s 993 for features and `dict` of `Tensor`s for predictions. 994 default_batch_size: Default batch size of the `Example` placeholder. 995 996 Raises: 997 ValueError: If `input_fn` and `input_feature_key` are not both defined or 998 are not both `None`. 999 """ 1000 super(ExportMonitor, self).__init__(every_n_steps=every_n_steps) 1001 self._export_dir = export_dir 1002 self._input_fn = input_fn 1003 self._input_feature_key = input_feature_key 1004 self._use_deprecated_input_fn = input_fn is None 1005 self._exports_to_keep = exports_to_keep 1006 self._signature_fn = signature_fn 1007 self._default_batch_size = default_batch_size 1008 self._last_export_dir = None 1009 1010 @property 1011 def export_dir(self): 1012 return self._export_dir 1013 1014 @property 1015 def exports_to_keep(self): 1016 return self._exports_to_keep 1017 1018 @property 1019 def signature_fn(self): 1020 return self._signature_fn 1021 1022 @property 1023 def last_export_dir(self): 1024 """Returns the directory containing the last completed export. 1025 1026 Returns: 1027 The string path to the exported directory. NB: this functionality was 1028 added on 2016/09/25; clients that depend on the return value may need 1029 to handle the case where this function returns None because the 1030 estimator being fitted does not yet return a value during export. 1031 """ 1032 return self._last_export_dir 1033 1034 def every_n_step_end(self, step, outputs): 1035 super(ExportMonitor, self).every_n_step_end(step, outputs) 1036 try: 1037 if isinstance(self._estimator, core_estimator.Estimator): 1038 raise ValueError( 1039 "ExportMonitor does not support `tf.estimator.Estimator. `. " 1040 "Please pass an ExportStrategy to Experiment instead.") 1041 self._last_export_dir = self._estimator.export( 1042 self.export_dir, 1043 exports_to_keep=self.exports_to_keep, 1044 signature_fn=self.signature_fn, 1045 input_fn=self._input_fn, 1046 default_batch_size=self._default_batch_size, 1047 input_feature_key=self._input_feature_key, 1048 use_deprecated_input_fn=self._use_deprecated_input_fn) 1049 except RuntimeError: 1050 # Currently we are not syncronized with saving checkpoints, which leads to 1051 # runtime errors when we are calling export on the same global step. 1052 # Exports depend on saved checkpoints for constructing the graph and 1053 # getting the global step from the graph instance saved in the checkpoint. 1054 # If the checkpoint is stale with respect to current step, the global step 1055 # is taken to be the last saved checkpoint's global step and exporter 1056 # doesn't export the same checkpoint again with the following error. 1057 logging.info("Skipping exporting because the existing checkpoint has " 1058 "already been exported. " 1059 "Consider exporting less frequently.") 1060 1061 def end(self, session=None): 1062 super(ExportMonitor, self).end(session=session) 1063 latest_path = checkpoint_management.latest_checkpoint( 1064 self._estimator.model_dir) 1065 if latest_path is None: 1066 logging.info("Skipping export at the end since model has not been saved " 1067 "yet.") 1068 return 1069 if isinstance(self._estimator, core_estimator.Estimator): 1070 raise ValueError( 1071 "ExportMonitor does not support `tf.estimator.Estimator. `. " 1072 "Please pass an ExportStrategy to Experiment instead.") 1073 try: 1074 self._last_export_dir = self._estimator.export( 1075 self.export_dir, 1076 exports_to_keep=self.exports_to_keep, 1077 signature_fn=self.signature_fn, 1078 input_fn=self._input_fn, 1079 default_batch_size=self._default_batch_size, 1080 input_feature_key=self._input_feature_key, 1081 use_deprecated_input_fn=self._use_deprecated_input_fn) 1082 except RuntimeError: 1083 logging.info("Skipping exporting for the same step.") 1084 1085 1086class CheckpointSaver(BaseMonitor): 1087 """Saves checkpoints every N steps or N seconds. 1088 1089 THIS CLASS IS DEPRECATED. See 1090 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 1091 for general migration instructions. 1092 """ 1093 1094 def __init__(self, 1095 checkpoint_dir, 1096 save_secs=None, 1097 save_steps=None, 1098 saver=None, 1099 checkpoint_basename="model.ckpt", 1100 scaffold=None): 1101 """Initialize CheckpointSaver monitor. 1102 1103 Args: 1104 checkpoint_dir: `str`, base directory for the checkpoint files. 1105 save_secs: `int`, save every N secs. 1106 save_steps: `int`, save every N steps. 1107 saver: `Saver` object, used for saving. 1108 checkpoint_basename: `str`, base name for the checkpoint files. 1109 scaffold: `Scaffold`, use to get saver object. 1110 1111 Raises: 1112 ValueError: If both `save_steps` and `save_secs` are not `None`. 1113 ValueError: If both `save_steps` and `save_secs` are `None`. 1114 """ 1115 logging.info("Create CheckpointSaver.") 1116 super(CheckpointSaver, self).__init__() 1117 self._saver = saver 1118 self._summary_writer = core_summary.FileWriterCache.get(checkpoint_dir) 1119 self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) 1120 self._scaffold = scaffold 1121 self._save_secs = save_secs 1122 self._save_steps = save_steps 1123 self._last_saved_time = None 1124 self._last_begin_step = None 1125 self._last_saved_step = None 1126 1127 if save_steps is None and save_secs is None: 1128 raise ValueError("Either save_steps or save_secs should be provided") 1129 if (save_steps is not None) and (save_secs is not None): 1130 raise ValueError("Can not provide both save_steps and save_secs.") 1131 1132 def begin(self, max_steps=None): 1133 super(CheckpointSaver, self).begin(max_steps) 1134 self._last_saved_time = None 1135 self._last_begin_step = None 1136 self._last_saved_step = None 1137 1138 def step_begin(self, step): 1139 super(CheckpointSaver, self).step_begin(step) 1140 self._last_begin_step = step 1141 1142 def post_step(self, step, session): 1143 super(CheckpointSaver, self).post_step(step, session) 1144 if self._last_saved_time is None: 1145 self._save(step, session) 1146 1147 if self._save_steps is not None: 1148 if step >= self._last_saved_step + self._save_steps: 1149 self._save(step, session) 1150 1151 if self._save_secs is not None: 1152 if time.time() >= self._last_saved_time + self._save_secs: 1153 self._save(step, session) 1154 1155 def end(self, session=None): 1156 super(CheckpointSaver, self).end(session) 1157 self._save(self._last_begin_step, session) 1158 1159 def _save(self, step, session): 1160 """Saves the latest checkpoint.""" 1161 if step == self._last_saved_step: 1162 return 1163 logging.info("Saving checkpoints for %d into %s.", step, self._save_path) 1164 self._last_saved_time = time.time() 1165 self._last_saved_step = step 1166 if self._saver is None: 1167 self._scaffold.saver.save(session, self._save_path, global_step=step) 1168 else: 1169 self._saver.save(session, self._save_path, global_step=step) 1170 self._summary_writer.add_session_log( 1171 SessionLog( 1172 status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), 1173 step) 1174 1175 1176class StepCounter(EveryN): 1177 """Steps per second monitor. 1178 1179 THIS CLASS IS DEPRECATED. See 1180 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 1181 for general migration instructions. 1182 """ 1183 1184 def __init__(self, every_n_steps=100, output_dir=None, summary_writer=None): 1185 super(StepCounter, self).__init__(every_n_steps=every_n_steps) 1186 self._summary_tag = "global_step/sec" 1187 self._last_reported_step = None 1188 self._last_reported_time = None 1189 self._summary_writer = summary_writer 1190 if summary_writer is None and output_dir: 1191 self._summary_writer = core_summary.FileWriterCache.get(output_dir) 1192 1193 def set_estimator(self, estimator): 1194 super(StepCounter, self).set_estimator(estimator) 1195 if self._summary_writer is None: 1196 self._summary_writer = core_summary.FileWriterCache.get( 1197 estimator.model_dir) 1198 1199 def every_n_step_end(self, current_step, outputs): 1200 current_time = time.time() 1201 if self._last_reported_time is not None and self._summary_writer: 1202 added_steps = current_step - self._last_reported_step 1203 elapsed_time = current_time - self._last_reported_time 1204 steps_per_sec = added_steps / elapsed_time 1205 summary = Summary(value=[ 1206 Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec) 1207 ]) 1208 self._summary_writer.add_summary(summary, current_step) 1209 self._last_reported_step = current_step 1210 self._last_reported_time = current_time 1211 1212 1213class NanLossDuringTrainingError(RuntimeError): 1214 1215 def __str__(self): 1216 return "NaN loss during training." 1217 1218 1219class NanLoss(EveryN): 1220 """NaN Loss monitor. 1221 1222 THIS CLASS IS DEPRECATED. See 1223 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 1224 for general migration instructions. 1225 1226 Monitors loss and stops training if loss is NaN. 1227 Can either fail with exception or just stop training. 1228 """ 1229 1230 def __init__(self, loss_tensor, every_n_steps=100, fail_on_nan_loss=True): 1231 """Initializes NanLoss monitor. 1232 1233 Args: 1234 loss_tensor: `Tensor`, the loss tensor. 1235 every_n_steps: `int`, run check every this many steps. 1236 fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN. 1237 """ 1238 super(NanLoss, self).__init__(every_n_steps=every_n_steps) 1239 self._loss_tensor = loss_tensor 1240 self._fail_on_nan_loss = fail_on_nan_loss 1241 1242 def every_n_step_begin(self, step): 1243 super(NanLoss, self).every_n_step_begin(step) 1244 return [self._loss_tensor] 1245 1246 def every_n_step_end(self, step, outputs): 1247 super(NanLoss, self).every_n_step_end(step, outputs) 1248 if np.isnan(_extract_output(outputs, self._loss_tensor)): 1249 failure_message = "Model diverged with loss = NaN." 1250 if self._fail_on_nan_loss: 1251 logging.error(failure_message) 1252 raise NanLossDuringTrainingError 1253 else: 1254 logging.warning(failure_message) 1255 # We don't raise an error but we return "should stop" so we stop, but 1256 # without an exception. 1257 return True 1258 1259 1260class RunHookAdapterForMonitors(session_run_hook.SessionRunHook): 1261 """Wraps monitors into a SessionRunHook.""" 1262 1263 def __init__(self, monitors): 1264 self._monitors = monitors 1265 1266 def begin(self): 1267 self._last_step = None 1268 self._global_step_tensor = training_util.get_global_step() 1269 for m in self._monitors: 1270 m.begin(max_steps=None) 1271 1272 def before_run(self, run_context): 1273 if self._last_step is None: 1274 self._last_step = run_context.session.run(self._global_step_tensor) + 1 1275 1276 request = {self._global_step_tensor: self._global_step_tensor} 1277 monitor_fetches = [] 1278 for m in self._monitors: 1279 monitor_requests = m.step_begin(self._last_step) 1280 if monitor_requests: 1281 if not isinstance(monitor_requests, list): 1282 raise ValueError("Monitor.step_begin should return a list.") 1283 monitor_fetches.extend(monitor_requests) 1284 if monitor_fetches: 1285 request["monitors"] = dict( 1286 zip(monitor_fetches, [_as_graph_element(f) for f in monitor_fetches])) 1287 1288 return session_run_hook.SessionRunArgs(request) 1289 1290 def after_run(self, run_context, run_values): 1291 result = run_values.results[ 1292 "monitors"] if "monitors" in run_values.results else {} 1293 for m in self._monitors: 1294 induce_stop = m.step_end(self._last_step, result) 1295 if induce_stop: 1296 run_context.request_stop() 1297 1298 for m in self._monitors: 1299 m.post_step(self._last_step, run_context.session) 1300 1301 self._last_step = run_values.results[self._global_step_tensor] + 1 1302 1303 def end(self, session): 1304 self._last_step = None 1305 for m in self._monitors: 1306 if "session" in tf_inspect.getargspec(m.end).args: 1307 m.end(session=session) 1308 else: 1309 m.end() 1310 1311 1312def replace_monitors_with_hooks(monitors_or_hooks, estimator): 1313 """Wraps monitors with a hook. 1314 1315 `Monitor` is deprecated in favor of `SessionRunHook`. If you're using a 1316 monitor, you can wrap it with a hook using function. It is recommended to 1317 implement hook version of your monitor. 1318 1319 Args: 1320 monitors_or_hooks: A `list` may contain both monitors and hooks. 1321 estimator: An `Estimator` that monitor will be used with. 1322 1323 Returns: 1324 Returns a list of hooks. If there is any monitor in the given list, it is 1325 replaced by a hook. 1326 """ 1327 monitors_or_hooks = monitors_or_hooks or [] 1328 hooks = [ 1329 m for m in monitors_or_hooks 1330 if isinstance(m, session_run_hook.SessionRunHook) 1331 ] 1332 1333 deprecated_monitors = [ 1334 m for m in monitors_or_hooks 1335 if not isinstance(m, session_run_hook.SessionRunHook) 1336 ] 1337 1338 if not estimator.config.is_chief: 1339 # Prune list of monitor to the ones runnable on all workers. 1340 deprecated_monitors = [ 1341 m for m in deprecated_monitors if m.run_on_all_workers 1342 ] 1343 1344 # Setup monitors. 1345 for monitor in deprecated_monitors: 1346 monitor.set_estimator(estimator) 1347 1348 if deprecated_monitors: 1349 hooks.append(RunHookAdapterForMonitors(deprecated_monitors)) 1350 1351 return hooks 1352 1353 1354def _as_graph_element(obj): 1355 """Retrieves Graph element.""" 1356 graph = ops.get_default_graph() 1357 if not isinstance(obj, six.string_types): 1358 if not hasattr(obj, "graph") or obj.graph != graph: 1359 raise ValueError("Passed %s should have graph attribute that is equal " 1360 "to current graph %s." % (obj, graph)) 1361 return obj 1362 if ":" in obj: 1363 element = graph.as_graph_element(obj) 1364 else: 1365 element = graph.as_graph_element(obj + ":0") 1366 # Check that there is no :1 (e.g. it's single output). 1367 try: 1368 graph.as_graph_element(obj + ":1") 1369 except (KeyError, ValueError): 1370 pass 1371 else: 1372 raise ValueError("Name %s is ambiguous, " 1373 "as this `Operation` has multiple outputs " 1374 "(at least 2)." % obj) 1375 return element 1376