1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-import-not-at-top 16# pylint: disable=g-classes-have-attributes 17"""Callbacks: utilities called at certain points during model training. 18""" 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import collections 24import copy 25import csv 26import io 27import json 28import os 29import re 30import sys 31import time 32 33import numpy as np 34import six 35 36from tensorflow.core.framework import summary_pb2 37from tensorflow.python.data.ops import iterator_ops 38from tensorflow.python.distribute import collective_all_reduce_strategy 39from tensorflow.python.distribute import mirrored_strategy 40from tensorflow.python.distribute import tpu_strategy 41from tensorflow.python.eager import context 42from tensorflow.python.framework import constant_op 43from tensorflow.python.framework import dtypes 44from tensorflow.python.framework import ops 45from tensorflow.python.keras import backend as K 46from tensorflow.python.keras.distribute import distributed_file_utils 47from tensorflow.python.keras.distribute import worker_training_state 48from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule 49from tensorflow.python.keras.utils import generic_utils 50from tensorflow.python.keras.utils import tf_utils 51from tensorflow.python.keras.utils import version_utils 52from tensorflow.python.keras.utils.data_utils import Sequence 53from tensorflow.python.keras.utils.generic_utils import Progbar 54from tensorflow.python.keras.utils.io_utils import path_to_string 55from tensorflow.python.keras.utils.mode_keys import ModeKeys 56from tensorflow.python.lib.io import file_io 57from tensorflow.python.ops import array_ops 58from tensorflow.python.ops import math_ops 59from tensorflow.python.ops import summary_ops_v2 60from tensorflow.python.platform import gfile 61from tensorflow.python.platform import tf_logging as logging 62from tensorflow.python.profiler import profiler_v2 as profiler 63from tensorflow.python.saved_model import save_options as save_options_lib 64from tensorflow.python.training import checkpoint_management 65from tensorflow.python.training.saving import checkpoint_options as checkpoint_options_lib 66from tensorflow.python.util import nest 67from tensorflow.python.util.tf_export import keras_export 68from tensorflow.tools.docs import doc_controls 69 70try: 71 import requests 72except ImportError: 73 requests = None 74 75 76# Note: `configure_callbacks` is only used in TF1. 77def configure_callbacks(callbacks, 78 model, 79 do_validation=False, 80 batch_size=None, 81 epochs=None, 82 steps_per_epoch=None, 83 samples=None, 84 verbose=1, 85 count_mode='steps', 86 mode=ModeKeys.TRAIN): 87 """Configures callbacks for use in various training loops. 88 89 Args: 90 callbacks: List of Callbacks. 91 model: Model being trained. 92 do_validation: Whether or not validation loop will be run. 93 batch_size: Number of samples per batch. 94 epochs: Number of epoch to train. 95 steps_per_epoch: Number of batches to run per training epoch. 96 samples: Number of training samples. 97 verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger. 98 count_mode: One of 'steps' or 'samples'. Per-batch or per-sample count. 99 mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT. 100 Which loop mode to configure callbacks for. 101 102 Returns: 103 Instance of CallbackList used to control all Callbacks. 104 """ 105 # Check if callbacks have already been configured. 106 if isinstance(callbacks, CallbackList): 107 return callbacks 108 109 if not callbacks: 110 callbacks = [] 111 112 # Add additional callbacks during training. 113 if mode == ModeKeys.TRAIN: 114 model.history = History() 115 callbacks = [BaseLogger()] + (callbacks or []) + [model.history] 116 if verbose: 117 callbacks.append(ProgbarLogger(count_mode)) 118 callback_list = CallbackList(callbacks) 119 120 # Set callback model 121 callback_model = model._get_callback_model() # pylint: disable=protected-access 122 callback_list.set_model(callback_model) 123 124 set_callback_parameters( 125 callback_list, 126 model, 127 do_validation=do_validation, 128 batch_size=batch_size, 129 epochs=epochs, 130 steps_per_epoch=steps_per_epoch, 131 samples=samples, 132 verbose=verbose, 133 mode=mode) 134 135 callback_list.model.stop_training = False 136 return callback_list 137 138 139def set_callback_parameters(callback_list, 140 model, 141 do_validation=False, 142 batch_size=None, 143 epochs=None, 144 steps_per_epoch=None, 145 samples=None, 146 verbose=1, 147 mode=ModeKeys.TRAIN): 148 """Sets callback parameters. 149 150 Args: 151 callback_list: CallbackList instance. 152 model: Model being trained. 153 do_validation: Whether or not validation loop will be run. 154 batch_size: Number of samples per batch. 155 epochs: Number of epoch to train. 156 steps_per_epoch: Number of batches to run per training epoch. 157 samples: Number of training samples. 158 verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger. 159 mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT. 160 Which loop mode to configure callbacks for. 161 """ 162 metric_names = model.metrics_names 163 for cbk in callback_list: 164 if isinstance(cbk, (BaseLogger, ProgbarLogger)): 165 cbk.stateful_metrics = metric_names[1:] # Exclude `loss` 166 167 # Set callback parameters 168 callback_metrics = [] 169 # When we have deferred build scenario with iterator input, we will compile 170 # when we standardize first batch of data. 171 if mode != ModeKeys.PREDICT: 172 callback_metrics = copy.copy(metric_names) 173 if do_validation: 174 callback_metrics += ['val_' + n for n in metric_names] 175 callback_params = { 176 'batch_size': batch_size, 177 'epochs': epochs, 178 'steps': steps_per_epoch, 179 'samples': samples, 180 'verbose': verbose, 181 'do_validation': do_validation, 182 'metrics': callback_metrics, 183 } 184 callback_list.set_params(callback_params) 185 186 187def _is_generator_like(data): 188 """Checks if data is a generator, Sequence, or Iterator.""" 189 return (hasattr(data, '__next__') or hasattr(data, 'next') or isinstance( 190 data, (Sequence, iterator_ops.Iterator, iterator_ops.IteratorBase))) 191 192 193def make_logs(model, logs, outputs, mode, prefix=''): 194 """Computes logs for sending to `on_batch_end` methods.""" 195 metric_names = model.metrics_names 196 if mode in {ModeKeys.TRAIN, ModeKeys.TEST} and metric_names: 197 for label, output in zip(metric_names, outputs): 198 logs[prefix + label] = output 199 else: 200 logs['outputs'] = outputs 201 return logs 202 203 204@keras_export('keras.callbacks.CallbackList') 205class CallbackList(object): 206 """Container abstracting a list of callbacks.""" 207 208 def __init__(self, 209 callbacks=None, 210 add_history=False, 211 add_progbar=False, 212 model=None, 213 **params): 214 """Container for `Callback` instances. 215 216 This object wraps a list of `Callback` instances, making it possible 217 to call them all at once via a single endpoint 218 (e.g. `callback_list.on_epoch_end(...)`). 219 220 Args: 221 callbacks: List of `Callback` instances. 222 add_history: Whether a `History` callback should be added, if one does not 223 already exist in the `callbacks` list. 224 add_progbar: Whether a `ProgbarLogger` callback should be added, if one 225 does not already exist in the `callbacks` list. 226 model: The `Model` these callbacks are used with. 227 **params: If provided, parameters will be passed to each `Callback` via 228 `Callback.set_params`. 229 """ 230 self.callbacks = nest.flatten(callbacks) if callbacks else [] 231 self._add_default_callbacks(add_history, add_progbar) 232 233 if model: 234 self.set_model(model) 235 if params: 236 self.set_params(params) 237 238 # Performance optimization: determines if batch hooks need to be called. 239 # pylint: disable=protected-access 240 self._should_call_train_batch_hooks = any( 241 cb._implements_train_batch_hooks() for cb in self.callbacks) 242 self._should_call_test_batch_hooks = any( 243 cb._implements_test_batch_hooks() for cb in self.callbacks) 244 self._should_call_predict_batch_hooks = any( 245 cb._implements_predict_batch_hooks() for cb in self.callbacks) 246 # pylint: enable=protected-access 247 248 # Performance check: Check batch hooks for slowness compared to batch time. 249 # Only run check for custom callbacks (i.e. not present in this file). 250 self._check_timing = any([cbk.__class__.__name__ not in globals() 251 for cbk in self.callbacks]) 252 self._num_batches_for_timing_check = 5 253 self._hook_times = {} 254 self._batch_start_time = None 255 self._batch_times = [] 256 257 def _add_default_callbacks(self, add_history, add_progbar): 258 """Adds `Callback`s that are always present.""" 259 self._progbar = None 260 self._history = None 261 262 for cb in self.callbacks: 263 if isinstance(cb, ProgbarLogger): 264 self._progbar = cb 265 elif isinstance(cb, History): 266 self._history = cb 267 268 if self._progbar is None and add_progbar: 269 self._progbar = ProgbarLogger(count_mode='steps') 270 self.callbacks.insert(0, self._progbar) 271 272 if self._history is None and add_history: 273 self._history = History() 274 self.callbacks.append(self._history) 275 276 def append(self, callback): 277 self.callbacks.append(callback) 278 279 def set_params(self, params): 280 self.params = params 281 for callback in self.callbacks: 282 callback.set_params(params) 283 284 def set_model(self, model): 285 self.model = model 286 if self._history: 287 model.history = self._history 288 for callback in self.callbacks: 289 callback.set_model(model) 290 291 def _call_batch_hook(self, mode, hook, batch, logs=None): 292 """Helper function for all batch_{begin | end} methods.""" 293 if not self.callbacks: 294 return 295 296 if hook == 'begin': 297 self._call_batch_begin_hook(mode, batch, logs) 298 elif hook == 'end': 299 self._call_batch_end_hook(mode, batch, logs) 300 else: 301 raise ValueError('Unrecognized hook: {}'.format(hook)) 302 303 def _call_batch_begin_hook(self, mode, batch, logs): 304 """Helper function for `on_*_batch_begin` methods.""" 305 hook_name = 'on_{mode}_batch_begin'.format(mode=mode) 306 self._call_batch_hook_helper(hook_name, batch, logs) 307 308 if self._check_timing: 309 self._batch_start_time = time.time() 310 311 def _call_batch_end_hook(self, mode, batch, logs): 312 """Helper function for `on_*_batch_end` methods.""" 313 hook_name = 'on_{mode}_batch_end'.format(mode=mode) 314 315 if self._check_timing and batch >= 1: 316 batch_time = time.time() - self._batch_start_time 317 self._batch_times.append(batch_time) 318 319 self._call_batch_hook_helper(hook_name, batch, logs) 320 321 if len(self._batch_times) >= self._num_batches_for_timing_check: 322 end_hook_name = hook_name 323 begin_hook_name = 'on_{mode}_batch_begin'.format(mode=mode) 324 avg_batch_time = sum(self._batch_times) / len(self._batch_times) 325 avg_end_hook_time = sum(self._hook_times[end_hook_name]) / len( 326 self._hook_times[end_hook_name]) 327 avg_begin_hook_time = sum(self._hook_times[begin_hook_name]) / len( 328 self._hook_times[begin_hook_name]) 329 330 threshold_time = 1.0 * avg_batch_time 331 warning_msg = ('Callback method `{hook}` is slow compared to ' 332 'the batch time (batch time: {batch_time:.4f}s vs ' 333 '`{hook}` time: {hook_time:.4f}s). Check your callbacks.') 334 if avg_begin_hook_time > threshold_time: 335 logging.warning(warning_msg.format( 336 hook=begin_hook_name, 337 batch_time=avg_batch_time, 338 hook_time=avg_begin_hook_time)) 339 if avg_end_hook_time > threshold_time: 340 logging.warning(warning_msg.format( 341 hook=end_hook_name, 342 batch_time=avg_batch_time, 343 hook_time=avg_end_hook_time)) 344 self._check_timing = False 345 self._batch_start_time = None 346 self._batch_times = [] 347 self._hook_times = {} 348 349 def _call_batch_hook_helper(self, hook_name, batch, logs): 350 """Helper function for `on_*_batch_*` methods.""" 351 logs = logs or {} 352 numpy_logs = None 353 if self._check_timing: 354 start_time = time.time() 355 356 for callback in self.callbacks: 357 hook = getattr(callback, hook_name) 358 if getattr(callback, '_supports_tf_logs', False): 359 hook(batch, logs) 360 else: 361 if numpy_logs is None: # Only convert once. 362 numpy_logs = tf_utils.to_numpy_or_python_type(logs) 363 hook(batch, numpy_logs) 364 365 if self._check_timing: 366 if hook_name not in self._hook_times: 367 self._hook_times[hook_name] = [] 368 self._hook_times[hook_name].append(time.time() - start_time) 369 370 def _call_begin_hook(self, mode): 371 """Helper function for on_{train|test|predict}_begin methods.""" 372 if mode == ModeKeys.TRAIN: 373 self.on_train_begin() 374 elif mode == ModeKeys.TEST: 375 self.on_test_begin() 376 else: 377 self.on_predict_begin() 378 379 def _call_end_hook(self, mode): 380 """Helper function for on_{train|test|predict}_end methods.""" 381 if mode == ModeKeys.TRAIN: 382 self.on_train_end() 383 elif mode == ModeKeys.TEST: 384 self.on_test_end() 385 else: 386 self.on_predict_end() 387 388 def on_batch_begin(self, batch, logs=None): 389 if self._should_call_train_batch_hooks: 390 self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs) 391 392 def on_batch_end(self, batch, logs=None): 393 if self._should_call_train_batch_hooks: 394 self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs) 395 396 def on_epoch_begin(self, epoch, logs=None): 397 """Calls the `on_epoch_begin` methods of its callbacks. 398 399 This function should only be called during TRAIN mode. 400 401 Args: 402 epoch: Integer, index of epoch. 403 logs: Dict. Currently no data is passed to this argument for this method 404 but that may change in the future. 405 """ 406 logs = logs or {} 407 numpy_logs = None 408 for callback in self.callbacks: 409 if getattr(callback, '_supports_tf_logs', False): 410 callback.on_epoch_begin(epoch, logs) 411 else: 412 if numpy_logs is None: # Only convert once. 413 numpy_logs = tf_utils.to_numpy_or_python_type(logs) 414 callback.on_epoch_begin(epoch, numpy_logs) 415 416 def on_epoch_end(self, epoch, logs=None): 417 """Calls the `on_epoch_end` methods of its callbacks. 418 419 This function should only be called during TRAIN mode. 420 421 Args: 422 epoch: Integer, index of epoch. 423 logs: Dict, metric results for this training epoch, and for the 424 validation epoch if validation is performed. Validation result keys 425 are prefixed with `val_`. 426 """ 427 logs = logs or {} 428 numpy_logs = None 429 for callback in self.callbacks: 430 if getattr(callback, '_supports_tf_logs', False): 431 callback.on_epoch_end(epoch, logs) 432 else: 433 if numpy_logs is None: # Only convert once. 434 numpy_logs = tf_utils.to_numpy_or_python_type(logs) 435 callback.on_epoch_end(epoch, numpy_logs) 436 437 def on_train_batch_begin(self, batch, logs=None): 438 """Calls the `on_train_batch_begin` methods of its callbacks. 439 440 Args: 441 batch: Integer, index of batch within the current epoch. 442 logs: Dict, contains the return value of `model.train_step`. Typically, 443 the values of the `Model`'s metrics are returned. Example: 444 `{'loss': 0.2, 'accuracy': 0.7}`. 445 """ 446 if self._should_call_train_batch_hooks: 447 self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs) 448 449 def on_train_batch_end(self, batch, logs=None): 450 """Calls the `on_train_batch_end` methods of its callbacks. 451 452 Args: 453 batch: Integer, index of batch within the current epoch. 454 logs: Dict. Aggregated metric results up until this batch. 455 """ 456 if self._should_call_train_batch_hooks: 457 self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs) 458 459 def on_test_batch_begin(self, batch, logs=None): 460 """Calls the `on_test_batch_begin` methods of its callbacks. 461 462 Args: 463 batch: Integer, index of batch within the current epoch. 464 logs: Dict, contains the return value of `model.test_step`. Typically, 465 the values of the `Model`'s metrics are returned. Example: 466 `{'loss': 0.2, 'accuracy': 0.7}`. 467 """ 468 if self._should_call_test_batch_hooks: 469 self._call_batch_hook(ModeKeys.TEST, 'begin', batch, logs=logs) 470 471 def on_test_batch_end(self, batch, logs=None): 472 """Calls the `on_test_batch_end` methods of its callbacks. 473 474 Args: 475 batch: Integer, index of batch within the current epoch. 476 logs: Dict. Aggregated metric results up until this batch. 477 """ 478 if self._should_call_test_batch_hooks: 479 self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs) 480 481 def on_predict_batch_begin(self, batch, logs=None): 482 """Calls the `on_predict_batch_begin` methods of its callbacks. 483 484 Args: 485 batch: Integer, index of batch within the current epoch. 486 logs: Dict, contains the return value of `model.predict_step`, 487 it typically returns a dict with a key 'outputs' containing 488 the model's outputs. 489 """ 490 if self._should_call_predict_batch_hooks: 491 self._call_batch_hook(ModeKeys.PREDICT, 'begin', batch, logs=logs) 492 493 def on_predict_batch_end(self, batch, logs=None): 494 """Calls the `on_predict_batch_end` methods of its callbacks. 495 496 Args: 497 batch: Integer, index of batch within the current epoch. 498 logs: Dict. Aggregated metric results up until this batch. 499 """ 500 if self._should_call_predict_batch_hooks: 501 self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs) 502 503 def on_train_begin(self, logs=None): 504 """Calls the `on_train_begin` methods of its callbacks. 505 506 Args: 507 logs: Dict. Currently no data is passed to this argument for this method 508 but that may change in the future. 509 """ 510 logs = logs or {} 511 numpy_logs = None 512 for callback in self.callbacks: 513 if getattr(callback, '_supports_tf_logs', False): 514 callback.on_train_begin(logs) 515 else: 516 if numpy_logs is None: # Only convert once. 517 numpy_logs = tf_utils.to_numpy_or_python_type(logs) 518 callback.on_train_begin(numpy_logs) 519 520 def on_train_end(self, logs=None): 521 """Calls the `on_train_end` methods of its callbacks. 522 523 Args: 524 logs: Dict. Currently no data is passed to this argument for this method 525 but that may change in the future. 526 """ 527 logs = logs or {} 528 numpy_logs = None 529 for callback in self.callbacks: 530 if getattr(callback, '_supports_tf_logs', False): 531 callback.on_train_end(logs) 532 else: 533 if numpy_logs is None: # Only convert once. 534 numpy_logs = tf_utils.to_numpy_or_python_type(logs) 535 callback.on_train_end(numpy_logs) 536 537 def on_test_begin(self, logs=None): 538 """Calls the `on_test_begin` methods of its callbacks. 539 540 Args: 541 logs: Dict. Currently no data is passed to this argument for this method 542 but that may change in the future. 543 """ 544 logs = logs or {} 545 numpy_logs = None 546 for callback in self.callbacks: 547 if getattr(callback, '_supports_tf_logs', False): 548 callback.on_test_begin(logs) 549 else: 550 if numpy_logs is None: # Only convert once. 551 numpy_logs = tf_utils.to_numpy_or_python_type(logs) 552 callback.on_test_begin(numpy_logs) 553 554 def on_test_end(self, logs=None): 555 """Calls the `on_test_end` methods of its callbacks. 556 557 Args: 558 logs: Dict. Currently no data is passed to this argument for this method 559 but that may change in the future. 560 """ 561 logs = logs or {} 562 numpy_logs = None 563 for callback in self.callbacks: 564 if getattr(callback, '_supports_tf_logs', False): 565 callback.on_test_end(logs) 566 else: 567 if numpy_logs is None: # Only convert once. 568 numpy_logs = tf_utils.to_numpy_or_python_type(logs) 569 callback.on_test_end(numpy_logs) 570 571 def on_predict_begin(self, logs=None): 572 """Calls the 'on_predict_begin` methods of its callbacks. 573 574 Args: 575 logs: Dict. Currently no data is passed to this argument for this method 576 but that may change in the future. 577 """ 578 logs = logs or {} 579 numpy_logs = None 580 for callback in self.callbacks: 581 if getattr(callback, '_supports_tf_logs', False): 582 callback.on_predict_begin(logs) 583 else: 584 if numpy_logs is None: # Only convert once. 585 numpy_logs = tf_utils.to_numpy_or_python_type(logs) 586 callback.on_predict_begin(numpy_logs) 587 588 def on_predict_end(self, logs=None): 589 """Calls the `on_predict_end` methods of its callbacks. 590 591 Args: 592 logs: Dict. Currently no data is passed to this argument for this method 593 but that may change in the future. 594 """ 595 logs = logs or {} 596 numpy_logs = None 597 for callback in self.callbacks: 598 if getattr(callback, '_supports_tf_logs', False): 599 callback.on_predict_end(logs) 600 else: 601 if numpy_logs is None: # Only convert once. 602 numpy_logs = tf_utils.to_numpy_or_python_type(logs) 603 callback.on_predict_end(numpy_logs) 604 605 def __iter__(self): 606 return iter(self.callbacks) 607 608 609@keras_export('keras.callbacks.Callback') 610class Callback(object): 611 """Abstract base class used to build new callbacks. 612 613 Attributes: 614 params: Dict. Training parameters 615 (eg. verbosity, batch size, number of epochs...). 616 model: Instance of `keras.models.Model`. 617 Reference of the model being trained. 618 619 The `logs` dictionary that callback methods 620 take as argument will contain keys for quantities relevant to 621 the current batch or epoch (see method-specific docstrings). 622 """ 623 624 def __init__(self): 625 self.validation_data = None # pylint: disable=g-missing-from-attributes 626 self.model = None 627 # Whether this Callback should only run on the chief worker in a 628 # Multi-Worker setting. 629 # TODO(omalleyt): Make this attr public once solution is stable. 630 self._chief_worker_only = None 631 self._supports_tf_logs = False 632 633 def set_params(self, params): 634 self.params = params 635 636 def set_model(self, model): 637 self.model = model 638 639 @doc_controls.for_subclass_implementers 640 @generic_utils.default 641 def on_batch_begin(self, batch, logs=None): 642 """A backwards compatibility alias for `on_train_batch_begin`.""" 643 644 @doc_controls.for_subclass_implementers 645 @generic_utils.default 646 def on_batch_end(self, batch, logs=None): 647 """A backwards compatibility alias for `on_train_batch_end`.""" 648 649 @doc_controls.for_subclass_implementers 650 def on_epoch_begin(self, epoch, logs=None): 651 """Called at the start of an epoch. 652 653 Subclasses should override for any actions to run. This function should only 654 be called during TRAIN mode. 655 656 Args: 657 epoch: Integer, index of epoch. 658 logs: Dict. Currently no data is passed to this argument for this method 659 but that may change in the future. 660 """ 661 662 @doc_controls.for_subclass_implementers 663 def on_epoch_end(self, epoch, logs=None): 664 """Called at the end of an epoch. 665 666 Subclasses should override for any actions to run. This function should only 667 be called during TRAIN mode. 668 669 Args: 670 epoch: Integer, index of epoch. 671 logs: Dict, metric results for this training epoch, and for the 672 validation epoch if validation is performed. Validation result keys 673 are prefixed with `val_`. For training epoch, the values of the 674 `Model`'s metrics are returned. Example : `{'loss': 0.2, 'accuracy': 675 0.7}`. 676 """ 677 678 @doc_controls.for_subclass_implementers 679 @generic_utils.default 680 def on_train_batch_begin(self, batch, logs=None): 681 """Called at the beginning of a training batch in `fit` methods. 682 683 Subclasses should override for any actions to run. 684 685 Note that if the `steps_per_execution` argument to `compile` in 686 `tf.keras.Model` is set to `N`, this method will only be called every `N` 687 batches. 688 689 Args: 690 batch: Integer, index of batch within the current epoch. 691 logs: Dict, contains the return value of `model.train_step`. Typically, 692 the values of the `Model`'s metrics are returned. Example: 693 `{'loss': 0.2, 'accuracy': 0.7}`. 694 """ 695 # For backwards compatibility. 696 self.on_batch_begin(batch, logs=logs) 697 698 @doc_controls.for_subclass_implementers 699 @generic_utils.default 700 def on_train_batch_end(self, batch, logs=None): 701 """Called at the end of a training batch in `fit` methods. 702 703 Subclasses should override for any actions to run. 704 705 Note that if the `steps_per_execution` argument to `compile` in 706 `tf.keras.Model` is set to `N`, this method will only be called every `N` 707 batches. 708 709 Args: 710 batch: Integer, index of batch within the current epoch. 711 logs: Dict. Aggregated metric results up until this batch. 712 """ 713 # For backwards compatibility. 714 self.on_batch_end(batch, logs=logs) 715 716 @doc_controls.for_subclass_implementers 717 @generic_utils.default 718 def on_test_batch_begin(self, batch, logs=None): 719 """Called at the beginning of a batch in `evaluate` methods. 720 721 Also called at the beginning of a validation batch in the `fit` 722 methods, if validation data is provided. 723 724 Subclasses should override for any actions to run. 725 726 Note that if the `steps_per_execution` argument to `compile` in 727 `tf.keras.Model` is set to `N`, this method will only be called every `N` 728 batches. 729 730 Args: 731 batch: Integer, index of batch within the current epoch. 732 logs: Dict, contains the return value of `model.test_step`. Typically, 733 the values of the `Model`'s metrics are returned. Example: 734 `{'loss': 0.2, 'accuracy': 0.7}`. 735 """ 736 737 @doc_controls.for_subclass_implementers 738 @generic_utils.default 739 def on_test_batch_end(self, batch, logs=None): 740 """Called at the end of a batch in `evaluate` methods. 741 742 Also called at the end of a validation batch in the `fit` 743 methods, if validation data is provided. 744 745 Subclasses should override for any actions to run. 746 747 Note that if the `steps_per_execution` argument to `compile` in 748 `tf.keras.Model` is set to `N`, this method will only be called every `N` 749 batches. 750 751 Args: 752 batch: Integer, index of batch within the current epoch. 753 logs: Dict. Aggregated metric results up until this batch. 754 """ 755 756 @doc_controls.for_subclass_implementers 757 @generic_utils.default 758 def on_predict_batch_begin(self, batch, logs=None): 759 """Called at the beginning of a batch in `predict` methods. 760 761 Subclasses should override for any actions to run. 762 763 Note that if the `steps_per_execution` argument to `compile` in 764 `tf.keras.Model` is set to `N`, this method will only be called every `N` 765 batches. 766 767 Args: 768 batch: Integer, index of batch within the current epoch. 769 logs: Dict, contains the return value of `model.predict_step`, 770 it typically returns a dict with a key 'outputs' containing 771 the model's outputs. 772 """ 773 774 @doc_controls.for_subclass_implementers 775 @generic_utils.default 776 def on_predict_batch_end(self, batch, logs=None): 777 """Called at the end of a batch in `predict` methods. 778 779 Subclasses should override for any actions to run. 780 781 Note that if the `steps_per_execution` argument to `compile` in 782 `tf.keras.Model` is set to `N`, this method will only be called every `N` 783 batches. 784 785 Args: 786 batch: Integer, index of batch within the current epoch. 787 logs: Dict. Aggregated metric results up until this batch. 788 """ 789 790 @doc_controls.for_subclass_implementers 791 def on_train_begin(self, logs=None): 792 """Called at the beginning of training. 793 794 Subclasses should override for any actions to run. 795 796 Args: 797 logs: Dict. Currently no data is passed to this argument for this method 798 but that may change in the future. 799 """ 800 801 @doc_controls.for_subclass_implementers 802 def on_train_end(self, logs=None): 803 """Called at the end of training. 804 805 Subclasses should override for any actions to run. 806 807 Args: 808 logs: Dict. Currently the output of the last call to `on_epoch_end()` 809 is passed to this argument for this method but that may change in 810 the future. 811 """ 812 813 @doc_controls.for_subclass_implementers 814 def on_test_begin(self, logs=None): 815 """Called at the beginning of evaluation or validation. 816 817 Subclasses should override for any actions to run. 818 819 Args: 820 logs: Dict. Currently no data is passed to this argument for this method 821 but that may change in the future. 822 """ 823 824 @doc_controls.for_subclass_implementers 825 def on_test_end(self, logs=None): 826 """Called at the end of evaluation or validation. 827 828 Subclasses should override for any actions to run. 829 830 Args: 831 logs: Dict. Currently the output of the last call to 832 `on_test_batch_end()` is passed to this argument for this method 833 but that may change in the future. 834 """ 835 836 @doc_controls.for_subclass_implementers 837 def on_predict_begin(self, logs=None): 838 """Called at the beginning of prediction. 839 840 Subclasses should override for any actions to run. 841 842 Args: 843 logs: Dict. Currently no data is passed to this argument for this method 844 but that may change in the future. 845 """ 846 847 @doc_controls.for_subclass_implementers 848 def on_predict_end(self, logs=None): 849 """Called at the end of prediction. 850 851 Subclasses should override for any actions to run. 852 853 Args: 854 logs: Dict. Currently no data is passed to this argument for this method 855 but that may change in the future. 856 """ 857 858 def _implements_train_batch_hooks(self): 859 """Determines if this Callback should be called for each train batch.""" 860 return (not generic_utils.is_default(self.on_batch_begin) or 861 not generic_utils.is_default(self.on_batch_end) or 862 not generic_utils.is_default(self.on_train_batch_begin) or 863 not generic_utils.is_default(self.on_train_batch_end)) 864 865 def _implements_test_batch_hooks(self): 866 """Determines if this Callback should be called for each test batch.""" 867 return (not generic_utils.is_default(self.on_test_batch_begin) or 868 not generic_utils.is_default(self.on_test_batch_end)) 869 870 def _implements_predict_batch_hooks(self): 871 """Determines if this Callback should be called for each predict batch.""" 872 return (not generic_utils.is_default(self.on_predict_batch_begin) or 873 not generic_utils.is_default(self.on_predict_batch_end)) 874 875 876@keras_export('keras.callbacks.BaseLogger') 877class BaseLogger(Callback): 878 """Callback that accumulates epoch averages of metrics. 879 880 This callback is automatically applied to every Keras model. 881 882 Args: 883 stateful_metrics: Iterable of string names of metrics that 884 should *not* be averaged over an epoch. 885 Metrics in this list will be logged as-is in `on_epoch_end`. 886 All others will be averaged in `on_epoch_end`. 887 """ 888 889 def __init__(self, stateful_metrics=None): 890 super(BaseLogger, self).__init__() 891 self.stateful_metrics = set(stateful_metrics or []) 892 893 def on_epoch_begin(self, epoch, logs=None): 894 self.seen = 0 895 self.totals = {} 896 897 def on_batch_end(self, batch, logs=None): 898 logs = logs or {} 899 batch_size = logs.get('size', 0) 900 # In case of distribution strategy we can potentially run multiple steps 901 # at the same time, we should account for that in the `seen` calculation. 902 num_steps = logs.get('num_steps', 1) 903 self.seen += batch_size * num_steps 904 905 for k, v in logs.items(): 906 if k in self.stateful_metrics: 907 self.totals[k] = v 908 else: 909 if k in self.totals: 910 self.totals[k] += v * batch_size 911 else: 912 self.totals[k] = v * batch_size 913 914 def on_epoch_end(self, epoch, logs=None): 915 if logs is not None: 916 for k in self.params['metrics']: 917 if k in self.totals: 918 # Make value available to next callbacks. 919 if k in self.stateful_metrics: 920 logs[k] = self.totals[k] 921 else: 922 logs[k] = self.totals[k] / self.seen 923 924 925@keras_export('keras.callbacks.TerminateOnNaN') 926class TerminateOnNaN(Callback): 927 """Callback that terminates training when a NaN loss is encountered. 928 """ 929 930 def __init__(self): 931 super(TerminateOnNaN, self).__init__() 932 self._supports_tf_logs = True 933 934 def on_batch_end(self, batch, logs=None): 935 logs = logs or {} 936 loss = logs.get('loss') 937 if loss is not None: 938 loss = tf_utils.to_numpy_or_python_type(loss) 939 if np.isnan(loss) or np.isinf(loss): 940 print('Batch %d: Invalid loss, terminating training' % (batch)) 941 self.model.stop_training = True 942 943 944@keras_export('keras.callbacks.ProgbarLogger') 945class ProgbarLogger(Callback): 946 """Callback that prints metrics to stdout. 947 948 Args: 949 count_mode: One of `"steps"` or `"samples"`. 950 Whether the progress bar should 951 count samples seen or steps (batches) seen. 952 stateful_metrics: Iterable of string names of metrics that 953 should *not* be averaged over an epoch. 954 Metrics in this list will be logged as-is. 955 All others will be averaged over time (e.g. loss, etc). 956 If not provided, defaults to the `Model`'s metrics. 957 958 Raises: 959 ValueError: In case of invalid `count_mode`. 960 """ 961 962 def __init__(self, count_mode='samples', stateful_metrics=None): 963 super(ProgbarLogger, self).__init__() 964 self._supports_tf_logs = True 965 if count_mode == 'samples': 966 self.use_steps = False 967 elif count_mode == 'steps': 968 self.use_steps = True 969 else: 970 raise ValueError('Unknown `count_mode`: ' + str(count_mode)) 971 # Defaults to all Model's metrics except for loss. 972 self.stateful_metrics = set(stateful_metrics) if stateful_metrics else None 973 974 self.seen = 0 975 self.progbar = None 976 self.target = None 977 self.verbose = 1 978 self.epochs = 1 979 980 self._train_step, self._test_step, self._predict_step = None, None, None 981 self._call_batch_hooks = True 982 983 self._called_in_fit = False 984 985 def set_params(self, params): 986 self.verbose = params['verbose'] 987 self.epochs = params['epochs'] 988 if self.use_steps and 'steps' in params: 989 self.target = params['steps'] 990 elif not self.use_steps and 'samples' in params: 991 self.target = params['samples'] 992 else: 993 self.target = None # Will be inferred at the end of the first epoch. 994 995 self._call_batch_hooks = self.verbose == 1 996 if self.target is None: 997 try: 998 self._train_step = self.model._train_counter # pylint: disable=protected-access 999 self._test_step = self.model._test_counter # pylint: disable=protected-access 1000 self._predict_step = self.model._predict_counter # pylint: disable=protected-access 1001 except AttributeError: 1002 self._call_batch_hooks = True 1003 1004 def on_train_begin(self, logs=None): 1005 # When this logger is called inside `fit`, validation is silent. 1006 self._called_in_fit = True 1007 1008 def on_test_begin(self, logs=None): 1009 if not self._called_in_fit: 1010 self._reset_progbar() 1011 self._maybe_init_progbar() 1012 1013 def on_predict_begin(self, logs=None): 1014 self._reset_progbar() 1015 self._maybe_init_progbar() 1016 1017 def on_epoch_begin(self, epoch, logs=None): 1018 self._reset_progbar() 1019 self._maybe_init_progbar() 1020 if self.verbose and self.epochs > 1: 1021 print('Epoch %d/%d' % (epoch + 1, self.epochs)) 1022 1023 def on_train_batch_end(self, batch, logs=None): 1024 self._batch_update_progbar(batch, logs) 1025 1026 def on_test_batch_end(self, batch, logs=None): 1027 if not self._called_in_fit: 1028 self._batch_update_progbar(batch, logs) 1029 1030 def on_predict_batch_end(self, batch, logs=None): 1031 # Don't pass prediction results. 1032 self._batch_update_progbar(batch, None) 1033 1034 def on_epoch_end(self, epoch, logs=None): 1035 self._finalize_progbar(logs, self._train_step) 1036 1037 def on_test_end(self, logs=None): 1038 if not self._called_in_fit: 1039 self._finalize_progbar(logs, self._test_step) 1040 1041 def on_predict_end(self, logs=None): 1042 self._finalize_progbar(logs, self._predict_step) 1043 1044 def _reset_progbar(self): 1045 self.seen = 0 1046 self.progbar = None 1047 1048 def _maybe_init_progbar(self): 1049 if self.stateful_metrics is None: 1050 if self.model: 1051 self.stateful_metrics = set(m.name for m in self.model.metrics) 1052 else: 1053 self.stateful_metrics = set() 1054 1055 if self.progbar is None: 1056 self.progbar = Progbar( 1057 target=self.target, 1058 verbose=self.verbose, 1059 stateful_metrics=self.stateful_metrics, 1060 unit_name='step' if self.use_steps else 'sample') 1061 1062 def _implements_train_batch_hooks(self): 1063 return self._call_batch_hooks 1064 1065 def _implements_test_batch_hooks(self): 1066 return self._call_batch_hooks 1067 1068 def _implements_predict_batch_hooks(self): 1069 return self._call_batch_hooks 1070 1071 def _batch_update_progbar(self, batch, logs=None): 1072 """Updates the progbar.""" 1073 logs = logs or {} 1074 self._maybe_init_progbar() 1075 if self.use_steps: 1076 self.seen = batch + 1 # One-indexed. 1077 else: 1078 # v1 path only. 1079 logs = copy.copy(logs) 1080 batch_size = logs.pop('size', 0) 1081 num_steps = logs.pop('num_steps', 1) 1082 logs.pop('batch', None) 1083 add_seen = num_steps * batch_size 1084 self.seen += add_seen 1085 1086 if self.verbose == 1: 1087 # Only block async when verbose = 1. 1088 logs = tf_utils.to_numpy_or_python_type(logs) 1089 self.progbar.update(self.seen, list(logs.items()), finalize=False) 1090 1091 def _finalize_progbar(self, logs, counter): 1092 logs = tf_utils.to_numpy_or_python_type(logs or {}) 1093 if self.target is None: 1094 if counter is not None: 1095 counter = counter.numpy() 1096 if not self.use_steps: 1097 counter *= logs.get('size', 1) 1098 self.target = counter or self.seen 1099 self.progbar.target = self.target 1100 self.progbar.update(self.target, list(logs.items()), finalize=True) 1101 1102 1103@keras_export('keras.callbacks.History') 1104class History(Callback): 1105 """Callback that records events into a `History` object. 1106 1107 This callback is automatically applied to 1108 every Keras model. The `History` object 1109 gets returned by the `fit` method of models. 1110 """ 1111 1112 def __init__(self): 1113 super(History, self).__init__() 1114 self.history = {} 1115 1116 def on_train_begin(self, logs=None): 1117 self.epoch = [] 1118 1119 def on_epoch_end(self, epoch, logs=None): 1120 logs = logs or {} 1121 self.epoch.append(epoch) 1122 for k, v in logs.items(): 1123 self.history.setdefault(k, []).append(v) 1124 1125 # Set the history attribute on the model after the epoch ends. This will 1126 # make sure that the state which is set is the latest one. 1127 self.model.history = self 1128 1129 1130@keras_export('keras.callbacks.ModelCheckpoint') 1131class ModelCheckpoint(Callback): 1132 """Callback to save the Keras model or model weights at some frequency. 1133 1134 `ModelCheckpoint` callback is used in conjunction with training using 1135 `model.fit()` to save a model or weights (in a checkpoint file) at some 1136 interval, so the model or weights can be loaded later to continue the training 1137 from the state saved. 1138 1139 A few options this callback provides include: 1140 1141 - Whether to only keep the model that has achieved the "best performance" so 1142 far, or whether to save the model at the end of every epoch regardless of 1143 performance. 1144 - Definition of 'best'; which quantity to monitor and whether it should be 1145 maximized or minimized. 1146 - The frequency it should save at. Currently, the callback supports saving at 1147 the end of every epoch, or after a fixed number of training batches. 1148 - Whether only weights are saved, or the whole model is saved. 1149 1150 Note: If you get `WARNING:tensorflow:Can save best model only with <name> 1151 available, skipping` see the description of the `monitor` argument for 1152 details on how to get this right. 1153 1154 Example: 1155 1156 ```python 1157 model.compile(loss=..., optimizer=..., 1158 metrics=['accuracy']) 1159 1160 EPOCHS = 10 1161 checkpoint_filepath = '/tmp/checkpoint' 1162 model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( 1163 filepath=checkpoint_filepath, 1164 save_weights_only=True, 1165 monitor='val_accuracy', 1166 mode='max', 1167 save_best_only=True) 1168 1169 # Model weights are saved at the end of every epoch, if it's the best seen 1170 # so far. 1171 model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback]) 1172 1173 # The model weights (that are considered the best) are loaded into the model. 1174 model.load_weights(checkpoint_filepath) 1175 ``` 1176 1177 Args: 1178 filepath: string or `PathLike`, path to save the model file. e.g. 1179 filepath = os.path.join(working_dir, 'ckpt', file_name). `filepath` 1180 can contain named formatting options, which will be filled the value of 1181 `epoch` and keys in `logs` (passed in `on_epoch_end`). For example: if 1182 `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model 1183 checkpoints will be saved with the epoch number and the validation loss 1184 in the filename. The directory of the filepath should not be reused by 1185 any other callbacks to avoid conflicts. 1186 monitor: The metric name to monitor. Typically the metrics are set by the 1187 `Model.compile` method. Note: 1188 1189 * Prefix the name with `"val_`" to monitor validation metrics. 1190 * Use `"loss"` or "`val_loss`" to monitor the model's total loss. 1191 * If you specify metrics as strings, like `"accuracy"`, pass the same 1192 string (with or without the `"val_"` prefix). 1193 * If you pass `metrics.Metric` objects, `monitor` should be set to 1194 `metric.name` 1195 * If you're not sure about the metric names you can check the contents 1196 of the `history.history` dictionary returned by 1197 `history = model.fit()` 1198 * Multi-output models set additional prefixes on the metric names. 1199 1200 verbose: verbosity mode, 0 or 1. 1201 save_best_only: if `save_best_only=True`, it only saves when the model 1202 is considered the "best" and the latest best model according to the 1203 quantity monitored will not be overwritten. If `filepath` doesn't 1204 contain formatting options like `{epoch}` then `filepath` will be 1205 overwritten by each new better model. 1206 mode: one of {'auto', 'min', 'max'}. If `save_best_only=True`, the 1207 decision to overwrite the current save file is made based on either 1208 the maximization or the minimization of the monitored quantity. 1209 For `val_acc`, this should be `max`, for `val_loss` this should be 1210 `min`, etc. In `auto` mode, the direction is automatically inferred 1211 from the name of the monitored quantity. 1212 save_weights_only: if True, then only the model's weights will be saved 1213 (`model.save_weights(filepath)`), else the full model is saved 1214 (`model.save(filepath)`). 1215 save_freq: `'epoch'` or integer. When using `'epoch'`, the callback saves 1216 the model after each epoch. When using integer, the callback saves the 1217 model at end of this many batches. If the `Model` is compiled with 1218 `steps_per_execution=N`, then the saving criteria will be 1219 checked every Nth batch. Note that if the saving isn't aligned to 1220 epochs, the monitored metric may potentially be less reliable (it 1221 could reflect as little as 1 batch, since the metrics get reset every 1222 epoch). Defaults to `'epoch'`. 1223 options: Optional `tf.train.CheckpointOptions` object if 1224 `save_weights_only` is true or optional `tf.saved_model.SaveOptions` 1225 object if `save_weights_only` is false. 1226 **kwargs: Additional arguments for backwards compatibility. Possible key 1227 is `period`. 1228 """ 1229 1230 def __init__(self, 1231 filepath, 1232 monitor='val_loss', 1233 verbose=0, 1234 save_best_only=False, 1235 save_weights_only=False, 1236 mode='auto', 1237 save_freq='epoch', 1238 options=None, 1239 **kwargs): 1240 super(ModelCheckpoint, self).__init__() 1241 self._supports_tf_logs = True 1242 self.monitor = monitor 1243 self.verbose = verbose 1244 self.filepath = path_to_string(filepath) 1245 self.save_best_only = save_best_only 1246 self.save_weights_only = save_weights_only 1247 self.save_freq = save_freq 1248 self.epochs_since_last_save = 0 1249 self._batches_seen_since_last_saving = 0 1250 self._last_batch_seen = 0 1251 1252 if save_weights_only: 1253 if options is None or isinstance( 1254 options, checkpoint_options_lib.CheckpointOptions): 1255 self._options = options or checkpoint_options_lib.CheckpointOptions() 1256 else: 1257 raise TypeError('If save_weights_only is True, then `options` must be ' 1258 'either None or a tf.train.CheckpointOptions') 1259 else: 1260 if options is None or isinstance(options, save_options_lib.SaveOptions): 1261 self._options = options or save_options_lib.SaveOptions() 1262 else: 1263 raise TypeError('If save_weights_only is False, then `options` must be' 1264 'either None or a tf.saved_model.SaveOptions') 1265 1266 # Deprecated field `load_weights_on_restart` is for loading the checkpoint 1267 # file from `filepath` at the start of `model.fit()` 1268 # TODO(rchao): Remove the arg during next breaking release. 1269 if 'load_weights_on_restart' in kwargs: 1270 self.load_weights_on_restart = kwargs['load_weights_on_restart'] 1271 logging.warning('`load_weights_on_restart` argument is deprecated. ' 1272 'Please use `model.load_weights()` for loading weights ' 1273 'before the start of `model.fit()`.') 1274 else: 1275 self.load_weights_on_restart = False 1276 1277 # Deprecated field `period` is for the number of epochs between which 1278 # the model is saved. 1279 if 'period' in kwargs: 1280 self.period = kwargs['period'] 1281 logging.warning('`period` argument is deprecated. Please use `save_freq` ' 1282 'to specify the frequency in number of batches seen.') 1283 else: 1284 self.period = 1 1285 1286 if mode not in ['auto', 'min', 'max']: 1287 logging.warning('ModelCheckpoint mode %s is unknown, ' 1288 'fallback to auto mode.', mode) 1289 mode = 'auto' 1290 1291 if mode == 'min': 1292 self.monitor_op = np.less 1293 self.best = np.Inf 1294 elif mode == 'max': 1295 self.monitor_op = np.greater 1296 self.best = -np.Inf 1297 else: 1298 if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): 1299 self.monitor_op = np.greater 1300 self.best = -np.Inf 1301 else: 1302 self.monitor_op = np.less 1303 self.best = np.Inf 1304 1305 if self.save_freq != 'epoch' and not isinstance(self.save_freq, int): 1306 raise ValueError('Unrecognized save_freq: {}'.format(self.save_freq)) 1307 1308 # Only the chief worker writes model checkpoints, but all workers 1309 # restore checkpoint at on_train_begin(). 1310 self._chief_worker_only = False 1311 1312 def set_model(self, model): 1313 self.model = model 1314 # Use name matching rather than `isinstance` to avoid circular dependencies. 1315 if (not self.save_weights_only and 1316 not model._is_graph_network and # pylint: disable=protected-access 1317 model.__class__.__name__ != 'Sequential'): 1318 self.save_weights_only = True 1319 1320 def on_train_begin(self, logs=None): 1321 if self.load_weights_on_restart: 1322 filepath_to_load = ( 1323 self._get_most_recently_modified_file_matching_pattern(self.filepath)) 1324 if (filepath_to_load is not None and 1325 self._checkpoint_exists(filepath_to_load)): 1326 try: 1327 # `filepath` may contain placeholders such as `{epoch:02d}`, and 1328 # thus it attempts to load the most recently modified file with file 1329 # name matching the pattern. 1330 self.model.load_weights(filepath_to_load) 1331 except (IOError, ValueError) as e: 1332 raise ValueError('Error loading file from {}. Reason: {}'.format( 1333 filepath_to_load, e)) 1334 1335 def _implements_train_batch_hooks(self): 1336 # Only call batch hooks when saving on batch 1337 return self.save_freq != 'epoch' 1338 1339 def on_train_batch_end(self, batch, logs=None): 1340 if self._should_save_on_batch(batch): 1341 self._save_model(epoch=self._current_epoch, logs=logs) 1342 1343 def on_epoch_begin(self, epoch, logs=None): 1344 self._current_epoch = epoch 1345 1346 def on_epoch_end(self, epoch, logs=None): 1347 self.epochs_since_last_save += 1 1348 # pylint: disable=protected-access 1349 if self.save_freq == 'epoch': 1350 self._save_model(epoch=epoch, logs=logs) 1351 1352 def _should_save_on_batch(self, batch): 1353 """Handles batch-level saving logic, supports steps_per_execution.""" 1354 if self.save_freq == 'epoch': 1355 return False 1356 1357 if batch <= self._last_batch_seen: # New epoch. 1358 add_batches = batch + 1 # batches are zero-indexed. 1359 else: 1360 add_batches = batch - self._last_batch_seen 1361 self._batches_seen_since_last_saving += add_batches 1362 self._last_batch_seen = batch 1363 1364 if self._batches_seen_since_last_saving >= self.save_freq: 1365 self._batches_seen_since_last_saving = 0 1366 return True 1367 return False 1368 1369 def _save_model(self, epoch, logs): 1370 """Saves the model. 1371 1372 Args: 1373 epoch: the epoch this iteration is in. 1374 logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`. 1375 """ 1376 logs = logs or {} 1377 1378 if isinstance(self.save_freq, 1379 int) or self.epochs_since_last_save >= self.period: 1380 # Block only when saving interval is reached. 1381 logs = tf_utils.to_numpy_or_python_type(logs) 1382 self.epochs_since_last_save = 0 1383 filepath = self._get_file_path(epoch, logs) 1384 1385 try: 1386 if self.save_best_only: 1387 current = logs.get(self.monitor) 1388 if current is None: 1389 logging.warning('Can save best model only with %s available, ' 1390 'skipping.', self.monitor) 1391 else: 1392 if self.monitor_op(current, self.best): 1393 if self.verbose > 0: 1394 print('\nEpoch %05d: %s improved from %0.5f to %0.5f,' 1395 ' saving model to %s' % (epoch + 1, self.monitor, 1396 self.best, current, filepath)) 1397 self.best = current 1398 if self.save_weights_only: 1399 self.model.save_weights( 1400 filepath, overwrite=True, options=self._options) 1401 else: 1402 self.model.save(filepath, overwrite=True, options=self._options) 1403 else: 1404 if self.verbose > 0: 1405 print('\nEpoch %05d: %s did not improve from %0.5f' % 1406 (epoch + 1, self.monitor, self.best)) 1407 else: 1408 if self.verbose > 0: 1409 print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath)) 1410 if self.save_weights_only: 1411 self.model.save_weights( 1412 filepath, overwrite=True, options=self._options) 1413 else: 1414 self.model.save(filepath, overwrite=True, options=self._options) 1415 1416 self._maybe_remove_file() 1417 except IOError as e: 1418 # `e.errno` appears to be `None` so checking the content of `e.args[0]`. 1419 if 'is a directory' in six.ensure_str(e.args[0]).lower(): 1420 raise IOError('Please specify a non-directory filepath for ' 1421 'ModelCheckpoint. Filepath used is an existing ' 1422 'directory: {}'.format(filepath)) 1423 # Re-throw the error for any other causes. 1424 raise e 1425 1426 def _get_file_path(self, epoch, logs): 1427 """Returns the file path for checkpoint.""" 1428 # pylint: disable=protected-access 1429 try: 1430 # `filepath` may contain placeholders such as `{epoch:02d}` and 1431 # `{mape:.2f}`. A mismatch between logged metrics and the path's 1432 # placeholders can cause formatting to fail. 1433 file_path = self.filepath.format(epoch=epoch + 1, **logs) 1434 except KeyError as e: 1435 raise KeyError('Failed to format this callback filepath: "{}". ' 1436 'Reason: {}'.format(self.filepath, e)) 1437 self._write_filepath = distributed_file_utils.write_filepath( 1438 file_path, self.model.distribute_strategy) 1439 return self._write_filepath 1440 1441 def _maybe_remove_file(self): 1442 # Remove the checkpoint directory in multi-worker training where this worker 1443 # should not checkpoint. It is a dummy directory previously saved for sync 1444 # distributed training. 1445 distributed_file_utils.remove_temp_dir_with_filepath( 1446 self._write_filepath, self.model.distribute_strategy) 1447 1448 def _checkpoint_exists(self, filepath): 1449 """Returns whether the checkpoint `filepath` refers to exists.""" 1450 if filepath.endswith('.h5'): 1451 return file_io.file_exists_v2(filepath) 1452 tf_saved_model_exists = file_io.file_exists_v2(filepath) 1453 tf_weights_only_checkpoint_exists = file_io.file_exists_v2( 1454 filepath + '.index') 1455 return tf_saved_model_exists or tf_weights_only_checkpoint_exists 1456 1457 def _get_most_recently_modified_file_matching_pattern(self, pattern): 1458 """Returns the most recently modified filepath matching pattern. 1459 1460 Pattern may contain python formatting placeholder. If 1461 `tf.train.latest_checkpoint()` does not return None, use that; otherwise, 1462 check for most recently modified one that matches the pattern. 1463 1464 In the rare case where there are more than one pattern-matching file having 1465 the same modified time that is most recent among all, return the filepath 1466 that is largest (by `>` operator, lexicographically using the numeric 1467 equivalents). This provides a tie-breaker when multiple files are most 1468 recent. Note that a larger `filepath` can sometimes indicate a later time of 1469 modification (for instance, when epoch/batch is used as formatting option), 1470 but not necessarily (when accuracy or loss is used). The tie-breaker is 1471 put in the logic as best effort to return the most recent, and to avoid 1472 undeterministic result. 1473 1474 Modified time of a file is obtained with `os.path.getmtime()`. 1475 1476 This utility function is best demonstrated via an example: 1477 1478 ```python 1479 file_pattern = 'f.batch{batch:02d}epoch{epoch:02d}.h5' 1480 test_dir = self.get_temp_dir() 1481 path_pattern = os.path.join(test_dir, file_pattern) 1482 file_paths = [ 1483 os.path.join(test_dir, file_name) for file_name in 1484 ['f.batch03epoch02.h5', 'f.batch02epoch02.h5', 'f.batch01epoch01.h5'] 1485 ] 1486 for file_path in file_paths: 1487 # Write something to each of the files 1488 self.assertEqual( 1489 _get_most_recently_modified_file_matching_pattern(path_pattern), 1490 file_paths[-1]) 1491 ``` 1492 1493 Args: 1494 pattern: The file pattern that may optionally contain python placeholder 1495 such as `{epoch:02d}`. 1496 1497 Returns: 1498 The most recently modified file's full filepath matching `pattern`. If 1499 `pattern` does not contain any placeholder, this returns the filepath 1500 that 1501 exactly matches `pattern`. Returns `None` if no match is found. 1502 """ 1503 dir_name = os.path.dirname(pattern) 1504 base_name = os.path.basename(pattern) 1505 base_name_regex = '^' + re.sub(r'{.*}', r'.*', base_name) + '$' 1506 1507 # If tf.train.latest_checkpoint tells us there exists a latest checkpoint, 1508 # use that as it is more robust than `os.path.getmtime()`. 1509 latest_tf_checkpoint = checkpoint_management.latest_checkpoint(dir_name) 1510 if latest_tf_checkpoint is not None and re.match( 1511 base_name_regex, os.path.basename(latest_tf_checkpoint)): 1512 return latest_tf_checkpoint 1513 1514 latest_mod_time = 0 1515 file_path_with_latest_mod_time = None 1516 n_file_with_latest_mod_time = 0 1517 file_path_with_largest_file_name = None 1518 1519 if file_io.file_exists_v2(dir_name): 1520 for file_name in os.listdir(dir_name): 1521 # Only consider if `file_name` matches the pattern. 1522 if re.match(base_name_regex, file_name): 1523 file_path = os.path.join(dir_name, file_name) 1524 mod_time = os.path.getmtime(file_path) 1525 if (file_path_with_largest_file_name is None or 1526 file_path > file_path_with_largest_file_name): 1527 file_path_with_largest_file_name = file_path 1528 if mod_time > latest_mod_time: 1529 latest_mod_time = mod_time 1530 file_path_with_latest_mod_time = file_path 1531 # In the case a file with later modified time is found, reset 1532 # the counter for the number of files with latest modified time. 1533 n_file_with_latest_mod_time = 1 1534 elif mod_time == latest_mod_time: 1535 # In the case a file has modified time tied with the most recent, 1536 # increment the counter for the number of files with latest modified 1537 # time by 1. 1538 n_file_with_latest_mod_time += 1 1539 1540 if n_file_with_latest_mod_time == 1: 1541 # Return the sole file that has most recent modified time. 1542 return file_path_with_latest_mod_time 1543 else: 1544 # If there are more than one file having latest modified time, return 1545 # the file path with the largest file name. 1546 return file_path_with_largest_file_name 1547 1548 1549@keras_export('keras.callbacks.experimental.BackupAndRestore', v1=[]) 1550class BackupAndRestore(Callback): 1551 """Callback to back up and restore the training state. 1552 1553 `BackupAndRestore` callback is intended to recover from interruptions that 1554 happened in the middle of a model.fit execution by backing up the 1555 training states in a temporary checkpoint file (based on TF CheckpointManager) 1556 at the end of each epoch. If training restarted before completion, the 1557 training state and model are restored to the most recently saved state at the 1558 beginning of a new model.fit() run. 1559 Note that user is responsible to bring jobs back up. 1560 This callback is important for the backup and restore mechanism for fault 1561 tolerance purpose. And the model to be restored from an previous checkpoint is 1562 expected to be the same as the one used to back up. If user changes arguments 1563 passed to compile or fit, the checkpoint saved for fault tolerance can become 1564 invalid. 1565 1566 Note: 1567 1. This callback is not compatible with disabling eager execution. 1568 2. A checkpoint is saved at the end of each epoch, when restoring we'll redo 1569 any partial work from an unfinished epoch in which the training got restarted 1570 (so the work done before a interruption doesn't affect the final model state). 1571 3. This works for both single worker and multi-worker mode, only 1572 MirroredStrategy and MultiWorkerMirroredStrategy are supported for now. 1573 1574 Example: 1575 1576 >>> class InterruptingCallback(tf.keras.callbacks.Callback): 1577 ... def on_epoch_begin(self, epoch, logs=None): 1578 ... if epoch == 4: 1579 ... raise RuntimeError('Interrupting!') 1580 >>> callback = tf.keras.callbacks.experimental.BackupAndRestore( 1581 ... backup_dir="/tmp/backup") 1582 >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 1583 >>> model.compile(tf.keras.optimizers.SGD(), loss='mse') 1584 >>> try: 1585 ... model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10, 1586 ... batch_size=1, callbacks=[callback, InterruptingCallback()], 1587 ... verbose=0) 1588 ... except: 1589 ... pass 1590 >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10, 1591 ... batch_size=1, callbacks=[callback], verbose=0) 1592 >>> # Only 6 more epochs are run, since first trainning got interrupted at 1593 >>> # zero-indexed epoch 4, second training will continue from 4 to 9. 1594 >>> len(history.history['loss']) 1595 6 1596 1597 Args: 1598 backup_dir: String, path to store the checkpoint. 1599 e.g. backup_dir = os.path.join(working_dir, 'backup') 1600 This is the directory in which the system stores temporary files to 1601 recover the model from jobs terminated unexpectedly. The directory 1602 cannot be reused elsewhere to store other files, e.g. by 1603 BackupAndRestore callback of another training, or by another callback 1604 (ModelCheckpoint) of the same training. 1605 """ 1606 1607 def __init__(self, backup_dir): 1608 super(BackupAndRestore, self).__init__() 1609 self.backup_dir = backup_dir 1610 self._supports_tf_logs = True 1611 self._supported_strategies = ( 1612 mirrored_strategy.MirroredStrategy, 1613 collective_all_reduce_strategy.CollectiveAllReduceStrategy, 1614 tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2) 1615 1616 if not context.executing_eagerly(): 1617 if ops.inside_function(): 1618 raise ValueError('This Callback\'s method contains Python state and ' 1619 'should be called outside of `tf.function`s.') 1620 else: # Legacy graph mode: 1621 raise ValueError( 1622 'BackupAndRestore only supports eager mode. In graph ' 1623 'mode, consider using ModelCheckpoint to manually save ' 1624 'and restore weights with `model.load_weights()` and by ' 1625 'providing `initial_epoch` in `model.fit()` for fault tolerance.') 1626 1627 # Only the chief worker writes model checkpoints, but all workers 1628 # restore checkpoint at on_train_begin(). 1629 self._chief_worker_only = False 1630 1631 def on_train_begin(self, logs=None): 1632 # TrainingState is used to manage the training state needed for 1633 # failure-recovery of a worker in training. 1634 # pylint: disable=protected-access 1635 1636 if self.model._distribution_strategy and not isinstance( 1637 self.model.distribute_strategy, self._supported_strategies): 1638 raise NotImplementedError( 1639 '%s is not supported yet. ' 1640 'Currently BackupAndRestore callback only supports empty strategy, ' 1641 'MirroredStrategy, MultiWorkerMirroredStrategy and TPUStrategy.' % 1642 type(self.model.distribute_strategy).__name__) 1643 self.model._training_state = ( 1644 worker_training_state.WorkerTrainingState(self.model, self.backup_dir)) 1645 self._training_state = self.model._training_state 1646 self._training_state.restore() 1647 1648 def on_train_end(self, logs=None): 1649 # pylint: disable=protected-access 1650 # On exit of training, delete the training state backup file that was saved 1651 # for the purpose of worker recovery. 1652 self._training_state.delete_backup() 1653 1654 # Clean up the training state. 1655 del self._training_state 1656 del self.model._training_state 1657 1658 def on_epoch_end(self, epoch, logs=None): 1659 # Back up the model and current epoch for possible future recovery. 1660 self._training_state.back_up(epoch) 1661 1662 1663@keras_export('keras.callbacks.EarlyStopping') 1664class EarlyStopping(Callback): 1665 """Stop training when a monitored metric has stopped improving. 1666 1667 Assuming the goal of a training is to minimize the loss. With this, the 1668 metric to be monitored would be `'loss'`, and mode would be `'min'`. A 1669 `model.fit()` training loop will check at end of every epoch whether 1670 the loss is no longer decreasing, considering the `min_delta` and 1671 `patience` if applicable. Once it's found no longer decreasing, 1672 `model.stop_training` is marked True and the training terminates. 1673 1674 The quantity to be monitored needs to be available in `logs` dict. 1675 To make it so, pass the loss or metrics at `model.compile()`. 1676 1677 Args: 1678 monitor: Quantity to be monitored. 1679 min_delta: Minimum change in the monitored quantity 1680 to qualify as an improvement, i.e. an absolute 1681 change of less than min_delta, will count as no 1682 improvement. 1683 patience: Number of epochs with no improvement 1684 after which training will be stopped. 1685 verbose: verbosity mode. 1686 mode: One of `{"auto", "min", "max"}`. In `min` mode, 1687 training will stop when the quantity 1688 monitored has stopped decreasing; in `"max"` 1689 mode it will stop when the quantity 1690 monitored has stopped increasing; in `"auto"` 1691 mode, the direction is automatically inferred 1692 from the name of the monitored quantity. 1693 baseline: Baseline value for the monitored quantity. 1694 Training will stop if the model doesn't show improvement over the 1695 baseline. 1696 restore_best_weights: Whether to restore model weights from 1697 the epoch with the best value of the monitored quantity. 1698 If False, the model weights obtained at the last step of 1699 training are used. An epoch will be restored regardless 1700 of the performance relative to the `baseline`. If no epoch 1701 improves on `baseline`, training will run for `patience` 1702 epochs and restore weights from the best epoch in that set. 1703 1704 Example: 1705 1706 >>> callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3) 1707 >>> # This callback will stop the training when there is no improvement in 1708 >>> # the loss for three consecutive epochs. 1709 >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 1710 >>> model.compile(tf.keras.optimizers.SGD(), loss='mse') 1711 >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), 1712 ... epochs=10, batch_size=1, callbacks=[callback], 1713 ... verbose=0) 1714 >>> len(history.history['loss']) # Only 4 epochs are run. 1715 4 1716 """ 1717 1718 def __init__(self, 1719 monitor='val_loss', 1720 min_delta=0, 1721 patience=0, 1722 verbose=0, 1723 mode='auto', 1724 baseline=None, 1725 restore_best_weights=False): 1726 super(EarlyStopping, self).__init__() 1727 1728 self.monitor = monitor 1729 self.patience = patience 1730 self.verbose = verbose 1731 self.baseline = baseline 1732 self.min_delta = abs(min_delta) 1733 self.wait = 0 1734 self.stopped_epoch = 0 1735 self.restore_best_weights = restore_best_weights 1736 self.best_weights = None 1737 1738 if mode not in ['auto', 'min', 'max']: 1739 logging.warning('EarlyStopping mode %s is unknown, ' 1740 'fallback to auto mode.', mode) 1741 mode = 'auto' 1742 1743 if mode == 'min': 1744 self.monitor_op = np.less 1745 elif mode == 'max': 1746 self.monitor_op = np.greater 1747 else: 1748 if 'acc' in self.monitor: 1749 self.monitor_op = np.greater 1750 else: 1751 self.monitor_op = np.less 1752 1753 if self.monitor_op == np.greater: 1754 self.min_delta *= 1 1755 else: 1756 self.min_delta *= -1 1757 1758 def on_train_begin(self, logs=None): 1759 # Allow instances to be re-used 1760 self.wait = 0 1761 self.stopped_epoch = 0 1762 self.best = np.Inf if self.monitor_op == np.less else -np.Inf 1763 self.best_weights = None 1764 1765 def on_epoch_end(self, epoch, logs=None): 1766 current = self.get_monitor_value(logs) 1767 if current is None: 1768 return 1769 if self.restore_best_weights and self.best_weights is None: 1770 # Restore the weights after first epoch if no progress is ever made. 1771 self.best_weights = self.model.get_weights() 1772 1773 self.wait += 1 1774 if self._is_improvement(current, self.best): 1775 self.best = current 1776 if self.restore_best_weights: 1777 self.best_weights = self.model.get_weights() 1778 # Only restart wait if we beat both the baseline and our previous best. 1779 if self.baseline is None or self._is_improvement(current, self.baseline): 1780 self.wait = 0 1781 1782 if self.wait >= self.patience: 1783 self.stopped_epoch = epoch 1784 self.model.stop_training = True 1785 if self.restore_best_weights and self.best_weights is not None: 1786 if self.verbose > 0: 1787 print('Restoring model weights from the end of the best epoch.') 1788 self.model.set_weights(self.best_weights) 1789 1790 def on_train_end(self, logs=None): 1791 if self.stopped_epoch > 0 and self.verbose > 0: 1792 print('Epoch %05d: early stopping' % (self.stopped_epoch + 1)) 1793 1794 def get_monitor_value(self, logs): 1795 logs = logs or {} 1796 monitor_value = logs.get(self.monitor) 1797 if monitor_value is None: 1798 logging.warning('Early stopping conditioned on metric `%s` ' 1799 'which is not available. Available metrics are: %s', 1800 self.monitor, ','.join(list(logs.keys()))) 1801 return monitor_value 1802 1803 def _is_improvement(self, monitor_value, reference_value): 1804 return self.monitor_op(monitor_value - self.min_delta, reference_value) 1805 1806 1807@keras_export('keras.callbacks.RemoteMonitor') 1808class RemoteMonitor(Callback): 1809 """Callback used to stream events to a server. 1810 1811 Requires the `requests` library. 1812 Events are sent to `root + '/publish/epoch/end/'` by default. Calls are 1813 HTTP POST, with a `data` argument which is a 1814 JSON-encoded dictionary of event data. 1815 If `send_as_json=True`, the content type of the request will be 1816 `"application/json"`. 1817 Otherwise the serialized JSON will be sent within a form. 1818 1819 Args: 1820 root: String; root url of the target server. 1821 path: String; path relative to `root` to which the events will be sent. 1822 field: String; JSON field under which the data will be stored. 1823 The field is used only if the payload is sent within a form 1824 (i.e. send_as_json is set to False). 1825 headers: Dictionary; optional custom HTTP headers. 1826 send_as_json: Boolean; whether the request should be 1827 sent as `"application/json"`. 1828 """ 1829 1830 def __init__(self, 1831 root='http://localhost:9000', 1832 path='/publish/epoch/end/', 1833 field='data', 1834 headers=None, 1835 send_as_json=False): 1836 super(RemoteMonitor, self).__init__() 1837 1838 self.root = root 1839 self.path = path 1840 self.field = field 1841 self.headers = headers 1842 self.send_as_json = send_as_json 1843 1844 def on_epoch_end(self, epoch, logs=None): 1845 if requests is None: 1846 raise ImportError('RemoteMonitor requires the `requests` library.') 1847 logs = logs or {} 1848 send = {} 1849 send['epoch'] = epoch 1850 for k, v in logs.items(): 1851 # np.ndarray and np.generic are not scalar types 1852 # therefore we must unwrap their scalar values and 1853 # pass to the json-serializable dict 'send' 1854 if isinstance(v, (np.ndarray, np.generic)): 1855 send[k] = v.item() 1856 else: 1857 send[k] = v 1858 try: 1859 if self.send_as_json: 1860 requests.post(self.root + self.path, json=send, headers=self.headers) 1861 else: 1862 requests.post( 1863 self.root + self.path, {self.field: json.dumps(send)}, 1864 headers=self.headers) 1865 except requests.exceptions.RequestException: 1866 logging.warning('Warning: could not reach RemoteMonitor ' 1867 'root server at ' + str(self.root)) 1868 1869 1870@keras_export('keras.callbacks.LearningRateScheduler') 1871class LearningRateScheduler(Callback): 1872 """Learning rate scheduler. 1873 1874 At the beginning of every epoch, this callback gets the updated learning rate 1875 value from `schedule` function provided at `__init__`, with the current epoch 1876 and current learning rate, and applies the updated learning rate 1877 on the optimizer. 1878 1879 Args: 1880 schedule: a function that takes an epoch index (integer, indexed from 0) 1881 and current learning rate (float) as inputs and returns a new 1882 learning rate as output (float). 1883 verbose: int. 0: quiet, 1: update messages. 1884 1885 Example: 1886 1887 >>> # This function keeps the initial learning rate for the first ten epochs 1888 >>> # and decreases it exponentially after that. 1889 >>> def scheduler(epoch, lr): 1890 ... if epoch < 10: 1891 ... return lr 1892 ... else: 1893 ... return lr * tf.math.exp(-0.1) 1894 >>> 1895 >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 1896 >>> model.compile(tf.keras.optimizers.SGD(), loss='mse') 1897 >>> round(model.optimizer.lr.numpy(), 5) 1898 0.01 1899 1900 >>> callback = tf.keras.callbacks.LearningRateScheduler(scheduler) 1901 >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), 1902 ... epochs=15, callbacks=[callback], verbose=0) 1903 >>> round(model.optimizer.lr.numpy(), 5) 1904 0.00607 1905 1906 """ 1907 1908 def __init__(self, schedule, verbose=0): 1909 super(LearningRateScheduler, self).__init__() 1910 self.schedule = schedule 1911 self.verbose = verbose 1912 1913 def on_epoch_begin(self, epoch, logs=None): 1914 if not hasattr(self.model.optimizer, 'lr'): 1915 raise ValueError('Optimizer must have a "lr" attribute.') 1916 try: # new API 1917 lr = float(K.get_value(self.model.optimizer.lr)) 1918 lr = self.schedule(epoch, lr) 1919 except TypeError: # Support for old API for backward compatibility 1920 lr = self.schedule(epoch) 1921 if not isinstance(lr, (ops.Tensor, float, np.float32, np.float64)): 1922 raise ValueError('The output of the "schedule" function ' 1923 'should be float.') 1924 if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating: 1925 raise ValueError('The dtype of Tensor should be float') 1926 K.set_value(self.model.optimizer.lr, K.get_value(lr)) 1927 if self.verbose > 0: 1928 print('\nEpoch %05d: LearningRateScheduler reducing learning ' 1929 'rate to %s.' % (epoch + 1, lr)) 1930 1931 def on_epoch_end(self, epoch, logs=None): 1932 logs = logs or {} 1933 logs['lr'] = K.get_value(self.model.optimizer.lr) 1934 1935 1936def keras_model_summary(name, data, step=None): 1937 """Writes a Keras model as JSON to as a Summary. 1938 1939 Writing the Keras model configuration allows the TensorBoard graph plugin to 1940 render a conceptual graph, as opposed to graph of ops. In case the model fails 1941 to serialize as JSON, it ignores and returns False. 1942 1943 Args: 1944 name: A name for this summary. The summary tag used for TensorBoard will be 1945 this name prefixed by any active name scopes. 1946 data: A Keras Model to write. 1947 step: Explicit `int64`-castable monotonic step value for this summary. If 1948 omitted, this defaults to `tf.summary.experimental.get_step()`, which must 1949 not be None. 1950 1951 Returns: 1952 True on success, or False if no summary was written because no default 1953 summary writer was available. 1954 1955 Raises: 1956 ValueError: if a default writer exists, but no step was provided and 1957 `tf.summary.experimental.get_step()` is None. 1958 """ 1959 summary_metadata = summary_pb2.SummaryMetadata() 1960 # Hard coding a plugin name. Please refer to go/tb-plugin-name-hardcode for 1961 # the rationale. 1962 summary_metadata.plugin_data.plugin_name = 'graph_keras_model' 1963 # version number = 1 1964 summary_metadata.plugin_data.content = b'1' 1965 1966 try: 1967 json_string = data.to_json() 1968 except Exception as exc: # pylint: disable=broad-except 1969 # An exception should not break a model code. 1970 logging.warn('Model failed to serialize as JSON. Ignoring... %s', exc) 1971 return False 1972 1973 with summary_ops_v2.summary_scope(name, 'graph_keras_model', 1974 [data, step]) as (tag, _): 1975 with ops.device('cpu:0'): 1976 tensor = constant_op.constant(json_string, dtype=dtypes.string) 1977 return summary_ops_v2.write( 1978 tag=tag, tensor=tensor, step=step, metadata=summary_metadata) 1979 1980 1981@keras_export('keras.callbacks.TensorBoard', v1=[]) 1982class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): 1983 # pylint: disable=line-too-long 1984 """Enable visualizations for TensorBoard. 1985 1986 TensorBoard is a visualization tool provided with TensorFlow. 1987 1988 This callback logs events for TensorBoard, including: 1989 1990 * Metrics summary plots 1991 * Training graph visualization 1992 * Activation histograms 1993 * Sampled profiling 1994 1995 If you have installed TensorFlow with pip, you should be able 1996 to launch TensorBoard from the command line: 1997 1998 ``` 1999 tensorboard --logdir=path_to_your_logs 2000 ``` 2001 2002 You can find more information about TensorBoard 2003 [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard). 2004 2005 Args: 2006 log_dir: the path of the directory where to save the log files to be 2007 parsed by TensorBoard. e.g. log_dir = os.path.join(working_dir, 'logs') 2008 This directory should not be reused by any other callbacks. 2009 histogram_freq: frequency (in epochs) at which to compute activation and 2010 weight histograms for the layers of the model. If set to 0, histograms 2011 won't be computed. Validation data (or split) must be specified for 2012 histogram visualizations. 2013 write_graph: whether to visualize the graph in TensorBoard. The log file 2014 can become quite large when write_graph is set to True. 2015 write_images: whether to write model weights to visualize as image in 2016 TensorBoard. 2017 write_steps_per_second: whether to log the training steps per second into 2018 Tensorboard. This supports both epoch and batch frequency logging. 2019 update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`, 2020 writes the losses and metrics to TensorBoard after each batch. The same 2021 applies for `'epoch'`. If using an integer, let's say `1000`, the 2022 callback will write the metrics and losses to TensorBoard every 1000 2023 batches. Note that writing too frequently to TensorBoard can slow down 2024 your training. 2025 profile_batch: Profile the batch(es) to sample compute characteristics. 2026 profile_batch must be a non-negative integer or a tuple of integers. 2027 A pair of positive integers signify a range of batches to profile. 2028 By default, it will profile the second batch. Set profile_batch=0 2029 to disable profiling. 2030 embeddings_freq: frequency (in epochs) at which embedding layers will be 2031 visualized. If set to 0, embeddings won't be visualized. 2032 embeddings_metadata: a dictionary which maps layer name to a file name in 2033 which metadata for this embedding layer is saved. See the 2034 [details]( 2035 https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional) 2036 about metadata files format. In case if the same metadata file is 2037 used for all embedding layers, string can be passed. 2038 2039 Examples: 2040 2041 Basic usage: 2042 2043 ```python 2044 tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs") 2045 model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) 2046 # Then run the tensorboard command to view the visualizations. 2047 ``` 2048 2049 Custom batch-level summaries in a subclassed Model: 2050 2051 ```python 2052 class MyModel(tf.keras.Model): 2053 2054 def build(self, _): 2055 self.dense = tf.keras.layers.Dense(10) 2056 2057 def call(self, x): 2058 outputs = self.dense(x) 2059 tf.summary.histogram('outputs', outputs) 2060 return outputs 2061 2062 model = MyModel() 2063 model.compile('sgd', 'mse') 2064 2065 # Make sure to set `update_freq=N` to log a batch-level summary every N batches. 2066 # In addition to any `tf.summary` contained in `Model.call`, metrics added in 2067 # `Model.compile` will be logged every N batches. 2068 tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1) 2069 model.fit(x_train, y_train, callbacks=[tb_callback]) 2070 ``` 2071 2072 Custom batch-level summaries in a Functional API Model: 2073 2074 ```python 2075 def my_summary(x): 2076 tf.summary.histogram('x', x) 2077 return x 2078 2079 inputs = tf.keras.Input(10) 2080 x = tf.keras.layers.Dense(10)(inputs) 2081 outputs = tf.keras.layers.Lambda(my_summary)(x) 2082 model = tf.keras.Model(inputs, outputs) 2083 model.compile('sgd', 'mse') 2084 2085 # Make sure to set `update_freq=N` to log a batch-level summary every N batches. 2086 # In addition to any `tf.summary` contained in `Model.call`, metrics added in 2087 # `Model.compile` will be logged every N batches. 2088 tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1) 2089 model.fit(x_train, y_train, callbacks=[tb_callback]) 2090 ``` 2091 2092 Profiling: 2093 2094 ```python 2095 # Profile a single batch, e.g. the 5th batch. 2096 tensorboard_callback = tf.keras.callbacks.TensorBoard( 2097 log_dir='./logs', profile_batch=5) 2098 model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) 2099 2100 # Profile a range of batches, e.g. from 10 to 20. 2101 tensorboard_callback = tf.keras.callbacks.TensorBoard( 2102 log_dir='./logs', profile_batch=(10,20)) 2103 model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) 2104 ``` 2105 """ 2106 2107 # pylint: enable=line-too-long 2108 2109 def __init__(self, 2110 log_dir='logs', 2111 histogram_freq=0, 2112 write_graph=True, 2113 write_images=False, 2114 write_steps_per_second=False, 2115 update_freq='epoch', 2116 profile_batch=2, 2117 embeddings_freq=0, 2118 embeddings_metadata=None, 2119 **kwargs): 2120 super(TensorBoard, self).__init__() 2121 self._supports_tf_logs = True 2122 self._validate_kwargs(kwargs) 2123 2124 self.log_dir = path_to_string(log_dir) 2125 self.histogram_freq = histogram_freq 2126 self.write_graph = write_graph 2127 self.write_images = write_images 2128 self.write_steps_per_second = write_steps_per_second 2129 self.update_freq = 1 if update_freq == 'batch' else update_freq 2130 self.embeddings_freq = embeddings_freq 2131 self.embeddings_metadata = embeddings_metadata 2132 self._init_profile_batch(profile_batch) 2133 self._epoch = 0 2134 self._global_train_batch = 0 2135 self._previous_epoch_iterations = 0 2136 self._train_accumulated_time = 0 2137 self._batch_start_time = 0 2138 2139 # Lazily initialized in order to avoid creating event files when 2140 # not needed. 2141 self._writers = {} 2142 2143 # Used to restore any existing `SummaryWriter` after training ends. 2144 self._prev_summary_state = [] 2145 2146 def _validate_kwargs(self, kwargs): 2147 """Handle arguments were supported in V1.""" 2148 if kwargs.get('write_grads', False): 2149 logging.warning('`write_grads` will be ignored in TensorFlow 2.0 ' 2150 'for the `TensorBoard` Callback.') 2151 if kwargs.get('batch_size', False): 2152 logging.warning('`batch_size` is no longer needed in the ' 2153 '`TensorBoard` Callback and will be ignored ' 2154 'in TensorFlow 2.0.') 2155 if kwargs.get('embeddings_layer_names', False): 2156 logging.warning('`embeddings_layer_names` is not supported in ' 2157 'TensorFlow 2.0. Instead, all `Embedding` layers ' 2158 'will be visualized.') 2159 if kwargs.get('embeddings_data', False): 2160 logging.warning('`embeddings_data` is not supported in TensorFlow ' 2161 '2.0. Instead, all `Embedding` variables will be ' 2162 'visualized.') 2163 2164 unrecognized_kwargs = set(kwargs.keys()) - { 2165 'write_grads', 'embeddings_layer_names', 'embeddings_data', 'batch_size' 2166 } 2167 2168 # Only allow kwargs that were supported in V1. 2169 if unrecognized_kwargs: 2170 raise ValueError('Unrecognized arguments in `TensorBoard` ' 2171 'Callback: ' + str(unrecognized_kwargs)) 2172 2173 def set_model(self, model): 2174 """Sets Keras model and writes graph if specified.""" 2175 self.model = model 2176 self._log_write_dir = self._get_log_write_dir() 2177 2178 self._train_dir = os.path.join(self._log_write_dir, 'train') 2179 self._train_step = self.model._train_counter # pylint: disable=protected-access 2180 2181 self._val_dir = os.path.join(self._log_write_dir, 'validation') 2182 self._val_step = self.model._test_counter # pylint: disable=protected-access 2183 2184 self._writers = {} # Resets writers. 2185 2186 self._should_write_train_graph = False 2187 if self.write_graph: 2188 self._write_keras_model_summary() 2189 self._should_write_train_graph = True 2190 if self.embeddings_freq: 2191 self._configure_embeddings() 2192 2193 @property 2194 def _train_writer(self): 2195 if 'train' not in self._writers: 2196 self._writers['train'] = summary_ops_v2.create_file_writer_v2( 2197 self._train_dir) 2198 return self._writers['train'] 2199 2200 @property 2201 def _val_writer(self): 2202 if 'val' not in self._writers: 2203 self._writers['val'] = summary_ops_v2.create_file_writer_v2(self._val_dir) 2204 return self._writers['val'] 2205 2206 def _get_log_write_dir(self): 2207 """For multi-worker, only chief should write, others write to '/tmp'.""" 2208 return distributed_file_utils.write_dirpath(self.log_dir, 2209 self.model.distribute_strategy) 2210 2211 def _delete_tmp_write_dir(self): 2212 """Deletes tmp write directories for multi-worker.""" 2213 distributed_file_utils.remove_temp_dirpath(self.log_dir, 2214 self.model.distribute_strategy) 2215 2216 def _write_keras_model_train_graph(self): 2217 """Writes Keras model train_function graph to TensorBoard.""" 2218 with self._train_writer.as_default(): 2219 with summary_ops_v2.record_if(True): 2220 train_fn = self.model.train_function 2221 # If the train_function is a `tf.function`, we can write out a graph 2222 if hasattr(train_fn, 'function_spec'): 2223 summary_ops_v2.graph(train_fn._concrete_stateful_fn.graph) # pylint: disable=protected-access 2224 2225 def _write_keras_model_summary(self): 2226 """Writes Keras graph network summary to TensorBoard.""" 2227 with self._train_writer.as_default(): 2228 with summary_ops_v2.record_if(True): 2229 summary_writable = ( 2230 self.model._is_graph_network or # pylint: disable=protected-access 2231 self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access 2232 if summary_writable: 2233 keras_model_summary('keras', self.model, step=0) 2234 2235 def _configure_embeddings(self): 2236 """Configure the Projector for embeddings.""" 2237 # TODO(omalleyt): Add integration tests. 2238 from google.protobuf import text_format 2239 from tensorflow.python.keras.layers import embeddings 2240 from tensorflow.python.keras.protobuf import projector_config_pb2 2241 2242 config = projector_config_pb2.ProjectorConfig() 2243 for layer in self.model.layers: 2244 if isinstance(layer, embeddings.Embedding): 2245 embedding = config.embeddings.add() 2246 # Embeddings are always the first layer, so this naming should be 2247 # consistent in any keras models checkpoints. 2248 name = 'layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE' 2249 embedding.tensor_name = name 2250 2251 if self.embeddings_metadata is not None: 2252 if isinstance(self.embeddings_metadata, str): 2253 embedding.metadata_path = self.embeddings_metadata 2254 else: 2255 if layer.name in self.embeddings_metadata.keys(): 2256 embedding.metadata_path = self.embeddings_metadata.pop(layer.name) 2257 2258 if self.embeddings_metadata and not isinstance(self.embeddings_metadata, 2259 str): 2260 raise ValueError('Unrecognized `Embedding` layer names passed to ' 2261 '`keras.callbacks.TensorBoard` `embeddings_metadata` ' 2262 'argument: ' + str(self.embeddings_metadata.keys())) 2263 2264 config_pbtxt = text_format.MessageToString(config) 2265 path = os.path.join(self._log_write_dir, 'projector_config.pbtxt') 2266 with gfile.Open(path, 'w') as f: 2267 f.write(config_pbtxt) 2268 2269 def _push_writer(self, writer, step): 2270 """Sets the default writer for custom batch-level summaries.""" 2271 if self.update_freq == 'epoch': 2272 return 2273 2274 should_record = lambda: math_ops.equal(step % self.update_freq, 0) 2275 # TODO(b/151339474): Fix deadlock when not using .value() here. 2276 summary_context = (writer.as_default(step.value()), 2277 summary_ops_v2.record_if(should_record)) 2278 self._prev_summary_state.append(summary_context) 2279 summary_context[0].__enter__() 2280 summary_context[1].__enter__() 2281 2282 def _pop_writer(self): 2283 """Pops the current writer.""" 2284 if self.update_freq == 'epoch': 2285 return 2286 2287 # See _push_writer for the content of the previous_context, which is pair 2288 # of context. 2289 previous_context = self._prev_summary_state.pop() 2290 previous_context[1].__exit__(*sys.exc_info()) 2291 previous_context[0].__exit__(*sys.exc_info()) 2292 2293 def _close_writers(self): 2294 for writer in self._writers.values(): 2295 writer.close() 2296 2297 def _init_profile_batch(self, profile_batch): 2298 """Validate profile_batch value and set the range of batches to profile. 2299 2300 Args: 2301 profile_batch: The range of batches to profile. Should be a non-negative 2302 integer or a comma separated string of pair of positive integers. A pair 2303 of positive integers signify a range of batches to profile. 2304 2305 Returns: 2306 A pair of non-negative integers specifying the start and stop batch to 2307 profile. 2308 2309 Raises: 2310 ValueError: If profile_batch is not an integer or a comma seperated pair 2311 of positive integers. 2312 2313 """ 2314 profile_batch_error_message = ( 2315 'profile_batch must be a non-negative integer or 2-tuple of positive ' 2316 'integers. A pair of positive integers signifies a range of batches ' 2317 'to profile. Found: {}'.format(profile_batch)) 2318 2319 # Support legacy way of specifying "start,stop" or "start" as str. 2320 if isinstance(profile_batch, six.string_types): 2321 profile_batch = str(profile_batch).split(',') 2322 profile_batch = nest.map_structure(int, profile_batch) 2323 2324 if isinstance(profile_batch, int): 2325 self._start_batch = profile_batch 2326 self._stop_batch = profile_batch 2327 elif isinstance(profile_batch, (tuple, list)) and len(profile_batch) == 2: 2328 self._start_batch, self._stop_batch = profile_batch 2329 else: 2330 raise ValueError(profile_batch_error_message) 2331 2332 if self._start_batch < 0 or self._stop_batch < self._start_batch: 2333 raise ValueError(profile_batch_error_message) 2334 2335 if self._start_batch > 0: 2336 # Warm up and improve the profiling accuracy. 2337 profiler.start('') 2338 profiler.stop(save=False) 2339 # True when a trace is running. 2340 self._is_tracing = False 2341 2342 # Setting `profile_batch=0` disables profiling. 2343 self._should_trace = not (self._start_batch == 0 and self._stop_batch == 0) 2344 2345 def on_train_begin(self, logs=None): 2346 self._global_train_batch = 0 2347 self._previous_epoch_iterations = 0 2348 self._train_accumulated_time = 0 2349 self._push_writer(self._train_writer, self._train_step) 2350 2351 def on_train_end(self, logs=None): 2352 self._pop_writer() 2353 2354 if self._is_tracing: 2355 self._stop_trace() 2356 2357 self._close_writers() 2358 self._delete_tmp_write_dir() 2359 2360 def on_test_begin(self, logs=None): 2361 self._push_writer(self._val_writer, self._val_step) 2362 2363 def on_test_end(self, logs=None): 2364 self._pop_writer() 2365 2366 def _implements_train_batch_hooks(self): 2367 # Only call batch hooks when tracing or write_steps_per_second are enabled 2368 return self._should_trace or self.write_steps_per_second 2369 2370 def on_train_batch_begin(self, batch, logs=None): 2371 self._global_train_batch += 1 2372 if self.write_steps_per_second: 2373 self._batch_start_time = time.time() 2374 if not self._should_trace: 2375 return 2376 2377 if self._global_train_batch == self._start_batch: 2378 self._start_trace() 2379 2380 def on_train_batch_end(self, batch, logs=None): 2381 if self._should_write_train_graph: 2382 self._write_keras_model_train_graph() 2383 self._should_write_train_graph = False 2384 if self.write_steps_per_second: 2385 batch_run_time = time.time() - self._batch_start_time 2386 self._train_accumulated_time += batch_run_time 2387 summary_ops_v2.scalar('batch_steps_per_second', 1. / batch_run_time) 2388 if not self._should_trace: 2389 return 2390 2391 if self._is_tracing and self._global_train_batch >= self._stop_batch: 2392 self._stop_trace() 2393 2394 def on_epoch_begin(self, epoch, logs=None): 2395 # Keeps track of epoch for profiling. 2396 self._epoch = epoch 2397 if self.write_steps_per_second: 2398 self._previous_epoch_iterations = self.model.optimizer.iterations.numpy() 2399 self._train_accumulated_time = 0 2400 2401 def on_epoch_end(self, epoch, logs=None): 2402 """Runs metrics and histogram summaries at epoch end.""" 2403 self._log_epoch_metrics(epoch, logs) 2404 2405 if self.histogram_freq and epoch % self.histogram_freq == 0: 2406 self._log_weights(epoch) 2407 2408 if self.embeddings_freq and epoch % self.embeddings_freq == 0: 2409 self._log_embeddings(epoch) 2410 2411 def _start_trace(self): 2412 summary_ops_v2.trace_on(graph=True, profiler=False) 2413 profiler.start(logdir=self._train_dir) 2414 self._is_tracing = True 2415 2416 def _stop_trace(self, batch=None): 2417 """Logs the trace graph to TensorBoard.""" 2418 if batch is None: 2419 batch = self._stop_batch 2420 with self._train_writer.as_default(): 2421 with summary_ops_v2.record_if(True): 2422 # TODO(b/126388999): Remove step info in the summary name. 2423 summary_ops_v2.trace_export(name='batch_%d' % batch, step=batch) 2424 profiler.stop() 2425 self._is_tracing = False 2426 2427 def _collect_learning_rate(self, logs): 2428 lr_schedule = getattr(self.model.optimizer, 'lr', None) 2429 if isinstance(lr_schedule, learning_rate_schedule.LearningRateSchedule): 2430 logs['learning_rate'] = lr_schedule(self.model.optimizer.iterations) 2431 return logs 2432 2433 def _compute_steps_per_second(self): 2434 current_iteration = self.model.optimizer.iterations.numpy() 2435 steps_per_second = ((current_iteration - self._previous_epoch_iterations) / 2436 (self._train_accumulated_time)) 2437 return steps_per_second 2438 2439 def _log_epoch_metrics(self, epoch, logs): 2440 """Writes epoch metrics out as scalar summaries. 2441 2442 Args: 2443 epoch: Int. The global step to use for TensorBoard. 2444 logs: Dict. Keys are scalar summary names, values are scalars. 2445 """ 2446 if not logs: 2447 return 2448 2449 train_logs = {k: v for k, v in logs.items() if not k.startswith('val_')} 2450 val_logs = {k: v for k, v in logs.items() if k.startswith('val_')} 2451 train_logs = self._collect_learning_rate(train_logs) 2452 if self.write_steps_per_second: 2453 train_logs['steps_per_second'] = self._compute_steps_per_second() 2454 2455 with summary_ops_v2.record_if(True): 2456 if train_logs: 2457 with self._train_writer.as_default(): 2458 for name, value in train_logs.items(): 2459 summary_ops_v2.scalar('epoch_' + name, value, step=epoch) 2460 if val_logs: 2461 with self._val_writer.as_default(): 2462 for name, value in val_logs.items(): 2463 name = name[4:] # Remove 'val_' prefix. 2464 summary_ops_v2.scalar('epoch_' + name, value, step=epoch) 2465 2466 def _log_weights(self, epoch): 2467 """Logs the weights of the Model to TensorBoard.""" 2468 with self._train_writer.as_default(): 2469 with summary_ops_v2.record_if(True): 2470 for layer in self.model.layers: 2471 for weight in layer.weights: 2472 weight_name = weight.name.replace(':', '_') 2473 summary_ops_v2.histogram(weight_name, weight, step=epoch) 2474 if self.write_images: 2475 self._log_weight_as_image(weight, weight_name, epoch) 2476 self._train_writer.flush() 2477 2478 def _log_weight_as_image(self, weight, weight_name, epoch): 2479 """Logs a weight as a TensorBoard image.""" 2480 w_img = array_ops.squeeze(weight) 2481 shape = K.int_shape(w_img) 2482 if len(shape) == 1: # Bias case 2483 w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1]) 2484 elif len(shape) == 2: # Dense layer kernel case 2485 if shape[0] > shape[1]: 2486 w_img = array_ops.transpose(w_img) 2487 shape = K.int_shape(w_img) 2488 w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1]) 2489 elif len(shape) == 3: # ConvNet case 2490 if K.image_data_format() == 'channels_last': 2491 # Switch to channels_first to display every kernel as a separate 2492 # image. 2493 w_img = array_ops.transpose(w_img, perm=[2, 0, 1]) 2494 shape = K.int_shape(w_img) 2495 w_img = array_ops.reshape(w_img, [shape[0], shape[1], shape[2], 1]) 2496 2497 shape = K.int_shape(w_img) 2498 # Not possible to handle 3D convnets etc. 2499 if len(shape) == 4 and shape[-1] in [1, 3, 4]: 2500 summary_ops_v2.image(weight_name, w_img, step=epoch) 2501 2502 def _log_embeddings(self, epoch): 2503 embeddings_ckpt = os.path.join(self._log_write_dir, 'train', 2504 'keras_embedding.ckpt-{}'.format(epoch)) 2505 self.model.save_weights(embeddings_ckpt) 2506 2507 2508@keras_export('keras.callbacks.ReduceLROnPlateau') 2509class ReduceLROnPlateau(Callback): 2510 """Reduce learning rate when a metric has stopped improving. 2511 2512 Models often benefit from reducing the learning rate by a factor 2513 of 2-10 once learning stagnates. This callback monitors a 2514 quantity and if no improvement is seen for a 'patience' number 2515 of epochs, the learning rate is reduced. 2516 2517 Example: 2518 2519 ```python 2520 reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, 2521 patience=5, min_lr=0.001) 2522 model.fit(X_train, Y_train, callbacks=[reduce_lr]) 2523 ``` 2524 2525 Args: 2526 monitor: quantity to be monitored. 2527 factor: factor by which the learning rate will be reduced. 2528 `new_lr = lr * factor`. 2529 patience: number of epochs with no improvement after which learning rate 2530 will be reduced. 2531 verbose: int. 0: quiet, 1: update messages. 2532 mode: one of `{'auto', 'min', 'max'}`. In `'min'` mode, 2533 the learning rate will be reduced when the 2534 quantity monitored has stopped decreasing; in `'max'` mode it will be 2535 reduced when the quantity monitored has stopped increasing; in `'auto'` 2536 mode, the direction is automatically inferred from the name of the 2537 monitored quantity. 2538 min_delta: threshold for measuring the new optimum, to only focus on 2539 significant changes. 2540 cooldown: number of epochs to wait before resuming normal operation after 2541 lr has been reduced. 2542 min_lr: lower bound on the learning rate. 2543 """ 2544 2545 def __init__(self, 2546 monitor='val_loss', 2547 factor=0.1, 2548 patience=10, 2549 verbose=0, 2550 mode='auto', 2551 min_delta=1e-4, 2552 cooldown=0, 2553 min_lr=0, 2554 **kwargs): 2555 super(ReduceLROnPlateau, self).__init__() 2556 2557 self.monitor = monitor 2558 if factor >= 1.0: 2559 raise ValueError('ReduceLROnPlateau ' 'does not support a factor >= 1.0.') 2560 if 'epsilon' in kwargs: 2561 min_delta = kwargs.pop('epsilon') 2562 logging.warning('`epsilon` argument is deprecated and ' 2563 'will be removed, use `min_delta` instead.') 2564 self.factor = factor 2565 self.min_lr = min_lr 2566 self.min_delta = min_delta 2567 self.patience = patience 2568 self.verbose = verbose 2569 self.cooldown = cooldown 2570 self.cooldown_counter = 0 # Cooldown counter. 2571 self.wait = 0 2572 self.best = 0 2573 self.mode = mode 2574 self.monitor_op = None 2575 self._reset() 2576 2577 def _reset(self): 2578 """Resets wait counter and cooldown counter. 2579 """ 2580 if self.mode not in ['auto', 'min', 'max']: 2581 logging.warning('Learning rate reduction mode %s is unknown, ' 2582 'fallback to auto mode.', self.mode) 2583 self.mode = 'auto' 2584 if (self.mode == 'min' or 2585 (self.mode == 'auto' and 'acc' not in self.monitor)): 2586 self.monitor_op = lambda a, b: np.less(a, b - self.min_delta) 2587 self.best = np.Inf 2588 else: 2589 self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta) 2590 self.best = -np.Inf 2591 self.cooldown_counter = 0 2592 self.wait = 0 2593 2594 def on_train_begin(self, logs=None): 2595 self._reset() 2596 2597 def on_epoch_end(self, epoch, logs=None): 2598 logs = logs or {} 2599 logs['lr'] = K.get_value(self.model.optimizer.lr) 2600 current = logs.get(self.monitor) 2601 if current is None: 2602 logging.warning('Learning rate reduction is conditioned on metric `%s` ' 2603 'which is not available. Available metrics are: %s', 2604 self.monitor, ','.join(list(logs.keys()))) 2605 2606 else: 2607 if self.in_cooldown(): 2608 self.cooldown_counter -= 1 2609 self.wait = 0 2610 2611 if self.monitor_op(current, self.best): 2612 self.best = current 2613 self.wait = 0 2614 elif not self.in_cooldown(): 2615 self.wait += 1 2616 if self.wait >= self.patience: 2617 old_lr = K.get_value(self.model.optimizer.lr) 2618 if old_lr > np.float32(self.min_lr): 2619 new_lr = old_lr * self.factor 2620 new_lr = max(new_lr, self.min_lr) 2621 K.set_value(self.model.optimizer.lr, new_lr) 2622 if self.verbose > 0: 2623 print('\nEpoch %05d: ReduceLROnPlateau reducing learning ' 2624 'rate to %s.' % (epoch + 1, new_lr)) 2625 self.cooldown_counter = self.cooldown 2626 self.wait = 0 2627 2628 def in_cooldown(self): 2629 return self.cooldown_counter > 0 2630 2631 2632@keras_export('keras.callbacks.CSVLogger') 2633class CSVLogger(Callback): 2634 """Callback that streams epoch results to a CSV file. 2635 2636 Supports all values that can be represented as a string, 2637 including 1D iterables such as `np.ndarray`. 2638 2639 Example: 2640 2641 ```python 2642 csv_logger = CSVLogger('training.log') 2643 model.fit(X_train, Y_train, callbacks=[csv_logger]) 2644 ``` 2645 2646 Args: 2647 filename: Filename of the CSV file, e.g. `'run/log.csv'`. 2648 separator: String used to separate elements in the CSV file. 2649 append: Boolean. True: append if file exists (useful for continuing 2650 training). False: overwrite existing file. 2651 """ 2652 2653 def __init__(self, filename, separator=',', append=False): 2654 self.sep = separator 2655 self.filename = path_to_string(filename) 2656 self.append = append 2657 self.writer = None 2658 self.keys = None 2659 self.append_header = True 2660 if six.PY2: 2661 self.file_flags = 'b' 2662 self._open_args = {} 2663 else: 2664 self.file_flags = '' 2665 self._open_args = {'newline': '\n'} 2666 super(CSVLogger, self).__init__() 2667 2668 def on_train_begin(self, logs=None): 2669 if self.append: 2670 if file_io.file_exists_v2(self.filename): 2671 with open(self.filename, 'r' + self.file_flags) as f: 2672 self.append_header = not bool(len(f.readline())) 2673 mode = 'a' 2674 else: 2675 mode = 'w' 2676 self.csv_file = io.open(self.filename, 2677 mode + self.file_flags, 2678 **self._open_args) 2679 2680 def on_epoch_end(self, epoch, logs=None): 2681 logs = logs or {} 2682 2683 def handle_value(k): 2684 is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0 2685 if isinstance(k, six.string_types): 2686 return k 2687 elif isinstance(k, collections.abc.Iterable) and not is_zero_dim_ndarray: 2688 return '"[%s]"' % (', '.join(map(str, k))) 2689 else: 2690 return k 2691 2692 if self.keys is None: 2693 self.keys = sorted(logs.keys()) 2694 2695 if self.model.stop_training: 2696 # We set NA so that csv parsers do not fail for this last epoch. 2697 logs = dict((k, logs[k]) if k in logs else (k, 'NA') for k in self.keys) 2698 2699 if not self.writer: 2700 2701 class CustomDialect(csv.excel): 2702 delimiter = self.sep 2703 2704 fieldnames = ['epoch'] + self.keys 2705 2706 self.writer = csv.DictWriter( 2707 self.csv_file, 2708 fieldnames=fieldnames, 2709 dialect=CustomDialect) 2710 if self.append_header: 2711 self.writer.writeheader() 2712 2713 row_dict = collections.OrderedDict({'epoch': epoch}) 2714 row_dict.update((key, handle_value(logs[key])) for key in self.keys) 2715 self.writer.writerow(row_dict) 2716 self.csv_file.flush() 2717 2718 def on_train_end(self, logs=None): 2719 self.csv_file.close() 2720 self.writer = None 2721 2722 2723@keras_export('keras.callbacks.LambdaCallback') 2724class LambdaCallback(Callback): 2725 r"""Callback for creating simple, custom callbacks on-the-fly. 2726 2727 This callback is constructed with anonymous functions that will be called 2728 at the appropriate time (during `Model.{fit | evaluate | predict}`). 2729 Note that the callbacks expects positional arguments, as: 2730 2731 - `on_epoch_begin` and `on_epoch_end` expect two positional arguments: 2732 `epoch`, `logs` 2733 - `on_batch_begin` and `on_batch_end` expect two positional arguments: 2734 `batch`, `logs` 2735 - `on_train_begin` and `on_train_end` expect one positional argument: 2736 `logs` 2737 2738 Args: 2739 on_epoch_begin: called at the beginning of every epoch. 2740 on_epoch_end: called at the end of every epoch. 2741 on_batch_begin: called at the beginning of every batch. 2742 on_batch_end: called at the end of every batch. 2743 on_train_begin: called at the beginning of model training. 2744 on_train_end: called at the end of model training. 2745 2746 Example: 2747 2748 ```python 2749 # Print the batch number at the beginning of every batch. 2750 batch_print_callback = LambdaCallback( 2751 on_batch_begin=lambda batch,logs: print(batch)) 2752 2753 # Stream the epoch loss to a file in JSON format. The file content 2754 # is not well-formed JSON but rather has a JSON object per line. 2755 import json 2756 json_log = open('loss_log.json', mode='wt', buffering=1) 2757 json_logging_callback = LambdaCallback( 2758 on_epoch_end=lambda epoch, logs: json_log.write( 2759 json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'), 2760 on_train_end=lambda logs: json_log.close() 2761 ) 2762 2763 # Terminate some processes after having finished model training. 2764 processes = ... 2765 cleanup_callback = LambdaCallback( 2766 on_train_end=lambda logs: [ 2767 p.terminate() for p in processes if p.is_alive()]) 2768 2769 model.fit(..., 2770 callbacks=[batch_print_callback, 2771 json_logging_callback, 2772 cleanup_callback]) 2773 ``` 2774 """ 2775 2776 def __init__(self, 2777 on_epoch_begin=None, 2778 on_epoch_end=None, 2779 on_batch_begin=None, 2780 on_batch_end=None, 2781 on_train_begin=None, 2782 on_train_end=None, 2783 **kwargs): 2784 super(LambdaCallback, self).__init__() 2785 self.__dict__.update(kwargs) 2786 if on_epoch_begin is not None: 2787 self.on_epoch_begin = on_epoch_begin 2788 else: 2789 self.on_epoch_begin = lambda epoch, logs: None 2790 if on_epoch_end is not None: 2791 self.on_epoch_end = on_epoch_end 2792 else: 2793 self.on_epoch_end = lambda epoch, logs: None 2794 if on_batch_begin is not None: 2795 self.on_batch_begin = on_batch_begin 2796 else: 2797 self.on_batch_begin = lambda batch, logs: None 2798 if on_batch_end is not None: 2799 self.on_batch_end = on_batch_end 2800 else: 2801 self.on_batch_end = lambda batch, logs: None 2802 if on_train_begin is not None: 2803 self.on_train_begin = on_train_begin 2804 else: 2805 self.on_train_begin = lambda logs: None 2806 if on_train_end is not None: 2807 self.on_train_end = on_train_end 2808 else: 2809 self.on_train_end = lambda logs: None 2810