1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-import-not-at-top 16# pylint: disable=g-classes-have-attributes 17"""Callbacks: utilities called at certain points during model training.""" 18 19import collections 20import copy 21import csv 22import json 23import os 24import re 25import sys 26import time 27 28import numpy as np 29 30from tensorflow.core.framework import summary_pb2 31from tensorflow.python.data.ops import iterator_ops 32from tensorflow.python.distribute import collective_all_reduce_strategy 33from tensorflow.python.distribute import distribution_strategy_context as ds_context 34from tensorflow.python.distribute import mirrored_strategy 35from tensorflow.python.distribute import parameter_server_strategy_v2 36from tensorflow.python.distribute import tpu_strategy 37from tensorflow.python.eager import context 38from tensorflow.python.framework import constant_op 39from tensorflow.python.framework import dtypes 40from tensorflow.python.framework import errors 41from tensorflow.python.framework import ops 42from tensorflow.python.keras import backend 43from tensorflow.python.keras.distribute import distributed_file_utils 44from tensorflow.python.keras.distribute import worker_training_state 45from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule 46from tensorflow.python.keras.utils import generic_utils 47from tensorflow.python.keras.utils import tf_utils 48from tensorflow.python.keras.utils import version_utils 49from tensorflow.python.keras.utils.data_utils import Sequence 50from tensorflow.python.keras.utils.generic_utils import Progbar 51from tensorflow.python.keras.utils.io_utils import path_to_string 52from tensorflow.python.keras.utils.mode_keys import ModeKeys 53from tensorflow.python.lib.io import file_io 54from tensorflow.python.ops import array_ops 55from tensorflow.python.ops import math_ops 56from tensorflow.python.ops import summary_ops_v2 57from tensorflow.python.platform import gfile 58from tensorflow.python.platform import tf_logging as logging 59from tensorflow.python.profiler import profiler_v2 as profiler 60from tensorflow.python.saved_model import save_options as save_options_lib 61from tensorflow.python.training import checkpoint_management 62from tensorflow.python.training.saving import checkpoint_options as checkpoint_options_lib 63from tensorflow.python.util import nest 64from tensorflow.python.util.tf_export import keras_export 65from tensorflow.tools.docs import doc_controls 66 67try: 68 import requests 69except ImportError: 70 requests = None 71 72 73# Note: `configure_callbacks` is only used in TF1. 74def configure_callbacks(callbacks, 75 model, 76 do_validation=False, 77 batch_size=None, 78 epochs=None, 79 steps_per_epoch=None, 80 samples=None, 81 verbose=1, 82 count_mode='steps', 83 mode=ModeKeys.TRAIN): 84 """Configures callbacks for use in various training loops. 85 86 Args: 87 callbacks: List of Callbacks. 88 model: Model being trained. 89 do_validation: Whether or not validation loop will be run. 90 batch_size: Number of samples per batch. 91 epochs: Number of epoch to train. 92 steps_per_epoch: Number of batches to run per training epoch. 93 samples: Number of training samples. 94 verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger. 95 count_mode: One of 'steps' or 'samples'. Per-batch or per-sample count. 96 mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT. 97 Which loop mode to configure callbacks for. 98 99 Returns: 100 Instance of CallbackList used to control all Callbacks. 101 """ 102 # Check if callbacks have already been configured. 103 if isinstance(callbacks, CallbackList): 104 return callbacks 105 106 if not callbacks: 107 callbacks = [] 108 109 # Add additional callbacks during training. 110 if mode == ModeKeys.TRAIN: 111 model.history = History() 112 callbacks = [BaseLogger()] + (callbacks or []) + [model.history] 113 if verbose: 114 callbacks.append(ProgbarLogger(count_mode)) 115 callback_list = CallbackList(callbacks) 116 117 # Set callback model 118 callback_model = model._get_callback_model() # pylint: disable=protected-access 119 callback_list.set_model(callback_model) 120 121 set_callback_parameters( 122 callback_list, 123 model, 124 do_validation=do_validation, 125 batch_size=batch_size, 126 epochs=epochs, 127 steps_per_epoch=steps_per_epoch, 128 samples=samples, 129 verbose=verbose, 130 mode=mode) 131 132 callback_list.model.stop_training = False 133 return callback_list 134 135 136def set_callback_parameters(callback_list, 137 model, 138 do_validation=False, 139 batch_size=None, 140 epochs=None, 141 steps_per_epoch=None, 142 samples=None, 143 verbose=1, 144 mode=ModeKeys.TRAIN): 145 """Sets callback parameters. 146 147 Args: 148 callback_list: CallbackList instance. 149 model: Model being trained. 150 do_validation: Whether or not validation loop will be run. 151 batch_size: Number of samples per batch. 152 epochs: Number of epoch to train. 153 steps_per_epoch: Number of batches to run per training epoch. 154 samples: Number of training samples. 155 verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger. 156 mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT. 157 Which loop mode to configure callbacks for. 158 """ 159 metric_names = model.metrics_names 160 for cbk in callback_list: 161 if isinstance(cbk, (BaseLogger, ProgbarLogger)): 162 cbk.stateful_metrics = metric_names[1:] # Exclude `loss` 163 164 # Set callback parameters 165 callback_metrics = [] 166 # When we have deferred build scenario with iterator input, we will compile 167 # when we standardize first batch of data. 168 if mode != ModeKeys.PREDICT: 169 callback_metrics = copy.copy(metric_names) 170 if do_validation: 171 callback_metrics += ['val_' + n for n in metric_names] 172 callback_params = { 173 'batch_size': batch_size, 174 'epochs': epochs, 175 'steps': steps_per_epoch, 176 'samples': samples, 177 'verbose': verbose, 178 'do_validation': do_validation, 179 'metrics': callback_metrics, 180 } 181 callback_list.set_params(callback_params) 182 183 184def _is_generator_like(data): 185 """Checks if data is a generator, Sequence, or Iterator.""" 186 return (hasattr(data, '__next__') or hasattr(data, 'next') or isinstance( 187 data, (Sequence, iterator_ops.Iterator, iterator_ops.IteratorBase))) 188 189 190def make_logs(model, logs, outputs, mode, prefix=''): 191 """Computes logs for sending to `on_batch_end` methods.""" 192 metric_names = model.metrics_names 193 if mode in {ModeKeys.TRAIN, ModeKeys.TEST} and metric_names: 194 for label, output in zip(metric_names, outputs): 195 logs[prefix + label] = output 196 else: 197 logs['outputs'] = outputs 198 return logs 199 200 201@keras_export('keras.callbacks.CallbackList') 202class CallbackList: 203 """Container abstracting a list of callbacks.""" 204 205 def __init__(self, 206 callbacks=None, 207 add_history=False, 208 add_progbar=False, 209 model=None, 210 **params): 211 """Container for `Callback` instances. 212 213 This object wraps a list of `Callback` instances, making it possible 214 to call them all at once via a single endpoint 215 (e.g. `callback_list.on_epoch_end(...)`). 216 217 Args: 218 callbacks: List of `Callback` instances. 219 add_history: Whether a `History` callback should be added, if one does not 220 already exist in the `callbacks` list. 221 add_progbar: Whether a `ProgbarLogger` callback should be added, if one 222 does not already exist in the `callbacks` list. 223 model: The `Model` these callbacks are used with. 224 **params: If provided, parameters will be passed to each `Callback` via 225 `Callback.set_params`. 226 """ 227 self.callbacks = nest.flatten(callbacks) if callbacks else [] 228 self._add_default_callbacks(add_history, add_progbar) 229 230 if model: 231 self.set_model(model) 232 if params: 233 self.set_params(params) 234 235 # Performance optimization: determines if batch hooks need to be called. 236 # pylint: disable=protected-access 237 self._supports_tf_logs = all( 238 getattr(cb, '_supports_tf_logs', False) for cb in self.callbacks) 239 self._batch_hooks_support_tf_logs = all( 240 getattr(cb, '_supports_tf_logs', False) 241 for cb in self.callbacks 242 if cb._implements_train_batch_hooks() or cb 243 ._implements_test_batch_hooks() or cb._implements_predict_batch_hooks()) 244 245 self._should_call_train_batch_hooks = any( 246 cb._implements_train_batch_hooks() for cb in self.callbacks) 247 self._should_call_test_batch_hooks = any( 248 cb._implements_test_batch_hooks() for cb in self.callbacks) 249 self._should_call_predict_batch_hooks = any( 250 cb._implements_predict_batch_hooks() for cb in self.callbacks) 251 # pylint: enable=protected-access 252 253 self._disallow_batch_hooks_in_ps_strategy() 254 255 # Performance check: Check batch hooks for slowness compared to batch time. 256 # Only run check for custom callbacks (i.e. not present in this file). 257 self._check_timing = any( 258 cbk.__class__.__name__ not in globals() for cbk in self.callbacks) 259 self._num_batches_for_timing_check = 5 260 self._hook_times = {} 261 self._batch_start_time = None 262 self._batch_times = [] 263 264 def _add_default_callbacks(self, add_history, add_progbar): 265 """Adds `Callback`s that are always present.""" 266 self._progbar = None 267 self._history = None 268 269 for cb in self.callbacks: 270 if isinstance(cb, ProgbarLogger): 271 self._progbar = cb 272 elif isinstance(cb, History): 273 self._history = cb 274 275 if self._progbar is None and add_progbar: 276 self._progbar = ProgbarLogger(count_mode='steps') 277 self.callbacks.insert(0, self._progbar) 278 279 if self._history is None and add_history: 280 self._history = History() 281 self.callbacks.append(self._history) 282 283 def _process_logs(self, logs, is_batch_hook=False): 284 """Turns tensors into numpy arrays or Python scalars if necessary.""" 285 if logs is None: 286 return {} 287 if self._supports_tf_logs: 288 return logs 289 if is_batch_hook and self._batch_hooks_support_tf_logs: 290 return logs 291 return tf_utils.sync_to_numpy_or_python_type(logs) 292 293 def append(self, callback): 294 self.callbacks.append(callback) 295 296 def set_params(self, params): 297 self.params = params 298 for callback in self.callbacks: 299 callback.set_params(params) 300 301 def set_model(self, model): 302 self.model = model 303 if self._history: 304 model.history = self._history 305 for callback in self.callbacks: 306 callback.set_model(model) 307 308 def _call_batch_hook(self, mode, hook, batch, logs=None): 309 """Helper function for all batch_{begin | end} methods.""" 310 if not self.callbacks: 311 return 312 313 if hook == 'begin': 314 self._call_batch_begin_hook(mode, batch, logs) 315 elif hook == 'end': 316 self._call_batch_end_hook(mode, batch, logs) 317 else: 318 raise ValueError('Unrecognized hook: {}'.format(hook)) 319 320 def _call_batch_begin_hook(self, mode, batch, logs): 321 """Helper function for `on_*_batch_begin` methods.""" 322 hook_name = 'on_{mode}_batch_begin'.format(mode=mode) 323 self._call_batch_hook_helper(hook_name, batch, logs) 324 325 if self._check_timing: 326 self._batch_start_time = time.time() 327 328 def _call_batch_end_hook(self, mode, batch, logs): 329 """Helper function for `on_*_batch_end` methods.""" 330 hook_name = 'on_{mode}_batch_end'.format(mode=mode) 331 332 if self._check_timing and batch >= 1: 333 batch_time = time.time() - self._batch_start_time 334 self._batch_times.append(batch_time) 335 336 self._call_batch_hook_helper(hook_name, batch, logs) 337 338 if len(self._batch_times) >= self._num_batches_for_timing_check: 339 end_hook_name = hook_name 340 begin_hook_name = 'on_{mode}_batch_begin'.format(mode=mode) 341 avg_batch_time = sum(self._batch_times) / len(self._batch_times) 342 avg_end_hook_time = sum(self._hook_times[end_hook_name]) / len( 343 self._hook_times[end_hook_name]) 344 avg_begin_hook_time = sum(self._hook_times[begin_hook_name]) / len( 345 self._hook_times[begin_hook_name]) 346 347 threshold_time = 1.0 * avg_batch_time 348 warning_msg = ('Callback method `{hook}` is slow compared to ' 349 'the batch time (batch time: {batch_time:.4f}s vs ' 350 '`{hook}` time: {hook_time:.4f}s). Check your callbacks.') 351 if avg_begin_hook_time > threshold_time: 352 logging.warning(warning_msg.format( 353 hook=begin_hook_name, 354 batch_time=avg_batch_time, 355 hook_time=avg_begin_hook_time)) 356 if avg_end_hook_time > threshold_time: 357 logging.warning(warning_msg.format( 358 hook=end_hook_name, 359 batch_time=avg_batch_time, 360 hook_time=avg_end_hook_time)) 361 self._check_timing = False 362 self._batch_start_time = None 363 self._batch_times = [] 364 self._hook_times = {} 365 366 def _call_batch_hook_helper(self, hook_name, batch, logs): 367 """Helper function for `on_*_batch_*` methods.""" 368 if self._check_timing: 369 start_time = time.time() 370 371 logs = self._process_logs(logs, is_batch_hook=True) 372 for callback in self.callbacks: 373 hook = getattr(callback, hook_name) 374 hook(batch, logs) 375 376 if self._check_timing: 377 if hook_name not in self._hook_times: 378 self._hook_times[hook_name] = [] 379 self._hook_times[hook_name].append(time.time() - start_time) 380 381 def _call_begin_hook(self, mode): 382 """Helper function for on_{train|test|predict}_begin methods.""" 383 if mode == ModeKeys.TRAIN: 384 self.on_train_begin() 385 elif mode == ModeKeys.TEST: 386 self.on_test_begin() 387 else: 388 self.on_predict_begin() 389 390 def _call_end_hook(self, mode): 391 """Helper function for on_{train|test|predict}_end methods.""" 392 if mode == ModeKeys.TRAIN: 393 self.on_train_end() 394 elif mode == ModeKeys.TEST: 395 self.on_test_end() 396 else: 397 self.on_predict_end() 398 399 def on_batch_begin(self, batch, logs=None): 400 if self._should_call_train_batch_hooks: 401 self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs) 402 403 def on_batch_end(self, batch, logs=None): 404 if self._should_call_train_batch_hooks: 405 self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs) 406 407 def on_epoch_begin(self, epoch, logs=None): 408 """Calls the `on_epoch_begin` methods of its callbacks. 409 410 This function should only be called during TRAIN mode. 411 412 Args: 413 epoch: Integer, index of epoch. 414 logs: Dict. Currently no data is passed to this argument for this method 415 but that may change in the future. 416 """ 417 logs = self._process_logs(logs) 418 for callback in self.callbacks: 419 callback.on_epoch_begin(epoch, logs) 420 421 def on_epoch_end(self, epoch, logs=None): 422 """Calls the `on_epoch_end` methods of its callbacks. 423 424 This function should only be called during TRAIN mode. 425 426 Args: 427 epoch: Integer, index of epoch. 428 logs: Dict, metric results for this training epoch, and for the 429 validation epoch if validation is performed. Validation result keys 430 are prefixed with `val_`. 431 """ 432 logs = self._process_logs(logs) 433 for callback in self.callbacks: 434 callback.on_epoch_end(epoch, logs) 435 436 def on_train_batch_begin(self, batch, logs=None): 437 """Calls the `on_train_batch_begin` methods of its callbacks. 438 439 Args: 440 batch: Integer, index of batch within the current epoch. 441 logs: Dict, contains the return value of `model.train_step`. Typically, 442 the values of the `Model`'s metrics are returned. Example: 443 `{'loss': 0.2, 'accuracy': 0.7}`. 444 """ 445 if self._should_call_train_batch_hooks: 446 self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs) 447 448 def on_train_batch_end(self, batch, logs=None): 449 """Calls the `on_train_batch_end` methods of its callbacks. 450 451 Args: 452 batch: Integer, index of batch within the current epoch. 453 logs: Dict. Aggregated metric results up until this batch. 454 """ 455 if self._should_call_train_batch_hooks: 456 self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs) 457 458 def on_test_batch_begin(self, batch, logs=None): 459 """Calls the `on_test_batch_begin` methods of its callbacks. 460 461 Args: 462 batch: Integer, index of batch within the current epoch. 463 logs: Dict, contains the return value of `model.test_step`. Typically, 464 the values of the `Model`'s metrics are returned. Example: 465 `{'loss': 0.2, 'accuracy': 0.7}`. 466 """ 467 if self._should_call_test_batch_hooks: 468 self._call_batch_hook(ModeKeys.TEST, 'begin', batch, logs=logs) 469 470 def on_test_batch_end(self, batch, logs=None): 471 """Calls the `on_test_batch_end` methods of its callbacks. 472 473 Args: 474 batch: Integer, index of batch within the current epoch. 475 logs: Dict. Aggregated metric results up until this batch. 476 """ 477 if self._should_call_test_batch_hooks: 478 self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs) 479 480 def on_predict_batch_begin(self, batch, logs=None): 481 """Calls the `on_predict_batch_begin` methods of its callbacks. 482 483 Args: 484 batch: Integer, index of batch within the current epoch. 485 logs: Dict, contains the return value of `model.predict_step`, 486 it typically returns a dict with a key 'outputs' containing 487 the model's outputs. 488 """ 489 if self._should_call_predict_batch_hooks: 490 self._call_batch_hook(ModeKeys.PREDICT, 'begin', batch, logs=logs) 491 492 def on_predict_batch_end(self, batch, logs=None): 493 """Calls the `on_predict_batch_end` methods of its callbacks. 494 495 Args: 496 batch: Integer, index of batch within the current epoch. 497 logs: Dict. Aggregated metric results up until this batch. 498 """ 499 if self._should_call_predict_batch_hooks: 500 self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs) 501 502 def on_train_begin(self, logs=None): 503 """Calls the `on_train_begin` methods of its callbacks. 504 505 Args: 506 logs: Dict. Currently no data is passed to this argument for this method 507 but that may change in the future. 508 """ 509 logs = self._process_logs(logs) 510 for callback in self.callbacks: 511 callback.on_train_begin(logs) 512 513 def on_train_end(self, logs=None): 514 """Calls the `on_train_end` methods of its callbacks. 515 516 Args: 517 logs: Dict. Currently no data is passed to this argument for this method 518 but that may change in the future. 519 """ 520 logs = self._process_logs(logs) 521 for callback in self.callbacks: 522 callback.on_train_end(logs) 523 524 def on_test_begin(self, logs=None): 525 """Calls the `on_test_begin` methods of its callbacks. 526 527 Args: 528 logs: Dict. Currently no data is passed to this argument for this method 529 but that may change in the future. 530 """ 531 logs = self._process_logs(logs) 532 for callback in self.callbacks: 533 callback.on_test_begin(logs) 534 535 def on_test_end(self, logs=None): 536 """Calls the `on_test_end` methods of its callbacks. 537 538 Args: 539 logs: Dict. Currently no data is passed to this argument for this method 540 but that may change in the future. 541 """ 542 logs = self._process_logs(logs) 543 for callback in self.callbacks: 544 callback.on_test_end(logs) 545 546 def on_predict_begin(self, logs=None): 547 """Calls the 'on_predict_begin` methods of its callbacks. 548 549 Args: 550 logs: Dict. Currently no data is passed to this argument for this method 551 but that may change in the future. 552 """ 553 logs = self._process_logs(logs) 554 for callback in self.callbacks: 555 callback.on_predict_begin(logs) 556 557 def on_predict_end(self, logs=None): 558 """Calls the `on_predict_end` methods of its callbacks. 559 560 Args: 561 logs: Dict. Currently no data is passed to this argument for this method 562 but that may change in the future. 563 """ 564 logs = self._process_logs(logs) 565 for callback in self.callbacks: 566 callback.on_predict_end(logs) 567 568 def __iter__(self): 569 return iter(self.callbacks) 570 571 def _disallow_batch_hooks_in_ps_strategy(self): 572 """Error out if batch-level callbacks are passed with PSStrategy.""" 573 # pylint: disable=protected-access 574 strategy = ds_context.get_strategy() 575 if strategy._should_use_with_coordinator: 576 unsupported_callbacks = [] 577 for cb in self.callbacks: 578 # These Callbacks can accept RemoteValues directly. 579 if getattr(cb, '_supports_tf_logs', False): 580 continue 581 if (cb._implements_train_batch_hooks() or 582 cb._implements_test_batch_hooks() or 583 cb._implements_predict_batch_hooks()): 584 unsupported_callbacks.append(cb) 585 if unsupported_callbacks: 586 raise ValueError('Batch-level `Callback`s are not supported with ' 587 '`ParameterServerStrategy`. Found unsupported ' 588 'callbacks: {}'.format(unsupported_callbacks)) 589 # pylint: enable=protected-access 590 591 592@keras_export('keras.callbacks.Callback') 593class Callback: 594 """Abstract base class used to build new callbacks. 595 596 Callbacks can be passed to keras methods such as `fit`, `evaluate`, and 597 `predict` in order to hook into the various stages of the model training and 598 inference lifecycle. 599 600 To create a custom callback, subclass `keras.callbacks.Callback` and override 601 the method associated with the stage of interest. See 602 https://www.tensorflow.org/guide/keras/custom_callback for more information. 603 604 Example: 605 606 >>> training_finished = False 607 >>> class MyCallback(tf.keras.callbacks.Callback): 608 ... def on_train_end(self, logs=None): 609 ... global training_finished 610 ... training_finished = True 611 >>> model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))]) 612 >>> model.compile(loss='mean_squared_error') 613 >>> model.fit(tf.constant([[1.0]]), tf.constant([[1.0]]), 614 ... callbacks=[MyCallback()]) 615 >>> assert training_finished == True 616 617 If you want to use `Callback` objects in a custom training loop: 618 619 1. You should pack all your callbacks into a single `callbacks.CallbackList` 620 so they can all be called together. 621 2. You will need to manually call all the `on_*` methods at the apropriate 622 locations in your loop. Like this: 623 624 ``` 625 callbacks = tf.keras.callbacks.CallbackList([...]) 626 callbacks.append(...) 627 628 callbacks.on_train_begin(...) 629 for epoch in range(EPOCHS): 630 callbacks.on_epoch_begin(epoch) 631 for i, data in dataset.enumerate(): 632 callbacks.on_train_batch_begin(i) 633 batch_logs = model.train_step(data) 634 callbacks.on_train_batch_end(i, batch_logs) 635 epoch_logs = ... 636 callbacks.on_epoch_end(epoch, epoch_logs) 637 final_logs=... 638 callbacks.on_train_end(final_logs) 639 ``` 640 641 Attributes: 642 params: Dict. Training parameters 643 (eg. verbosity, batch size, number of epochs...). 644 model: Instance of `keras.models.Model`. 645 Reference of the model being trained. 646 647 The `logs` dictionary that callback methods 648 take as argument will contain keys for quantities relevant to 649 the current batch or epoch (see method-specific docstrings). 650 """ 651 652 def __init__(self): 653 self.validation_data = None # pylint: disable=g-missing-from-attributes 654 self.model = None 655 # Whether this Callback should only run on the chief worker in a 656 # Multi-Worker setting. 657 # TODO(omalleyt): Make this attr public once solution is stable. 658 self._chief_worker_only = None 659 self._supports_tf_logs = False 660 661 def set_params(self, params): 662 self.params = params 663 664 def set_model(self, model): 665 self.model = model 666 667 @doc_controls.for_subclass_implementers 668 @generic_utils.default 669 def on_batch_begin(self, batch, logs=None): 670 """A backwards compatibility alias for `on_train_batch_begin`.""" 671 672 @doc_controls.for_subclass_implementers 673 @generic_utils.default 674 def on_batch_end(self, batch, logs=None): 675 """A backwards compatibility alias for `on_train_batch_end`.""" 676 677 @doc_controls.for_subclass_implementers 678 def on_epoch_begin(self, epoch, logs=None): 679 """Called at the start of an epoch. 680 681 Subclasses should override for any actions to run. This function should only 682 be called during TRAIN mode. 683 684 Args: 685 epoch: Integer, index of epoch. 686 logs: Dict. Currently no data is passed to this argument for this method 687 but that may change in the future. 688 """ 689 690 @doc_controls.for_subclass_implementers 691 def on_epoch_end(self, epoch, logs=None): 692 """Called at the end of an epoch. 693 694 Subclasses should override for any actions to run. This function should only 695 be called during TRAIN mode. 696 697 Args: 698 epoch: Integer, index of epoch. 699 logs: Dict, metric results for this training epoch, and for the 700 validation epoch if validation is performed. Validation result keys 701 are prefixed with `val_`. For training epoch, the values of the 702 `Model`'s metrics are returned. Example : `{'loss': 0.2, 'accuracy': 703 0.7}`. 704 """ 705 706 @doc_controls.for_subclass_implementers 707 @generic_utils.default 708 def on_train_batch_begin(self, batch, logs=None): 709 """Called at the beginning of a training batch in `fit` methods. 710 711 Subclasses should override for any actions to run. 712 713 Note that if the `steps_per_execution` argument to `compile` in 714 `tf.keras.Model` is set to `N`, this method will only be called every `N` 715 batches. 716 717 Args: 718 batch: Integer, index of batch within the current epoch. 719 logs: Dict, contains the return value of `model.train_step`. Typically, 720 the values of the `Model`'s metrics are returned. Example: 721 `{'loss': 0.2, 'accuracy': 0.7}`. 722 """ 723 # For backwards compatibility. 724 self.on_batch_begin(batch, logs=logs) 725 726 @doc_controls.for_subclass_implementers 727 @generic_utils.default 728 def on_train_batch_end(self, batch, logs=None): 729 """Called at the end of a training batch in `fit` methods. 730 731 Subclasses should override for any actions to run. 732 733 Note that if the `steps_per_execution` argument to `compile` in 734 `tf.keras.Model` is set to `N`, this method will only be called every `N` 735 batches. 736 737 Args: 738 batch: Integer, index of batch within the current epoch. 739 logs: Dict. Aggregated metric results up until this batch. 740 """ 741 # For backwards compatibility. 742 self.on_batch_end(batch, logs=logs) 743 744 @doc_controls.for_subclass_implementers 745 @generic_utils.default 746 def on_test_batch_begin(self, batch, logs=None): 747 """Called at the beginning of a batch in `evaluate` methods. 748 749 Also called at the beginning of a validation batch in the `fit` 750 methods, if validation data is provided. 751 752 Subclasses should override for any actions to run. 753 754 Note that if the `steps_per_execution` argument to `compile` in 755 `tf.keras.Model` is set to `N`, this method will only be called every `N` 756 batches. 757 758 Args: 759 batch: Integer, index of batch within the current epoch. 760 logs: Dict, contains the return value of `model.test_step`. Typically, 761 the values of the `Model`'s metrics are returned. Example: 762 `{'loss': 0.2, 'accuracy': 0.7}`. 763 """ 764 765 @doc_controls.for_subclass_implementers 766 @generic_utils.default 767 def on_test_batch_end(self, batch, logs=None): 768 """Called at the end of a batch in `evaluate` methods. 769 770 Also called at the end of a validation batch in the `fit` 771 methods, if validation data is provided. 772 773 Subclasses should override for any actions to run. 774 775 Note that if the `steps_per_execution` argument to `compile` in 776 `tf.keras.Model` is set to `N`, this method will only be called every `N` 777 batches. 778 779 Args: 780 batch: Integer, index of batch within the current epoch. 781 logs: Dict. Aggregated metric results up until this batch. 782 """ 783 784 @doc_controls.for_subclass_implementers 785 @generic_utils.default 786 def on_predict_batch_begin(self, batch, logs=None): 787 """Called at the beginning of a batch in `predict` methods. 788 789 Subclasses should override for any actions to run. 790 791 Note that if the `steps_per_execution` argument to `compile` in 792 `tf.keras.Model` is set to `N`, this method will only be called every `N` 793 batches. 794 795 Args: 796 batch: Integer, index of batch within the current epoch. 797 logs: Dict, contains the return value of `model.predict_step`, 798 it typically returns a dict with a key 'outputs' containing 799 the model's outputs. 800 """ 801 802 @doc_controls.for_subclass_implementers 803 @generic_utils.default 804 def on_predict_batch_end(self, batch, logs=None): 805 """Called at the end of a batch in `predict` methods. 806 807 Subclasses should override for any actions to run. 808 809 Note that if the `steps_per_execution` argument to `compile` in 810 `tf.keras.Model` is set to `N`, this method will only be called every `N` 811 batches. 812 813 Args: 814 batch: Integer, index of batch within the current epoch. 815 logs: Dict. Aggregated metric results up until this batch. 816 """ 817 818 @doc_controls.for_subclass_implementers 819 def on_train_begin(self, logs=None): 820 """Called at the beginning of training. 821 822 Subclasses should override for any actions to run. 823 824 Args: 825 logs: Dict. Currently no data is passed to this argument for this method 826 but that may change in the future. 827 """ 828 829 @doc_controls.for_subclass_implementers 830 def on_train_end(self, logs=None): 831 """Called at the end of training. 832 833 Subclasses should override for any actions to run. 834 835 Args: 836 logs: Dict. Currently the output of the last call to `on_epoch_end()` 837 is passed to this argument for this method but that may change in 838 the future. 839 """ 840 841 @doc_controls.for_subclass_implementers 842 def on_test_begin(self, logs=None): 843 """Called at the beginning of evaluation or validation. 844 845 Subclasses should override for any actions to run. 846 847 Args: 848 logs: Dict. Currently no data is passed to this argument for this method 849 but that may change in the future. 850 """ 851 852 @doc_controls.for_subclass_implementers 853 def on_test_end(self, logs=None): 854 """Called at the end of evaluation or validation. 855 856 Subclasses should override for any actions to run. 857 858 Args: 859 logs: Dict. Currently the output of the last call to 860 `on_test_batch_end()` is passed to this argument for this method 861 but that may change in the future. 862 """ 863 864 @doc_controls.for_subclass_implementers 865 def on_predict_begin(self, logs=None): 866 """Called at the beginning of prediction. 867 868 Subclasses should override for any actions to run. 869 870 Args: 871 logs: Dict. Currently no data is passed to this argument for this method 872 but that may change in the future. 873 """ 874 875 @doc_controls.for_subclass_implementers 876 def on_predict_end(self, logs=None): 877 """Called at the end of prediction. 878 879 Subclasses should override for any actions to run. 880 881 Args: 882 logs: Dict. Currently no data is passed to this argument for this method 883 but that may change in the future. 884 """ 885 886 def _implements_train_batch_hooks(self): 887 """Determines if this Callback should be called for each train batch.""" 888 return (not generic_utils.is_default(self.on_batch_begin) or 889 not generic_utils.is_default(self.on_batch_end) or 890 not generic_utils.is_default(self.on_train_batch_begin) or 891 not generic_utils.is_default(self.on_train_batch_end)) 892 893 def _implements_test_batch_hooks(self): 894 """Determines if this Callback should be called for each test batch.""" 895 return (not generic_utils.is_default(self.on_test_batch_begin) or 896 not generic_utils.is_default(self.on_test_batch_end)) 897 898 def _implements_predict_batch_hooks(self): 899 """Determines if this Callback should be called for each predict batch.""" 900 return (not generic_utils.is_default(self.on_predict_batch_begin) or 901 not generic_utils.is_default(self.on_predict_batch_end)) 902 903 904@keras_export('keras.callbacks.BaseLogger') 905class BaseLogger(Callback): 906 """Callback that accumulates epoch averages of metrics. 907 908 This callback is automatically applied to every Keras model. 909 910 Args: 911 stateful_metrics: Iterable of string names of metrics that 912 should *not* be averaged over an epoch. 913 Metrics in this list will be logged as-is in `on_epoch_end`. 914 All others will be averaged in `on_epoch_end`. 915 """ 916 917 def __init__(self, stateful_metrics=None): 918 super(BaseLogger, self).__init__() 919 self.stateful_metrics = set(stateful_metrics or []) 920 921 def on_epoch_begin(self, epoch, logs=None): 922 self.seen = 0 923 self.totals = {} 924 925 def on_batch_end(self, batch, logs=None): 926 logs = logs or {} 927 batch_size = logs.get('size', 0) 928 # In case of distribution strategy we can potentially run multiple steps 929 # at the same time, we should account for that in the `seen` calculation. 930 num_steps = logs.get('num_steps', 1) 931 self.seen += batch_size * num_steps 932 933 for k, v in logs.items(): 934 if k in self.stateful_metrics: 935 self.totals[k] = v 936 else: 937 if k in self.totals: 938 self.totals[k] += v * batch_size 939 else: 940 self.totals[k] = v * batch_size 941 942 def on_epoch_end(self, epoch, logs=None): 943 if logs is not None: 944 for k in self.params['metrics']: 945 if k in self.totals: 946 # Make value available to next callbacks. 947 if k in self.stateful_metrics: 948 logs[k] = self.totals[k] 949 else: 950 logs[k] = self.totals[k] / self.seen 951 952 953@keras_export('keras.callbacks.TerminateOnNaN') 954class TerminateOnNaN(Callback): 955 """Callback that terminates training when a NaN loss is encountered. 956 """ 957 958 def __init__(self): 959 super(TerminateOnNaN, self).__init__() 960 self._supports_tf_logs = True 961 962 def on_batch_end(self, batch, logs=None): 963 logs = logs or {} 964 loss = logs.get('loss') 965 if loss is not None: 966 loss = tf_utils.sync_to_numpy_or_python_type(loss) 967 if np.isnan(loss) or np.isinf(loss): 968 print('Batch %d: Invalid loss, terminating training' % (batch)) 969 self.model.stop_training = True 970 971 972@keras_export('keras.callbacks.ProgbarLogger') 973class ProgbarLogger(Callback): 974 """Callback that prints metrics to stdout. 975 976 Args: 977 count_mode: One of `"steps"` or `"samples"`. 978 Whether the progress bar should 979 count samples seen or steps (batches) seen. 980 stateful_metrics: Iterable of string names of metrics that 981 should *not* be averaged over an epoch. 982 Metrics in this list will be logged as-is. 983 All others will be averaged over time (e.g. loss, etc). 984 If not provided, defaults to the `Model`'s metrics. 985 986 Raises: 987 ValueError: In case of invalid `count_mode`. 988 """ 989 990 def __init__(self, count_mode='samples', stateful_metrics=None): 991 super(ProgbarLogger, self).__init__() 992 self._supports_tf_logs = True 993 if count_mode == 'samples': 994 self.use_steps = False 995 elif count_mode == 'steps': 996 self.use_steps = True 997 else: 998 raise ValueError('Unknown `count_mode`: ' + str(count_mode)) 999 # Defaults to all Model's metrics except for loss. 1000 self.stateful_metrics = set(stateful_metrics) if stateful_metrics else set() 1001 1002 self.seen = 0 1003 self.progbar = None 1004 self.target = None 1005 self.verbose = 1 1006 self.epochs = 1 1007 1008 self._train_step, self._test_step, self._predict_step = None, None, None 1009 self._call_batch_hooks = True 1010 1011 self._called_in_fit = False 1012 1013 def set_params(self, params): 1014 self.verbose = params['verbose'] 1015 self.epochs = params['epochs'] 1016 if self.use_steps and 'steps' in params: 1017 self.target = params['steps'] 1018 elif not self.use_steps and 'samples' in params: 1019 self.target = params['samples'] 1020 else: 1021 self.target = None # Will be inferred at the end of the first epoch. 1022 1023 self._call_batch_hooks = self.verbose == 1 1024 if self.target is None: 1025 try: 1026 self._train_step = self.model._train_counter # pylint: disable=protected-access 1027 self._test_step = self.model._test_counter # pylint: disable=protected-access 1028 self._predict_step = self.model._predict_counter # pylint: disable=protected-access 1029 except AttributeError: 1030 self._call_batch_hooks = True 1031 1032 def on_train_begin(self, logs=None): 1033 # When this logger is called inside `fit`, validation is silent. 1034 self._called_in_fit = True 1035 1036 def on_test_begin(self, logs=None): 1037 if not self._called_in_fit: 1038 self._reset_progbar() 1039 self._maybe_init_progbar() 1040 1041 def on_predict_begin(self, logs=None): 1042 self._reset_progbar() 1043 self._maybe_init_progbar() 1044 1045 def on_epoch_begin(self, epoch, logs=None): 1046 self._reset_progbar() 1047 self._maybe_init_progbar() 1048 if self.verbose and self.epochs > 1: 1049 print('Epoch %d/%d' % (epoch + 1, self.epochs)) 1050 1051 def on_train_batch_end(self, batch, logs=None): 1052 self._batch_update_progbar(batch, logs) 1053 1054 def on_test_batch_end(self, batch, logs=None): 1055 if not self._called_in_fit: 1056 self._batch_update_progbar(batch, logs) 1057 1058 def on_predict_batch_end(self, batch, logs=None): 1059 # Don't pass prediction results. 1060 self._batch_update_progbar(batch, None) 1061 1062 def on_epoch_end(self, epoch, logs=None): 1063 self._finalize_progbar(logs, self._train_step) 1064 1065 def on_test_end(self, logs=None): 1066 if not self._called_in_fit: 1067 self._finalize_progbar(logs, self._test_step) 1068 1069 def on_predict_end(self, logs=None): 1070 self._finalize_progbar(logs, self._predict_step) 1071 1072 def _reset_progbar(self): 1073 self.seen = 0 1074 self.progbar = None 1075 1076 def _maybe_init_progbar(self): 1077 """Instantiate a `Progbar` if not yet, and update the stateful metrics.""" 1078 # TODO(rchao): Legacy TF1 code path may use list for 1079 # `self.stateful_metrics`. Remove "cast to set" when TF1 support is dropped. 1080 self.stateful_metrics = set(self.stateful_metrics) 1081 1082 if self.model: 1083 # Update the existing stateful metrics as `self.model.metrics` may contain 1084 # updated metrics after `MetricsContainer` is built in the first train 1085 # step. 1086 self.stateful_metrics = self.stateful_metrics.union( 1087 set(m.name for m in self.model.metrics)) 1088 1089 if self.progbar is None: 1090 self.progbar = Progbar( 1091 target=self.target, 1092 verbose=self.verbose, 1093 stateful_metrics=self.stateful_metrics, 1094 unit_name='step' if self.use_steps else 'sample') 1095 1096 self.progbar._update_stateful_metrics(self.stateful_metrics) # pylint: disable=protected-access 1097 1098 def _implements_train_batch_hooks(self): 1099 return self._call_batch_hooks 1100 1101 def _implements_test_batch_hooks(self): 1102 return self._call_batch_hooks 1103 1104 def _implements_predict_batch_hooks(self): 1105 return self._call_batch_hooks 1106 1107 def _batch_update_progbar(self, batch, logs=None): 1108 """Updates the progbar.""" 1109 logs = logs or {} 1110 self._maybe_init_progbar() 1111 if self.use_steps: 1112 self.seen = batch + 1 # One-indexed. 1113 else: 1114 # v1 path only. 1115 logs = copy.copy(logs) 1116 batch_size = logs.pop('size', 0) 1117 num_steps = logs.pop('num_steps', 1) 1118 logs.pop('batch', None) 1119 add_seen = num_steps * batch_size 1120 self.seen += add_seen 1121 1122 if self.verbose == 1: 1123 # Only block async when verbose = 1. 1124 logs = tf_utils.sync_to_numpy_or_python_type(logs) 1125 self.progbar.update(self.seen, list(logs.items()), finalize=False) 1126 1127 def _finalize_progbar(self, logs, counter): 1128 logs = tf_utils.sync_to_numpy_or_python_type(logs or {}) 1129 if self.target is None: 1130 if counter is not None: 1131 counter = counter.numpy() 1132 if not self.use_steps: 1133 counter *= logs.get('size', 1) 1134 self.target = counter or self.seen 1135 self.progbar.target = self.target 1136 self.progbar.update(self.target, list(logs.items()), finalize=True) 1137 1138 1139@keras_export('keras.callbacks.History') 1140class History(Callback): 1141 """Callback that records events into a `History` object. 1142 1143 This callback is automatically applied to 1144 every Keras model. The `History` object 1145 gets returned by the `fit` method of models. 1146 1147 Example: 1148 1149 >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 1150 >>> model.compile(tf.keras.optimizers.SGD(), loss='mse') 1151 >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), 1152 ... epochs=10) 1153 >>> print(history.params) 1154 {'verbose': 1, 'epochs': 10, 'steps': 1} 1155 >>> # check the keys of history object 1156 >>> print(history.history.keys()) 1157 dict_keys(['loss']) 1158 1159 """ 1160 1161 def __init__(self): 1162 super(History, self).__init__() 1163 self.history = {} 1164 1165 def on_train_begin(self, logs=None): 1166 self.epoch = [] 1167 1168 def on_epoch_end(self, epoch, logs=None): 1169 logs = logs or {} 1170 self.epoch.append(epoch) 1171 for k, v in logs.items(): 1172 self.history.setdefault(k, []).append(v) 1173 1174 # Set the history attribute on the model after the epoch ends. This will 1175 # make sure that the state which is set is the latest one. 1176 self.model.history = self 1177 1178 1179@keras_export('keras.callbacks.ModelCheckpoint') 1180class ModelCheckpoint(Callback): 1181 """Callback to save the Keras model or model weights at some frequency. 1182 1183 `ModelCheckpoint` callback is used in conjunction with training using 1184 `model.fit()` to save a model or weights (in a checkpoint file) at some 1185 interval, so the model or weights can be loaded later to continue the training 1186 from the state saved. 1187 1188 A few options this callback provides include: 1189 1190 - Whether to only keep the model that has achieved the "best performance" so 1191 far, or whether to save the model at the end of every epoch regardless of 1192 performance. 1193 - Definition of 'best'; which quantity to monitor and whether it should be 1194 maximized or minimized. 1195 - The frequency it should save at. Currently, the callback supports saving at 1196 the end of every epoch, or after a fixed number of training batches. 1197 - Whether only weights are saved, or the whole model is saved. 1198 1199 Note: If you get `WARNING:tensorflow:Can save best model only with <name> 1200 available, skipping` see the description of the `monitor` argument for 1201 details on how to get this right. 1202 1203 Example: 1204 1205 ```python 1206 model.compile(loss=..., optimizer=..., 1207 metrics=['accuracy']) 1208 1209 EPOCHS = 10 1210 checkpoint_filepath = '/tmp/checkpoint' 1211 model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( 1212 filepath=checkpoint_filepath, 1213 save_weights_only=True, 1214 monitor='val_accuracy', 1215 mode='max', 1216 save_best_only=True) 1217 1218 # Model weights are saved at the end of every epoch, if it's the best seen 1219 # so far. 1220 model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback]) 1221 1222 # The model weights (that are considered the best) are loaded into the model. 1223 model.load_weights(checkpoint_filepath) 1224 ``` 1225 1226 Args: 1227 filepath: string or `PathLike`, path to save the model file. e.g. 1228 filepath = os.path.join(working_dir, 'ckpt', file_name). `filepath` 1229 can contain named formatting options, which will be filled the value of 1230 `epoch` and keys in `logs` (passed in `on_epoch_end`). For example: if 1231 `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model 1232 checkpoints will be saved with the epoch number and the validation loss 1233 in the filename. The directory of the filepath should not be reused by 1234 any other callbacks to avoid conflicts. 1235 monitor: The metric name to monitor. Typically the metrics are set by the 1236 `Model.compile` method. Note: 1237 1238 * Prefix the name with `"val_`" to monitor validation metrics. 1239 * Use `"loss"` or "`val_loss`" to monitor the model's total loss. 1240 * If you specify metrics as strings, like `"accuracy"`, pass the same 1241 string (with or without the `"val_"` prefix). 1242 * If you pass `metrics.Metric` objects, `monitor` should be set to 1243 `metric.name` 1244 * If you're not sure about the metric names you can check the contents 1245 of the `history.history` dictionary returned by 1246 `history = model.fit()` 1247 * Multi-output models set additional prefixes on the metric names. 1248 1249 verbose: verbosity mode, 0 or 1. 1250 save_best_only: if `save_best_only=True`, it only saves when the model 1251 is considered the "best" and the latest best model according to the 1252 quantity monitored will not be overwritten. If `filepath` doesn't 1253 contain formatting options like `{epoch}` then `filepath` will be 1254 overwritten by each new better model. 1255 mode: one of {'auto', 'min', 'max'}. If `save_best_only=True`, the 1256 decision to overwrite the current save file is made based on either 1257 the maximization or the minimization of the monitored quantity. 1258 For `val_acc`, this should be `max`, for `val_loss` this should be 1259 `min`, etc. In `auto` mode, the mode is set to `max` if the quantities 1260 monitored are 'acc' or start with 'fmeasure' and are set to `min` for 1261 the rest of the quantities. 1262 save_weights_only: if True, then only the model's weights will be saved 1263 (`model.save_weights(filepath)`), else the full model is saved 1264 (`model.save(filepath)`). 1265 save_freq: `'epoch'` or integer. When using `'epoch'`, the callback saves 1266 the model after each epoch. When using integer, the callback saves the 1267 model at end of this many batches. If the `Model` is compiled with 1268 `steps_per_execution=N`, then the saving criteria will be 1269 checked every Nth batch. Note that if the saving isn't aligned to 1270 epochs, the monitored metric may potentially be less reliable (it 1271 could reflect as little as 1 batch, since the metrics get reset every 1272 epoch). Defaults to `'epoch'`. 1273 options: Optional `tf.train.CheckpointOptions` object if 1274 `save_weights_only` is true or optional `tf.saved_model.SaveOptions` 1275 object if `save_weights_only` is false. 1276 **kwargs: Additional arguments for backwards compatibility. Possible key 1277 is `period`. 1278 """ 1279 1280 def __init__(self, 1281 filepath, 1282 monitor='val_loss', 1283 verbose=0, 1284 save_best_only=False, 1285 save_weights_only=False, 1286 mode='auto', 1287 save_freq='epoch', 1288 options=None, 1289 **kwargs): 1290 super(ModelCheckpoint, self).__init__() 1291 self._supports_tf_logs = True 1292 self.monitor = monitor 1293 self.verbose = verbose 1294 self.filepath = path_to_string(filepath) 1295 self.save_best_only = save_best_only 1296 self.save_weights_only = save_weights_only 1297 self.save_freq = save_freq 1298 self.epochs_since_last_save = 0 1299 self._batches_seen_since_last_saving = 0 1300 self._last_batch_seen = 0 1301 1302 if save_weights_only: 1303 if options is None or isinstance( 1304 options, checkpoint_options_lib.CheckpointOptions): 1305 self._options = options or checkpoint_options_lib.CheckpointOptions() 1306 else: 1307 raise TypeError('If save_weights_only is True, then `options` must be ' 1308 'either None or a tf.train.CheckpointOptions') 1309 else: 1310 if options is None or isinstance(options, save_options_lib.SaveOptions): 1311 self._options = options or save_options_lib.SaveOptions() 1312 else: 1313 raise TypeError('If save_weights_only is False, then `options` must be' 1314 'either None or a tf.saved_model.SaveOptions') 1315 1316 # Deprecated field `load_weights_on_restart` is for loading the checkpoint 1317 # file from `filepath` at the start of `model.fit()` 1318 # TODO(rchao): Remove the arg during next breaking release. 1319 if 'load_weights_on_restart' in kwargs: 1320 self.load_weights_on_restart = kwargs['load_weights_on_restart'] 1321 logging.warning('`load_weights_on_restart` argument is deprecated. ' 1322 'Please use `model.load_weights()` for loading weights ' 1323 'before the start of `model.fit()`.') 1324 else: 1325 self.load_weights_on_restart = False 1326 1327 # Deprecated field `period` is for the number of epochs between which 1328 # the model is saved. 1329 if 'period' in kwargs: 1330 self.period = kwargs['period'] 1331 logging.warning('`period` argument is deprecated. Please use `save_freq` ' 1332 'to specify the frequency in number of batches seen.') 1333 else: 1334 self.period = 1 1335 1336 if mode not in ['auto', 'min', 'max']: 1337 logging.warning('ModelCheckpoint mode %s is unknown, ' 1338 'fallback to auto mode.', mode) 1339 mode = 'auto' 1340 1341 if mode == 'min': 1342 self.monitor_op = np.less 1343 self.best = np.Inf 1344 elif mode == 'max': 1345 self.monitor_op = np.greater 1346 self.best = -np.Inf 1347 else: 1348 if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): 1349 self.monitor_op = np.greater 1350 self.best = -np.Inf 1351 else: 1352 self.monitor_op = np.less 1353 self.best = np.Inf 1354 1355 if self.save_freq != 'epoch' and not isinstance(self.save_freq, int): 1356 raise ValueError('Unrecognized save_freq: {}'.format(self.save_freq)) 1357 1358 # Only the chief worker writes model checkpoints, but all workers 1359 # restore checkpoint at on_train_begin(). 1360 self._chief_worker_only = False 1361 1362 def on_train_begin(self, logs=None): 1363 if self.load_weights_on_restart: 1364 filepath_to_load = ( 1365 self._get_most_recently_modified_file_matching_pattern(self.filepath)) 1366 if (filepath_to_load is not None and 1367 self._checkpoint_exists(filepath_to_load)): 1368 try: 1369 # `filepath` may contain placeholders such as `{epoch:02d}`, and 1370 # thus it attempts to load the most recently modified file with file 1371 # name matching the pattern. 1372 self.model.load_weights(filepath_to_load) 1373 except (IOError, ValueError) as e: 1374 raise ValueError('Error loading file from {}. Reason: {}'.format( 1375 filepath_to_load, e)) 1376 1377 def _implements_train_batch_hooks(self): 1378 # Only call batch hooks when saving on batch 1379 return self.save_freq != 'epoch' 1380 1381 def on_train_batch_end(self, batch, logs=None): 1382 if self._should_save_on_batch(batch): 1383 self._save_model(epoch=self._current_epoch, logs=logs) 1384 1385 def on_epoch_begin(self, epoch, logs=None): 1386 self._current_epoch = epoch 1387 1388 def on_epoch_end(self, epoch, logs=None): 1389 self.epochs_since_last_save += 1 1390 # pylint: disable=protected-access 1391 if self.save_freq == 'epoch': 1392 self._save_model(epoch=epoch, logs=logs) 1393 1394 def _should_save_on_batch(self, batch): 1395 """Handles batch-level saving logic, supports steps_per_execution.""" 1396 if self.save_freq == 'epoch': 1397 return False 1398 1399 if batch <= self._last_batch_seen: # New epoch. 1400 add_batches = batch + 1 # batches are zero-indexed. 1401 else: 1402 add_batches = batch - self._last_batch_seen 1403 self._batches_seen_since_last_saving += add_batches 1404 self._last_batch_seen = batch 1405 1406 if self._batches_seen_since_last_saving >= self.save_freq: 1407 self._batches_seen_since_last_saving = 0 1408 return True 1409 return False 1410 1411 def _save_model(self, epoch, logs): 1412 """Saves the model. 1413 1414 Args: 1415 epoch: the epoch this iteration is in. 1416 logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`. 1417 """ 1418 logs = logs or {} 1419 1420 if isinstance(self.save_freq, 1421 int) or self.epochs_since_last_save >= self.period: 1422 # Block only when saving interval is reached. 1423 logs = tf_utils.sync_to_numpy_or_python_type(logs) 1424 self.epochs_since_last_save = 0 1425 filepath = self._get_file_path(epoch, logs) 1426 1427 try: 1428 if self.save_best_only: 1429 current = logs.get(self.monitor) 1430 if current is None: 1431 logging.warning('Can save best model only with %s available, ' 1432 'skipping.', self.monitor) 1433 else: 1434 if self.monitor_op(current, self.best): 1435 if self.verbose > 0: 1436 print('\nEpoch %05d: %s improved from %0.5f to %0.5f,' 1437 ' saving model to %s' % (epoch + 1, self.monitor, 1438 self.best, current, filepath)) 1439 self.best = current 1440 if self.save_weights_only: 1441 self.model.save_weights( 1442 filepath, overwrite=True, options=self._options) 1443 else: 1444 self.model.save(filepath, overwrite=True, options=self._options) 1445 else: 1446 if self.verbose > 0: 1447 print('\nEpoch %05d: %s did not improve from %0.5f' % 1448 (epoch + 1, self.monitor, self.best)) 1449 else: 1450 if self.verbose > 0: 1451 print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath)) 1452 if self.save_weights_only: 1453 self.model.save_weights( 1454 filepath, overwrite=True, options=self._options) 1455 else: 1456 self.model.save(filepath, overwrite=True, options=self._options) 1457 1458 self._maybe_remove_file() 1459 except IsADirectoryError as e: # h5py 3.x 1460 raise IOError('Please specify a non-directory filepath for ' 1461 'ModelCheckpoint. Filepath used is an existing ' 1462 'directory: {}'.format(filepath)) 1463 except IOError as e: # h5py 2.x 1464 # `e.errno` appears to be `None` so checking the content of `e.args[0]`. 1465 if 'is a directory' in str(e.args[0]).lower(): 1466 raise IOError('Please specify a non-directory filepath for ' 1467 'ModelCheckpoint. Filepath used is an existing ' 1468 'directory: {}'.format(filepath)) 1469 # Re-throw the error for any other causes. 1470 raise e 1471 1472 def _get_file_path(self, epoch, logs): 1473 """Returns the file path for checkpoint.""" 1474 # pylint: disable=protected-access 1475 try: 1476 # `filepath` may contain placeholders such as `{epoch:02d}` and 1477 # `{mape:.2f}`. A mismatch between logged metrics and the path's 1478 # placeholders can cause formatting to fail. 1479 file_path = self.filepath.format(epoch=epoch + 1, **logs) 1480 except KeyError as e: 1481 raise KeyError('Failed to format this callback filepath: "{}". ' 1482 'Reason: {}'.format(self.filepath, e)) 1483 self._write_filepath = distributed_file_utils.write_filepath( 1484 file_path, self.model.distribute_strategy) 1485 return self._write_filepath 1486 1487 def _maybe_remove_file(self): 1488 # Remove the checkpoint directory in multi-worker training where this worker 1489 # should not checkpoint. It is a dummy directory previously saved for sync 1490 # distributed training. 1491 distributed_file_utils.remove_temp_dir_with_filepath( 1492 self._write_filepath, self.model.distribute_strategy) 1493 1494 def _checkpoint_exists(self, filepath): 1495 """Returns whether the checkpoint `filepath` refers to exists.""" 1496 if filepath.endswith('.h5'): 1497 return file_io.file_exists_v2(filepath) 1498 tf_saved_model_exists = file_io.file_exists_v2(filepath) 1499 tf_weights_only_checkpoint_exists = file_io.file_exists_v2( 1500 filepath + '.index') 1501 return tf_saved_model_exists or tf_weights_only_checkpoint_exists 1502 1503 def _get_most_recently_modified_file_matching_pattern(self, pattern): 1504 """Returns the most recently modified filepath matching pattern. 1505 1506 Pattern may contain python formatting placeholder. If 1507 `tf.train.latest_checkpoint()` does not return None, use that; otherwise, 1508 check for most recently modified one that matches the pattern. 1509 1510 In the rare case where there are more than one pattern-matching file having 1511 the same modified time that is most recent among all, return the filepath 1512 that is largest (by `>` operator, lexicographically using the numeric 1513 equivalents). This provides a tie-breaker when multiple files are most 1514 recent. Note that a larger `filepath` can sometimes indicate a later time of 1515 modification (for instance, when epoch/batch is used as formatting option), 1516 but not necessarily (when accuracy or loss is used). The tie-breaker is 1517 put in the logic as best effort to return the most recent, and to avoid 1518 undeterministic result. 1519 1520 Modified time of a file is obtained with `os.path.getmtime()`. 1521 1522 This utility function is best demonstrated via an example: 1523 1524 ```python 1525 file_pattern = 'f.batch{batch:02d}epoch{epoch:02d}.h5' 1526 test_dir = self.get_temp_dir() 1527 path_pattern = os.path.join(test_dir, file_pattern) 1528 file_paths = [ 1529 os.path.join(test_dir, file_name) for file_name in 1530 ['f.batch03epoch02.h5', 'f.batch02epoch02.h5', 'f.batch01epoch01.h5'] 1531 ] 1532 for file_path in file_paths: 1533 # Write something to each of the files 1534 self.assertEqual( 1535 _get_most_recently_modified_file_matching_pattern(path_pattern), 1536 file_paths[-1]) 1537 ``` 1538 1539 Args: 1540 pattern: The file pattern that may optionally contain python placeholder 1541 such as `{epoch:02d}`. 1542 1543 Returns: 1544 The most recently modified file's full filepath matching `pattern`. If 1545 `pattern` does not contain any placeholder, this returns the filepath 1546 that 1547 exactly matches `pattern`. Returns `None` if no match is found. 1548 """ 1549 dir_name = os.path.dirname(pattern) 1550 base_name = os.path.basename(pattern) 1551 base_name_regex = '^' + re.sub(r'{.*}', r'.*', base_name) + '$' 1552 1553 # If tf.train.latest_checkpoint tells us there exists a latest checkpoint, 1554 # use that as it is more robust than `os.path.getmtime()`. 1555 latest_tf_checkpoint = checkpoint_management.latest_checkpoint(dir_name) 1556 if latest_tf_checkpoint is not None and re.match( 1557 base_name_regex, os.path.basename(latest_tf_checkpoint)): 1558 return latest_tf_checkpoint 1559 1560 latest_mod_time = 0 1561 file_path_with_latest_mod_time = None 1562 n_file_with_latest_mod_time = 0 1563 file_path_with_largest_file_name = None 1564 1565 if file_io.file_exists_v2(dir_name): 1566 for file_name in os.listdir(dir_name): 1567 # Only consider if `file_name` matches the pattern. 1568 if re.match(base_name_regex, file_name): 1569 file_path = os.path.join(dir_name, file_name) 1570 mod_time = os.path.getmtime(file_path) 1571 if (file_path_with_largest_file_name is None or 1572 file_path > file_path_with_largest_file_name): 1573 file_path_with_largest_file_name = file_path 1574 if mod_time > latest_mod_time: 1575 latest_mod_time = mod_time 1576 file_path_with_latest_mod_time = file_path 1577 # In the case a file with later modified time is found, reset 1578 # the counter for the number of files with latest modified time. 1579 n_file_with_latest_mod_time = 1 1580 elif mod_time == latest_mod_time: 1581 # In the case a file has modified time tied with the most recent, 1582 # increment the counter for the number of files with latest modified 1583 # time by 1. 1584 n_file_with_latest_mod_time += 1 1585 1586 if n_file_with_latest_mod_time == 1: 1587 # Return the sole file that has most recent modified time. 1588 return file_path_with_latest_mod_time 1589 else: 1590 # If there are more than one file having latest modified time, return 1591 # the file path with the largest file name. 1592 return file_path_with_largest_file_name 1593 1594 1595@keras_export('keras.callbacks.experimental.BackupAndRestore', v1=[]) 1596class BackupAndRestore(Callback): 1597 """Callback to back up and restore the training state. 1598 1599 `BackupAndRestore` callback is intended to recover from interruptions that 1600 happened in the middle of a model.fit execution by backing up the 1601 training states in a temporary checkpoint file (based on TF CheckpointManager) 1602 at the end of each epoch. If training restarted before completion, the 1603 training state and model are restored to the most recently saved state at the 1604 beginning of a new model.fit() run. 1605 Note that user is responsible to bring jobs back up. 1606 This callback is important for the backup and restore mechanism for fault 1607 tolerance purpose. And the model to be restored from an previous checkpoint is 1608 expected to be the same as the one used to back up. If user changes arguments 1609 passed to compile or fit, the checkpoint saved for fault tolerance can become 1610 invalid. 1611 1612 Note: 1613 1. This callback is not compatible with disabling eager execution. 1614 2. A checkpoint is saved at the end of each epoch, when restoring we'll redo 1615 any partial work from an unfinished epoch in which the training got restarted 1616 (so the work done before a interruption doesn't affect the final model state). 1617 3. This works for both single worker and multi-worker mode, only 1618 MirroredStrategy and MultiWorkerMirroredStrategy are supported for now. 1619 1620 Example: 1621 1622 >>> class InterruptingCallback(tf.keras.callbacks.Callback): 1623 ... def on_epoch_begin(self, epoch, logs=None): 1624 ... if epoch == 4: 1625 ... raise RuntimeError('Interrupting!') 1626 >>> callback = tf.keras.callbacks.experimental.BackupAndRestore( 1627 ... backup_dir="/tmp/backup") 1628 >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 1629 >>> model.compile(tf.keras.optimizers.SGD(), loss='mse') 1630 >>> try: 1631 ... model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10, 1632 ... batch_size=1, callbacks=[callback, InterruptingCallback()], 1633 ... verbose=0) 1634 ... except: 1635 ... pass 1636 >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10, 1637 ... batch_size=1, callbacks=[callback], verbose=0) 1638 >>> # Only 6 more epochs are run, since first trainning got interrupted at 1639 >>> # zero-indexed epoch 4, second training will continue from 4 to 9. 1640 >>> len(history.history['loss']) 1641 6 1642 1643 Args: 1644 backup_dir: String, path to store the checkpoint. 1645 e.g. backup_dir = os.path.join(working_dir, 'backup') 1646 This is the directory in which the system stores temporary files to 1647 recover the model from jobs terminated unexpectedly. The directory 1648 cannot be reused elsewhere to store other files, e.g. by 1649 BackupAndRestore callback of another training, or by another callback 1650 (ModelCheckpoint) of the same training. 1651 """ 1652 1653 def __init__(self, backup_dir): 1654 super(BackupAndRestore, self).__init__() 1655 self.backup_dir = backup_dir 1656 self._supports_tf_logs = True 1657 self._supported_strategies = ( 1658 mirrored_strategy.MirroredStrategy, 1659 collective_all_reduce_strategy.CollectiveAllReduceStrategy, 1660 tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2, 1661 parameter_server_strategy_v2.ParameterServerStrategyV2) 1662 1663 if not context.executing_eagerly(): 1664 if ops.inside_function(): 1665 raise ValueError('This Callback\'s method contains Python state and ' 1666 'should be called outside of `tf.function`s.') 1667 else: # Legacy graph mode: 1668 raise ValueError( 1669 'BackupAndRestore only supports eager mode. In graph ' 1670 'mode, consider using ModelCheckpoint to manually save ' 1671 'and restore weights with `model.load_weights()` and by ' 1672 'providing `initial_epoch` in `model.fit()` for fault tolerance.') 1673 1674 # Only the chief worker writes model checkpoints, but all workers 1675 # restore checkpoint at on_train_begin(). 1676 self._chief_worker_only = False 1677 1678 def on_train_begin(self, logs=None): 1679 # TrainingState is used to manage the training state needed for 1680 # failure-recovery of a worker in training. 1681 # pylint: disable=protected-access 1682 1683 if self.model._distribution_strategy and not isinstance( 1684 self.model.distribute_strategy, self._supported_strategies): 1685 raise NotImplementedError( 1686 '%s is not supported yet. ' 1687 'Currently BackupAndRestore callback only supports empty strategy, ' 1688 'MirroredStrategy, MultiWorkerMirroredStrategy and TPUStrategy.' % 1689 type(self.model.distribute_strategy).__name__) 1690 self.model._training_state = ( 1691 worker_training_state.WorkerTrainingState(self.model, self.backup_dir)) 1692 self._training_state = self.model._training_state 1693 self._training_state.restore() 1694 1695 def on_train_end(self, logs=None): 1696 # pylint: disable=protected-access 1697 # On exit of training, delete the training state backup file that was saved 1698 # for the purpose of worker recovery. 1699 self._training_state.delete_backup() 1700 1701 # Clean up the training state. 1702 del self._training_state 1703 del self.model._training_state 1704 1705 def on_epoch_end(self, epoch, logs=None): 1706 # Back up the model and current epoch for possible future recovery. 1707 self._training_state.back_up(epoch) 1708 1709 1710@keras_export('keras.callbacks.EarlyStopping') 1711class EarlyStopping(Callback): 1712 """Stop training when a monitored metric has stopped improving. 1713 1714 Assuming the goal of a training is to minimize the loss. With this, the 1715 metric to be monitored would be `'loss'`, and mode would be `'min'`. A 1716 `model.fit()` training loop will check at end of every epoch whether 1717 the loss is no longer decreasing, considering the `min_delta` and 1718 `patience` if applicable. Once it's found no longer decreasing, 1719 `model.stop_training` is marked True and the training terminates. 1720 1721 The quantity to be monitored needs to be available in `logs` dict. 1722 To make it so, pass the loss or metrics at `model.compile()`. 1723 1724 Args: 1725 monitor: Quantity to be monitored. 1726 min_delta: Minimum change in the monitored quantity 1727 to qualify as an improvement, i.e. an absolute 1728 change of less than min_delta, will count as no 1729 improvement. 1730 patience: Number of epochs with no improvement 1731 after which training will be stopped. 1732 verbose: verbosity mode. 1733 mode: One of `{"auto", "min", "max"}`. In `min` mode, 1734 training will stop when the quantity 1735 monitored has stopped decreasing; in `"max"` 1736 mode it will stop when the quantity 1737 monitored has stopped increasing; in `"auto"` 1738 mode, the direction is automatically inferred 1739 from the name of the monitored quantity. 1740 baseline: Baseline value for the monitored quantity. 1741 Training will stop if the model doesn't show improvement over the 1742 baseline. 1743 restore_best_weights: Whether to restore model weights from 1744 the epoch with the best value of the monitored quantity. 1745 If False, the model weights obtained at the last step of 1746 training are used. An epoch will be restored regardless 1747 of the performance relative to the `baseline`. If no epoch 1748 improves on `baseline`, training will run for `patience` 1749 epochs and restore weights from the best epoch in that set. 1750 1751 Example: 1752 1753 >>> callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3) 1754 >>> # This callback will stop the training when there is no improvement in 1755 >>> # the loss for three consecutive epochs. 1756 >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 1757 >>> model.compile(tf.keras.optimizers.SGD(), loss='mse') 1758 >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), 1759 ... epochs=10, batch_size=1, callbacks=[callback], 1760 ... verbose=0) 1761 >>> len(history.history['loss']) # Only 4 epochs are run. 1762 4 1763 """ 1764 1765 def __init__(self, 1766 monitor='val_loss', 1767 min_delta=0, 1768 patience=0, 1769 verbose=0, 1770 mode='auto', 1771 baseline=None, 1772 restore_best_weights=False): 1773 super(EarlyStopping, self).__init__() 1774 1775 self.monitor = monitor 1776 self.patience = patience 1777 self.verbose = verbose 1778 self.baseline = baseline 1779 self.min_delta = abs(min_delta) 1780 self.wait = 0 1781 self.stopped_epoch = 0 1782 self.restore_best_weights = restore_best_weights 1783 self.best_weights = None 1784 1785 if mode not in ['auto', 'min', 'max']: 1786 logging.warning('EarlyStopping mode %s is unknown, ' 1787 'fallback to auto mode.', mode) 1788 mode = 'auto' 1789 1790 if mode == 'min': 1791 self.monitor_op = np.less 1792 elif mode == 'max': 1793 self.monitor_op = np.greater 1794 else: 1795 if 'acc' in self.monitor: 1796 self.monitor_op = np.greater 1797 else: 1798 self.monitor_op = np.less 1799 1800 if self.monitor_op == np.greater: 1801 self.min_delta *= 1 1802 else: 1803 self.min_delta *= -1 1804 1805 def on_train_begin(self, logs=None): 1806 # Allow instances to be re-used 1807 self.wait = 0 1808 self.stopped_epoch = 0 1809 self.best = np.Inf if self.monitor_op == np.less else -np.Inf 1810 self.best_weights = None 1811 1812 def on_epoch_end(self, epoch, logs=None): 1813 current = self.get_monitor_value(logs) 1814 if current is None: 1815 return 1816 if self.restore_best_weights and self.best_weights is None: 1817 # Restore the weights after first epoch if no progress is ever made. 1818 self.best_weights = self.model.get_weights() 1819 1820 self.wait += 1 1821 if self._is_improvement(current, self.best): 1822 self.best = current 1823 if self.restore_best_weights: 1824 self.best_weights = self.model.get_weights() 1825 # Only restart wait if we beat both the baseline and our previous best. 1826 if self.baseline is None or self._is_improvement(current, self.baseline): 1827 self.wait = 0 1828 1829 if self.wait >= self.patience: 1830 self.stopped_epoch = epoch 1831 self.model.stop_training = True 1832 if self.restore_best_weights and self.best_weights is not None: 1833 if self.verbose > 0: 1834 print('Restoring model weights from the end of the best epoch.') 1835 self.model.set_weights(self.best_weights) 1836 1837 def on_train_end(self, logs=None): 1838 if self.stopped_epoch > 0 and self.verbose > 0: 1839 print('Epoch %05d: early stopping' % (self.stopped_epoch + 1)) 1840 1841 def get_monitor_value(self, logs): 1842 logs = logs or {} 1843 monitor_value = logs.get(self.monitor) 1844 if monitor_value is None: 1845 logging.warning('Early stopping conditioned on metric `%s` ' 1846 'which is not available. Available metrics are: %s', 1847 self.monitor, ','.join(list(logs.keys()))) 1848 return monitor_value 1849 1850 def _is_improvement(self, monitor_value, reference_value): 1851 return self.monitor_op(monitor_value - self.min_delta, reference_value) 1852 1853 1854@keras_export('keras.callbacks.RemoteMonitor') 1855class RemoteMonitor(Callback): 1856 """Callback used to stream events to a server. 1857 1858 Requires the `requests` library. 1859 Events are sent to `root + '/publish/epoch/end/'` by default. Calls are 1860 HTTP POST, with a `data` argument which is a 1861 JSON-encoded dictionary of event data. 1862 If `send_as_json=True`, the content type of the request will be 1863 `"application/json"`. 1864 Otherwise the serialized JSON will be sent within a form. 1865 1866 Args: 1867 root: String; root url of the target server. 1868 path: String; path relative to `root` to which the events will be sent. 1869 field: String; JSON field under which the data will be stored. 1870 The field is used only if the payload is sent within a form 1871 (i.e. send_as_json is set to False). 1872 headers: Dictionary; optional custom HTTP headers. 1873 send_as_json: Boolean; whether the request should be 1874 sent as `"application/json"`. 1875 """ 1876 1877 def __init__(self, 1878 root='http://localhost:9000', 1879 path='/publish/epoch/end/', 1880 field='data', 1881 headers=None, 1882 send_as_json=False): 1883 super(RemoteMonitor, self).__init__() 1884 1885 self.root = root 1886 self.path = path 1887 self.field = field 1888 self.headers = headers 1889 self.send_as_json = send_as_json 1890 1891 def on_epoch_end(self, epoch, logs=None): 1892 if requests is None: 1893 raise ImportError('RemoteMonitor requires the `requests` library.') 1894 logs = logs or {} 1895 send = {} 1896 send['epoch'] = epoch 1897 for k, v in logs.items(): 1898 # np.ndarray and np.generic are not scalar types 1899 # therefore we must unwrap their scalar values and 1900 # pass to the json-serializable dict 'send' 1901 if isinstance(v, (np.ndarray, np.generic)): 1902 send[k] = v.item() 1903 else: 1904 send[k] = v 1905 try: 1906 if self.send_as_json: 1907 requests.post(self.root + self.path, json=send, headers=self.headers) 1908 else: 1909 requests.post( 1910 self.root + self.path, {self.field: json.dumps(send)}, 1911 headers=self.headers) 1912 except requests.exceptions.RequestException: 1913 logging.warning('Warning: could not reach RemoteMonitor ' 1914 'root server at ' + str(self.root)) 1915 1916 1917@keras_export('keras.callbacks.LearningRateScheduler') 1918class LearningRateScheduler(Callback): 1919 """Learning rate scheduler. 1920 1921 At the beginning of every epoch, this callback gets the updated learning rate 1922 value from `schedule` function provided at `__init__`, with the current epoch 1923 and current learning rate, and applies the updated learning rate 1924 on the optimizer. 1925 1926 Args: 1927 schedule: a function that takes an epoch index (integer, indexed from 0) 1928 and current learning rate (float) as inputs and returns a new 1929 learning rate as output (float). 1930 verbose: int. 0: quiet, 1: update messages. 1931 1932 Example: 1933 1934 >>> # This function keeps the initial learning rate for the first ten epochs 1935 >>> # and decreases it exponentially after that. 1936 >>> def scheduler(epoch, lr): 1937 ... if epoch < 10: 1938 ... return lr 1939 ... else: 1940 ... return lr * tf.math.exp(-0.1) 1941 >>> 1942 >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 1943 >>> model.compile(tf.keras.optimizers.SGD(), loss='mse') 1944 >>> round(model.optimizer.lr.numpy(), 5) 1945 0.01 1946 1947 >>> callback = tf.keras.callbacks.LearningRateScheduler(scheduler) 1948 >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), 1949 ... epochs=15, callbacks=[callback], verbose=0) 1950 >>> round(model.optimizer.lr.numpy(), 5) 1951 0.00607 1952 1953 """ 1954 1955 def __init__(self, schedule, verbose=0): 1956 super(LearningRateScheduler, self).__init__() 1957 self.schedule = schedule 1958 self.verbose = verbose 1959 1960 def on_epoch_begin(self, epoch, logs=None): 1961 if not hasattr(self.model.optimizer, 'lr'): 1962 raise ValueError('Optimizer must have a "lr" attribute.') 1963 try: # new API 1964 lr = float(backend.get_value(self.model.optimizer.lr)) 1965 lr = self.schedule(epoch, lr) 1966 except TypeError: # Support for old API for backward compatibility 1967 lr = self.schedule(epoch) 1968 if not isinstance(lr, (ops.Tensor, float, np.float32, np.float64)): 1969 raise ValueError('The output of the "schedule" function ' 1970 'should be float.') 1971 if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating: 1972 raise ValueError('The dtype of Tensor should be float') 1973 backend.set_value(self.model.optimizer.lr, backend.get_value(lr)) 1974 if self.verbose > 0: 1975 print('\nEpoch %05d: LearningRateScheduler setting learning ' 1976 'rate to %s.' % (epoch + 1, lr)) 1977 1978 def on_epoch_end(self, epoch, logs=None): 1979 logs = logs or {} 1980 logs['lr'] = backend.get_value(self.model.optimizer.lr) 1981 1982 1983def keras_model_summary(name, data, step=None): 1984 """Writes a Keras model as JSON to as a Summary. 1985 1986 Writing the Keras model configuration allows the TensorBoard graph plugin to 1987 render a conceptual graph, as opposed to graph of ops. In case the model fails 1988 to serialize as JSON, it ignores and returns False. 1989 1990 Args: 1991 name: A name for this summary. The summary tag used for TensorBoard will be 1992 this name prefixed by any active name scopes. 1993 data: A Keras Model to write. 1994 step: Explicit `int64`-castable monotonic step value for this summary. If 1995 omitted, this defaults to `tf.summary.experimental.get_step()`, which must 1996 not be None. 1997 1998 Returns: 1999 True on success, or False if no summary was written because no default 2000 summary writer was available. 2001 2002 Raises: 2003 ValueError: if a default writer exists, but no step was provided and 2004 `tf.summary.experimental.get_step()` is None. 2005 """ 2006 summary_metadata = summary_pb2.SummaryMetadata() 2007 # Hard coding a plugin name. Please refer to go/tb-plugin-name-hardcode for 2008 # the rationale. 2009 summary_metadata.plugin_data.plugin_name = 'graph_keras_model' 2010 # version number = 1 2011 summary_metadata.plugin_data.content = b'1' 2012 2013 try: 2014 json_string = data.to_json() 2015 except Exception as exc: # pylint: disable=broad-except 2016 # An exception should not break a model code. 2017 logging.warning('Model failed to serialize as JSON. Ignoring... %s', exc) 2018 return False 2019 2020 with summary_ops_v2.summary_scope(name, 'graph_keras_model', 2021 [data, step]) as (tag, _): 2022 with ops.device('cpu:0'): 2023 tensor = constant_op.constant(json_string, dtype=dtypes.string) 2024 return summary_ops_v2.write( 2025 tag=tag, tensor=tensor, step=step, metadata=summary_metadata) 2026 2027 2028@keras_export('keras.callbacks.TensorBoard', v1=[]) 2029class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): 2030 # pylint: disable=line-too-long 2031 """Enable visualizations for TensorBoard. 2032 2033 TensorBoard is a visualization tool provided with TensorFlow. 2034 2035 This callback logs events for TensorBoard, including: 2036 2037 * Metrics summary plots 2038 * Training graph visualization 2039 * Activation histograms 2040 * Sampled profiling 2041 2042 When used in `Model.evaluate`, in addition to epoch summaries, there will be 2043 a summary that records evaluation metrics vs `Model.optimizer.iterations` 2044 written. The metric names will be prepended with `evaluation`, with 2045 `Model.optimizer.iterations` being the step in the visualized TensorBoard. 2046 2047 If you have installed TensorFlow with pip, you should be able 2048 to launch TensorBoard from the command line: 2049 2050 ``` 2051 tensorboard --logdir=path_to_your_logs 2052 ``` 2053 2054 You can find more information about TensorBoard 2055 [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard). 2056 2057 Args: 2058 log_dir: the path of the directory where to save the log files to be 2059 parsed by TensorBoard. e.g. log_dir = os.path.join(working_dir, 'logs') 2060 This directory should not be reused by any other callbacks. 2061 histogram_freq: frequency (in epochs) at which to compute activation and 2062 weight histograms for the layers of the model. If set to 0, histograms 2063 won't be computed. Validation data (or split) must be specified for 2064 histogram visualizations. 2065 write_graph: whether to visualize the graph in TensorBoard. The log file 2066 can become quite large when write_graph is set to True. 2067 write_images: whether to write model weights to visualize as image in 2068 TensorBoard. 2069 write_steps_per_second: whether to log the training steps per second into 2070 Tensorboard. This supports both epoch and batch frequency logging. 2071 update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`, 2072 writes the losses and metrics to TensorBoard after each batch. The same 2073 applies for `'epoch'`. If using an integer, let's say `1000`, the 2074 callback will write the metrics and losses to TensorBoard every 1000 2075 batches. Note that writing too frequently to TensorBoard can slow down 2076 your training. 2077 profile_batch: Profile the batch(es) to sample compute characteristics. 2078 profile_batch must be a non-negative integer or a tuple of integers. 2079 A pair of positive integers signify a range of batches to profile. 2080 By default, it will profile the second batch. Set profile_batch=0 2081 to disable profiling. 2082 embeddings_freq: frequency (in epochs) at which embedding layers will be 2083 visualized. If set to 0, embeddings won't be visualized. 2084 embeddings_metadata: Dictionary which maps embedding layer names to the 2085 filename of a file in which to save metadata for the embedding layer. 2086 In case the same metadata file is to be 2087 used for all embedding layers, a single filename can be passed. 2088 2089 Examples: 2090 2091 Basic usage: 2092 2093 ```python 2094 tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs") 2095 model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) 2096 # Then run the tensorboard command to view the visualizations. 2097 ``` 2098 2099 Custom batch-level summaries in a subclassed Model: 2100 2101 ```python 2102 class MyModel(tf.keras.Model): 2103 2104 def build(self, _): 2105 self.dense = tf.keras.layers.Dense(10) 2106 2107 def call(self, x): 2108 outputs = self.dense(x) 2109 tf.summary.histogram('outputs', outputs) 2110 return outputs 2111 2112 model = MyModel() 2113 model.compile('sgd', 'mse') 2114 2115 # Make sure to set `update_freq=N` to log a batch-level summary every N batches. 2116 # In addition to any `tf.summary` contained in `Model.call`, metrics added in 2117 # `Model.compile` will be logged every N batches. 2118 tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1) 2119 model.fit(x_train, y_train, callbacks=[tb_callback]) 2120 ``` 2121 2122 Custom batch-level summaries in a Functional API Model: 2123 2124 ```python 2125 def my_summary(x): 2126 tf.summary.histogram('x', x) 2127 return x 2128 2129 inputs = tf.keras.Input(10) 2130 x = tf.keras.layers.Dense(10)(inputs) 2131 outputs = tf.keras.layers.Lambda(my_summary)(x) 2132 model = tf.keras.Model(inputs, outputs) 2133 model.compile('sgd', 'mse') 2134 2135 # Make sure to set `update_freq=N` to log a batch-level summary every N batches. 2136 # In addition to any `tf.summary` contained in `Model.call`, metrics added in 2137 # `Model.compile` will be logged every N batches. 2138 tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1) 2139 model.fit(x_train, y_train, callbacks=[tb_callback]) 2140 ``` 2141 2142 Profiling: 2143 2144 ```python 2145 # Profile a single batch, e.g. the 5th batch. 2146 tensorboard_callback = tf.keras.callbacks.TensorBoard( 2147 log_dir='./logs', profile_batch=5) 2148 model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) 2149 2150 # Profile a range of batches, e.g. from 10 to 20. 2151 tensorboard_callback = tf.keras.callbacks.TensorBoard( 2152 log_dir='./logs', profile_batch=(10,20)) 2153 model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) 2154 ``` 2155 """ 2156 2157 # pylint: enable=line-too-long 2158 2159 def __init__(self, 2160 log_dir='logs', 2161 histogram_freq=0, 2162 write_graph=True, 2163 write_images=False, 2164 write_steps_per_second=False, 2165 update_freq='epoch', 2166 profile_batch=2, 2167 embeddings_freq=0, 2168 embeddings_metadata=None, 2169 **kwargs): 2170 super(TensorBoard, self).__init__() 2171 self._supports_tf_logs = True 2172 self._validate_kwargs(kwargs) 2173 2174 self.log_dir = path_to_string(log_dir) 2175 self.histogram_freq = histogram_freq 2176 self.write_graph = write_graph 2177 self.write_images = write_images 2178 self.write_steps_per_second = write_steps_per_second 2179 self.update_freq = 1 if update_freq == 'batch' else update_freq 2180 self.embeddings_freq = embeddings_freq 2181 self.embeddings_metadata = embeddings_metadata 2182 self._init_profile_batch(profile_batch) 2183 self._global_train_batch = 0 2184 self._previous_epoch_iterations = 0 2185 self._train_accumulated_time = 0 2186 self._batch_start_time = 0 2187 2188 # Lazily initialized in order to avoid creating event files when 2189 # not needed. 2190 self._writers = {} 2191 2192 # Used to restore any existing `SummaryWriter` after training ends. 2193 self._prev_summary_state = [] 2194 2195 def _validate_kwargs(self, kwargs): 2196 """Handle arguments were supported in V1.""" 2197 if kwargs.get('write_grads', False): 2198 logging.warning('`write_grads` will be ignored in TensorFlow 2.0 ' 2199 'for the `TensorBoard` Callback.') 2200 if kwargs.get('batch_size', False): 2201 logging.warning('`batch_size` is no longer needed in the ' 2202 '`TensorBoard` Callback and will be ignored ' 2203 'in TensorFlow 2.0.') 2204 if kwargs.get('embeddings_layer_names', False): 2205 logging.warning('`embeddings_layer_names` is not supported in ' 2206 'TensorFlow 2.0. Instead, all `Embedding` layers ' 2207 'will be visualized.') 2208 if kwargs.get('embeddings_data', False): 2209 logging.warning('`embeddings_data` is not supported in TensorFlow ' 2210 '2.0. Instead, all `Embedding` variables will be ' 2211 'visualized.') 2212 2213 unrecognized_kwargs = set(kwargs.keys()) - { 2214 'write_grads', 'embeddings_layer_names', 'embeddings_data', 'batch_size' 2215 } 2216 2217 # Only allow kwargs that were supported in V1. 2218 if unrecognized_kwargs: 2219 raise ValueError('Unrecognized arguments in `TensorBoard` ' 2220 'Callback: ' + str(unrecognized_kwargs)) 2221 2222 def set_model(self, model): 2223 """Sets Keras model and writes graph if specified.""" 2224 self.model = model 2225 self._log_write_dir = self._get_log_write_dir() 2226 2227 self._train_dir = os.path.join(self._log_write_dir, 'train') 2228 self._train_step = self.model._train_counter # pylint: disable=protected-access 2229 2230 self._val_dir = os.path.join(self._log_write_dir, 'validation') 2231 self._val_step = self.model._test_counter # pylint: disable=protected-access 2232 2233 self._writers = {} # Resets writers. 2234 2235 self._should_write_train_graph = False 2236 if self.write_graph: 2237 self._write_keras_model_summary() 2238 self._should_write_train_graph = True 2239 if self.embeddings_freq: 2240 self._configure_embeddings() 2241 2242 @property 2243 def _train_writer(self): 2244 if 'train' not in self._writers: 2245 self._writers['train'] = summary_ops_v2.create_file_writer_v2( 2246 self._train_dir) 2247 return self._writers['train'] 2248 2249 @property 2250 def _val_writer(self): 2251 if 'val' not in self._writers: 2252 self._writers['val'] = summary_ops_v2.create_file_writer_v2(self._val_dir) 2253 return self._writers['val'] 2254 2255 def _get_log_write_dir(self): 2256 """For multi-worker, only chief should write, others write to '/tmp'.""" 2257 return distributed_file_utils.write_dirpath(self.log_dir, 2258 self.model.distribute_strategy) 2259 2260 def _delete_tmp_write_dir(self): 2261 """Deletes tmp write directories for multi-worker.""" 2262 distributed_file_utils.remove_temp_dirpath(self.log_dir, 2263 self.model.distribute_strategy) 2264 2265 def _write_keras_model_train_graph(self): 2266 """Writes Keras model train_function graph to TensorBoard.""" 2267 with self._train_writer.as_default(): 2268 with summary_ops_v2.record_if(True): 2269 train_fn = self.model.train_tf_function 2270 # If the train_function is a `tf.function`, we can write out a graph 2271 if hasattr(train_fn, 'function_spec'): 2272 summary_ops_v2.graph(train_fn._concrete_stateful_fn.graph) # pylint: disable=protected-access 2273 2274 def _write_keras_model_summary(self): 2275 """Writes Keras graph network summary to TensorBoard.""" 2276 with self._train_writer.as_default(): 2277 with summary_ops_v2.record_if(True): 2278 summary_writable = ( 2279 self.model._is_graph_network or # pylint: disable=protected-access 2280 self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access 2281 if summary_writable: 2282 keras_model_summary('keras', self.model, step=0) 2283 2284 def _configure_embeddings(self): 2285 """Configure the Projector for embeddings.""" 2286 # TODO(omalleyt): Add integration tests. 2287 from google.protobuf import text_format 2288 from tensorflow.python.keras.layers import embeddings 2289 from tensorflow.python.keras.protobuf import projector_config_pb2 2290 2291 config = projector_config_pb2.ProjectorConfig() 2292 for layer in self.model.layers: 2293 if isinstance(layer, embeddings.Embedding): 2294 embedding = config.embeddings.add() 2295 # Embeddings are always the first layer, so this naming should be 2296 # consistent in any keras models checkpoints. 2297 name = 'layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE' 2298 embedding.tensor_name = name 2299 2300 if self.embeddings_metadata is not None: 2301 if isinstance(self.embeddings_metadata, str): 2302 embedding.metadata_path = self.embeddings_metadata 2303 else: 2304 if layer.name in self.embeddings_metadata.keys(): 2305 embedding.metadata_path = self.embeddings_metadata.pop(layer.name) 2306 2307 if self.embeddings_metadata and not isinstance(self.embeddings_metadata, 2308 str): 2309 raise ValueError('Unrecognized `Embedding` layer names passed to ' 2310 '`keras.callbacks.TensorBoard` `embeddings_metadata` ' 2311 'argument: ' + str(self.embeddings_metadata.keys())) 2312 2313 config_pbtxt = text_format.MessageToString(config) 2314 path = os.path.join(self._log_write_dir, 'projector_config.pbtxt') 2315 with gfile.Open(path, 'w') as f: 2316 f.write(config_pbtxt) 2317 2318 def _push_writer(self, writer, step): 2319 """Sets the default writer for custom batch-level summaries.""" 2320 if self.update_freq == 'epoch': 2321 return 2322 2323 should_record = lambda: math_ops.equal(step % self.update_freq, 0) 2324 # TODO(b/151339474): Fix deadlock when not using .value() here. 2325 summary_context = (writer.as_default(step.value()), 2326 summary_ops_v2.record_if(should_record)) 2327 self._prev_summary_state.append(summary_context) 2328 summary_context[0].__enter__() 2329 summary_context[1].__enter__() 2330 2331 def _pop_writer(self): 2332 """Pops the current writer.""" 2333 if self.update_freq == 'epoch': 2334 return 2335 2336 # See _push_writer for the content of the previous_context, which is pair 2337 # of context. 2338 previous_context = self._prev_summary_state.pop() 2339 previous_context[1].__exit__(*sys.exc_info()) 2340 previous_context[0].__exit__(*sys.exc_info()) 2341 2342 def _close_writers(self): 2343 for writer in self._writers.values(): 2344 writer.close() 2345 2346 def _init_profile_batch(self, profile_batch): 2347 """Validate profile_batch value and set the range of batches to profile. 2348 Sets values of _start_batch and _stop_batch attributes, 2349 specifying the start and stop batch to profile. 2350 Setting `profile_batch=0` disables profiling. 2351 2352 Args: 2353 profile_batch: The range of batches to profile. Should be a non-negative 2354 integer or a comma separated string of pair of positive integers. A pair 2355 of positive integers signify a range of batches to profile. 2356 2357 Raises: 2358 ValueError: If profile_batch is not an integer or a comma seperated pair 2359 of positive integers. 2360 2361 """ 2362 profile_batch_error_message = ( 2363 'profile_batch must be a non-negative integer or 2-tuple of positive ' 2364 'integers. A pair of positive integers signifies a range of batches ' 2365 'to profile. Found: {}'.format(profile_batch)) 2366 2367 # Support legacy way of specifying "start,stop" or "start" as str. 2368 if isinstance(profile_batch, str): 2369 profile_batch = str(profile_batch).split(',') 2370 profile_batch = nest.map_structure(int, profile_batch) 2371 2372 if isinstance(profile_batch, int): 2373 self._start_batch = profile_batch 2374 self._stop_batch = profile_batch 2375 elif isinstance(profile_batch, (tuple, list)) and len(profile_batch) == 2: 2376 self._start_batch, self._stop_batch = profile_batch 2377 else: 2378 raise ValueError(profile_batch_error_message) 2379 2380 if self._start_batch < 0 or self._stop_batch < self._start_batch: 2381 raise ValueError(profile_batch_error_message) 2382 2383 # True when the profiler was successfully started by this callback. 2384 # We track the status here to make sure callbacks do not interfere with 2385 # each other. The callback will only stop the profiler it started. 2386 self._profiler_started = False 2387 if self._start_batch > 0: 2388 # Warm up and improve the profiling accuracy. 2389 self._start_profiler(logdir='') 2390 self._stop_profiler(save=False) 2391 # True when a trace is running. 2392 self._is_tracing = False 2393 2394 # Setting `profile_batch=0` disables profiling. 2395 self._should_trace = not (self._start_batch == 0 and self._stop_batch == 0) 2396 2397 def on_train_begin(self, logs=None): 2398 self._global_train_batch = 0 2399 self._previous_epoch_iterations = 0 2400 self._train_accumulated_time = 0 2401 self._push_writer(self._train_writer, self._train_step) 2402 2403 def on_train_end(self, logs=None): 2404 self._pop_writer() 2405 2406 if self._is_tracing: 2407 self._stop_trace() 2408 2409 self._close_writers() 2410 self._delete_tmp_write_dir() 2411 2412 def on_test_begin(self, logs=None): 2413 self._push_writer(self._val_writer, self._val_step) 2414 2415 def on_test_end(self, logs=None): 2416 if self.model.optimizer and hasattr(self.model.optimizer, 'iterations'): 2417 with summary_ops_v2.record_if(True), self._val_writer.as_default(): 2418 for name, value in logs.items(): 2419 summary_ops_v2.scalar( 2420 'evaluation_' + name + '_vs_iterations', 2421 value, 2422 step=self.model.optimizer.iterations.read_value()) 2423 self._pop_writer() 2424 2425 def _implements_train_batch_hooks(self): 2426 # Only call batch hooks when tracing or write_steps_per_second are enabled 2427 return self._should_trace or self.write_steps_per_second 2428 2429 def on_train_batch_begin(self, batch, logs=None): 2430 self._global_train_batch += 1 2431 if self.write_steps_per_second: 2432 self._batch_start_time = time.time() 2433 if not self._should_trace: 2434 return 2435 2436 if self._global_train_batch == self._start_batch: 2437 self._start_trace() 2438 2439 def on_train_batch_end(self, batch, logs=None): 2440 if self._should_write_train_graph: 2441 self._write_keras_model_train_graph() 2442 self._should_write_train_graph = False 2443 if self.write_steps_per_second: 2444 batch_run_time = time.time() - self._batch_start_time 2445 self._train_accumulated_time += batch_run_time 2446 summary_ops_v2.scalar( 2447 'batch_steps_per_second', 1. / batch_run_time, step=self._train_step) 2448 if not self._should_trace: 2449 return 2450 2451 if self._is_tracing and self._global_train_batch >= self._stop_batch: 2452 self._stop_trace() 2453 2454 def on_epoch_begin(self, epoch, logs=None): 2455 # Keeps track of epoch for profiling. 2456 if self.write_steps_per_second: 2457 self._previous_epoch_iterations = self.model.optimizer.iterations.numpy() 2458 self._train_accumulated_time = 0 2459 2460 def on_epoch_end(self, epoch, logs=None): 2461 """Runs metrics and histogram summaries at epoch end.""" 2462 self._log_epoch_metrics(epoch, logs) 2463 2464 if self.histogram_freq and epoch % self.histogram_freq == 0: 2465 self._log_weights(epoch) 2466 2467 if self.embeddings_freq and epoch % self.embeddings_freq == 0: 2468 self._log_embeddings(epoch) 2469 2470 def _start_trace(self): 2471 summary_ops_v2.trace_on(graph=True, profiler=False) 2472 self._start_profiler(logdir=self._train_dir) 2473 self._is_tracing = True 2474 2475 def _stop_trace(self, batch=None): 2476 """Logs the trace graph to TensorBoard.""" 2477 if batch is None: 2478 batch = self._stop_batch 2479 with self._train_writer.as_default(): 2480 with summary_ops_v2.record_if(True): 2481 # TODO(b/126388999): Remove step info in the summary name. 2482 summary_ops_v2.trace_export(name='batch_%d' % batch, step=batch) 2483 self._stop_profiler() 2484 self._is_tracing = False 2485 2486 def _collect_learning_rate(self, logs): 2487 lr_schedule = getattr(self.model.optimizer, 'lr', None) 2488 if isinstance(lr_schedule, learning_rate_schedule.LearningRateSchedule): 2489 logs['learning_rate'] = lr_schedule(self.model.optimizer.iterations) 2490 return logs 2491 2492 def _compute_steps_per_second(self): 2493 current_iteration = self.model.optimizer.iterations.numpy() 2494 steps_per_second = ((current_iteration - self._previous_epoch_iterations) / 2495 (self._train_accumulated_time)) 2496 return steps_per_second 2497 2498 def _log_epoch_metrics(self, epoch, logs): 2499 """Writes epoch metrics out as scalar summaries. 2500 2501 Args: 2502 epoch: Int. The global step to use for TensorBoard. 2503 logs: Dict. Keys are scalar summary names, values are scalars. 2504 """ 2505 if not logs: 2506 return 2507 2508 train_logs = {k: v for k, v in logs.items() if not k.startswith('val_')} 2509 val_logs = {k: v for k, v in logs.items() if k.startswith('val_')} 2510 train_logs = self._collect_learning_rate(train_logs) 2511 if self.write_steps_per_second: 2512 train_logs['steps_per_second'] = self._compute_steps_per_second() 2513 2514 with summary_ops_v2.record_if(True): 2515 if train_logs: 2516 with self._train_writer.as_default(): 2517 for name, value in train_logs.items(): 2518 summary_ops_v2.scalar('epoch_' + name, value, step=epoch) 2519 if val_logs: 2520 with self._val_writer.as_default(): 2521 for name, value in val_logs.items(): 2522 name = name[4:] # Remove 'val_' prefix. 2523 summary_ops_v2.scalar('epoch_' + name, value, step=epoch) 2524 2525 def _log_weights(self, epoch): 2526 """Logs the weights of the Model to TensorBoard.""" 2527 with self._train_writer.as_default(): 2528 with summary_ops_v2.record_if(True): 2529 for layer in self.model.layers: 2530 for weight in layer.weights: 2531 weight_name = weight.name.replace(':', '_') 2532 summary_ops_v2.histogram(weight_name, weight, step=epoch) 2533 if self.write_images: 2534 self._log_weight_as_image(weight, weight_name, epoch) 2535 self._train_writer.flush() 2536 2537 def _log_weight_as_image(self, weight, weight_name, epoch): 2538 """Logs a weight as a TensorBoard image.""" 2539 w_img = array_ops.squeeze(weight) 2540 shape = backend.int_shape(w_img) 2541 if len(shape) == 1: # Bias case 2542 w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1]) 2543 elif len(shape) == 2: # Dense layer kernel case 2544 if shape[0] > shape[1]: 2545 w_img = array_ops.transpose(w_img) 2546 shape = backend.int_shape(w_img) 2547 w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1]) 2548 elif len(shape) == 3: # ConvNet case 2549 if backend.image_data_format() == 'channels_last': 2550 # Switch to channels_first to display every kernel as a separate 2551 # image. 2552 w_img = array_ops.transpose(w_img, perm=[2, 0, 1]) 2553 shape = backend.int_shape(w_img) 2554 w_img = array_ops.reshape(w_img, [shape[0], shape[1], shape[2], 1]) 2555 2556 shape = backend.int_shape(w_img) 2557 # Not possible to handle 3D convnets etc. 2558 if len(shape) == 4 and shape[-1] in [1, 3, 4]: 2559 summary_ops_v2.image(weight_name, w_img, step=epoch) 2560 2561 def _log_embeddings(self, epoch): 2562 embeddings_ckpt = os.path.join(self._log_write_dir, 'train', 2563 'keras_embedding.ckpt-{}'.format(epoch)) 2564 self.model.save_weights(embeddings_ckpt) 2565 2566 def _start_profiler(self, logdir): 2567 """Starts the profiler if currently inactive. 2568 2569 Args: 2570 logdir: Directory where profiler results will be saved. 2571 """ 2572 if self._profiler_started: 2573 return 2574 try: 2575 profiler.start(logdir=logdir) 2576 self._profiler_started = True 2577 except errors.AlreadyExistsError as e: 2578 # Profiler errors should not be fatal. 2579 logging.error('Failed to start profiler: %s', e.message) 2580 2581 def _stop_profiler(self, save=True): 2582 """Stops the profiler if currently active. 2583 2584 Args: 2585 save: Whether to save the profiler results to TensorBoard. 2586 """ 2587 if not self._profiler_started: 2588 return 2589 try: 2590 profiler.stop(save=save) 2591 except errors.UnavailableError as e: 2592 # Profiler errors should not be fatal. 2593 logging.error('Failed to stop profiler: %s', e.message) 2594 finally: 2595 self._profiler_started = False 2596 2597 2598@keras_export('keras.callbacks.ReduceLROnPlateau') 2599class ReduceLROnPlateau(Callback): 2600 """Reduce learning rate when a metric has stopped improving. 2601 2602 Models often benefit from reducing the learning rate by a factor 2603 of 2-10 once learning stagnates. This callback monitors a 2604 quantity and if no improvement is seen for a 'patience' number 2605 of epochs, the learning rate is reduced. 2606 2607 Example: 2608 2609 ```python 2610 reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, 2611 patience=5, min_lr=0.001) 2612 model.fit(X_train, Y_train, callbacks=[reduce_lr]) 2613 ``` 2614 2615 Args: 2616 monitor: quantity to be monitored. 2617 factor: factor by which the learning rate will be reduced. 2618 `new_lr = lr * factor`. 2619 patience: number of epochs with no improvement after which learning rate 2620 will be reduced. 2621 verbose: int. 0: quiet, 1: update messages. 2622 mode: one of `{'auto', 'min', 'max'}`. In `'min'` mode, 2623 the learning rate will be reduced when the 2624 quantity monitored has stopped decreasing; in `'max'` mode it will be 2625 reduced when the quantity monitored has stopped increasing; in `'auto'` 2626 mode, the direction is automatically inferred from the name of the 2627 monitored quantity. 2628 min_delta: threshold for measuring the new optimum, to only focus on 2629 significant changes. 2630 cooldown: number of epochs to wait before resuming normal operation after 2631 lr has been reduced. 2632 min_lr: lower bound on the learning rate. 2633 """ 2634 2635 def __init__(self, 2636 monitor='val_loss', 2637 factor=0.1, 2638 patience=10, 2639 verbose=0, 2640 mode='auto', 2641 min_delta=1e-4, 2642 cooldown=0, 2643 min_lr=0, 2644 **kwargs): 2645 super(ReduceLROnPlateau, self).__init__() 2646 2647 self.monitor = monitor 2648 if factor >= 1.0: 2649 raise ValueError('ReduceLROnPlateau ' 'does not support a factor >= 1.0.') 2650 if 'epsilon' in kwargs: 2651 min_delta = kwargs.pop('epsilon') 2652 logging.warning('`epsilon` argument is deprecated and ' 2653 'will be removed, use `min_delta` instead.') 2654 self.factor = factor 2655 self.min_lr = min_lr 2656 self.min_delta = min_delta 2657 self.patience = patience 2658 self.verbose = verbose 2659 self.cooldown = cooldown 2660 self.cooldown_counter = 0 # Cooldown counter. 2661 self.wait = 0 2662 self.best = 0 2663 self.mode = mode 2664 self.monitor_op = None 2665 self._reset() 2666 2667 def _reset(self): 2668 """Resets wait counter and cooldown counter. 2669 """ 2670 if self.mode not in ['auto', 'min', 'max']: 2671 logging.warning('Learning rate reduction mode %s is unknown, ' 2672 'fallback to auto mode.', self.mode) 2673 self.mode = 'auto' 2674 if (self.mode == 'min' or 2675 (self.mode == 'auto' and 'acc' not in self.monitor)): 2676 self.monitor_op = lambda a, b: np.less(a, b - self.min_delta) 2677 self.best = np.Inf 2678 else: 2679 self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta) 2680 self.best = -np.Inf 2681 self.cooldown_counter = 0 2682 self.wait = 0 2683 2684 def on_train_begin(self, logs=None): 2685 self._reset() 2686 2687 def on_epoch_end(self, epoch, logs=None): 2688 logs = logs or {} 2689 logs['lr'] = backend.get_value(self.model.optimizer.lr) 2690 current = logs.get(self.monitor) 2691 if current is None: 2692 logging.warning('Learning rate reduction is conditioned on metric `%s` ' 2693 'which is not available. Available metrics are: %s', 2694 self.monitor, ','.join(list(logs.keys()))) 2695 2696 else: 2697 if self.in_cooldown(): 2698 self.cooldown_counter -= 1 2699 self.wait = 0 2700 2701 if self.monitor_op(current, self.best): 2702 self.best = current 2703 self.wait = 0 2704 elif not self.in_cooldown(): 2705 self.wait += 1 2706 if self.wait >= self.patience: 2707 old_lr = backend.get_value(self.model.optimizer.lr) 2708 if old_lr > np.float32(self.min_lr): 2709 new_lr = old_lr * self.factor 2710 new_lr = max(new_lr, self.min_lr) 2711 backend.set_value(self.model.optimizer.lr, new_lr) 2712 if self.verbose > 0: 2713 print('\nEpoch %05d: ReduceLROnPlateau reducing learning ' 2714 'rate to %s.' % (epoch + 1, new_lr)) 2715 self.cooldown_counter = self.cooldown 2716 self.wait = 0 2717 2718 def in_cooldown(self): 2719 return self.cooldown_counter > 0 2720 2721 2722@keras_export('keras.callbacks.CSVLogger') 2723class CSVLogger(Callback): 2724 """Callback that streams epoch results to a CSV file. 2725 2726 Supports all values that can be represented as a string, 2727 including 1D iterables such as `np.ndarray`. 2728 2729 Example: 2730 2731 ```python 2732 csv_logger = CSVLogger('training.log') 2733 model.fit(X_train, Y_train, callbacks=[csv_logger]) 2734 ``` 2735 2736 Args: 2737 filename: Filename of the CSV file, e.g. `'run/log.csv'`. 2738 separator: String used to separate elements in the CSV file. 2739 append: Boolean. True: append if file exists (useful for continuing 2740 training). False: overwrite existing file. 2741 """ 2742 2743 def __init__(self, filename, separator=',', append=False): 2744 self.sep = separator 2745 self.filename = path_to_string(filename) 2746 self.append = append 2747 self.writer = None 2748 self.keys = None 2749 self.append_header = True 2750 super(CSVLogger, self).__init__() 2751 2752 def on_train_begin(self, logs=None): 2753 if self.append: 2754 if file_io.file_exists_v2(self.filename): 2755 with gfile.GFile(self.filename, 'r') as f: 2756 self.append_header = not bool(len(f.readline())) 2757 mode = 'a' 2758 else: 2759 mode = 'w' 2760 self.csv_file = gfile.GFile(self.filename, mode) 2761 2762 def on_epoch_end(self, epoch, logs=None): 2763 logs = logs or {} 2764 2765 def handle_value(k): 2766 is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0 2767 if isinstance(k, str): 2768 return k 2769 elif isinstance(k, collections.abc.Iterable) and not is_zero_dim_ndarray: 2770 return '"[%s]"' % (', '.join(map(str, k))) 2771 else: 2772 return k 2773 2774 if self.keys is None: 2775 self.keys = sorted(logs.keys()) 2776 2777 if self.model.stop_training: 2778 # We set NA so that csv parsers do not fail for this last epoch. 2779 logs = dict((k, logs[k]) if k in logs else (k, 'NA') for k in self.keys) 2780 2781 if not self.writer: 2782 2783 class CustomDialect(csv.excel): 2784 delimiter = self.sep 2785 2786 fieldnames = ['epoch'] + self.keys 2787 2788 self.writer = csv.DictWriter( 2789 self.csv_file, 2790 fieldnames=fieldnames, 2791 dialect=CustomDialect) 2792 if self.append_header: 2793 self.writer.writeheader() 2794 2795 row_dict = collections.OrderedDict({'epoch': epoch}) 2796 row_dict.update((key, handle_value(logs[key])) for key in self.keys) 2797 self.writer.writerow(row_dict) 2798 self.csv_file.flush() 2799 2800 def on_train_end(self, logs=None): 2801 self.csv_file.close() 2802 self.writer = None 2803 2804 2805@keras_export('keras.callbacks.LambdaCallback') 2806class LambdaCallback(Callback): 2807 r"""Callback for creating simple, custom callbacks on-the-fly. 2808 2809 This callback is constructed with anonymous functions that will be called 2810 at the appropriate time (during `Model.{fit | evaluate | predict}`). 2811 Note that the callbacks expects positional arguments, as: 2812 2813 - `on_epoch_begin` and `on_epoch_end` expect two positional arguments: 2814 `epoch`, `logs` 2815 - `on_batch_begin` and `on_batch_end` expect two positional arguments: 2816 `batch`, `logs` 2817 - `on_train_begin` and `on_train_end` expect one positional argument: 2818 `logs` 2819 2820 Args: 2821 on_epoch_begin: called at the beginning of every epoch. 2822 on_epoch_end: called at the end of every epoch. 2823 on_batch_begin: called at the beginning of every batch. 2824 on_batch_end: called at the end of every batch. 2825 on_train_begin: called at the beginning of model training. 2826 on_train_end: called at the end of model training. 2827 2828 Example: 2829 2830 ```python 2831 # Print the batch number at the beginning of every batch. 2832 batch_print_callback = LambdaCallback( 2833 on_batch_begin=lambda batch,logs: print(batch)) 2834 2835 # Stream the epoch loss to a file in JSON format. The file content 2836 # is not well-formed JSON but rather has a JSON object per line. 2837 import json 2838 json_log = open('loss_log.json', mode='wt', buffering=1) 2839 json_logging_callback = LambdaCallback( 2840 on_epoch_end=lambda epoch, logs: json_log.write( 2841 json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'), 2842 on_train_end=lambda logs: json_log.close() 2843 ) 2844 2845 # Terminate some processes after having finished model training. 2846 processes = ... 2847 cleanup_callback = LambdaCallback( 2848 on_train_end=lambda logs: [ 2849 p.terminate() for p in processes if p.is_alive()]) 2850 2851 model.fit(..., 2852 callbacks=[batch_print_callback, 2853 json_logging_callback, 2854 cleanup_callback]) 2855 ``` 2856 """ 2857 2858 def __init__(self, 2859 on_epoch_begin=None, 2860 on_epoch_end=None, 2861 on_batch_begin=None, 2862 on_batch_end=None, 2863 on_train_begin=None, 2864 on_train_end=None, 2865 **kwargs): 2866 super(LambdaCallback, self).__init__() 2867 self.__dict__.update(kwargs) 2868 if on_epoch_begin is not None: 2869 self.on_epoch_begin = on_epoch_begin 2870 else: 2871 self.on_epoch_begin = lambda epoch, logs: None 2872 if on_epoch_end is not None: 2873 self.on_epoch_end = on_epoch_end 2874 else: 2875 self.on_epoch_end = lambda epoch, logs: None 2876 if on_batch_begin is not None: 2877 self.on_batch_begin = on_batch_begin 2878 else: 2879 self.on_batch_begin = lambda batch, logs: None 2880 if on_batch_end is not None: 2881 self.on_batch_end = on_batch_end 2882 else: 2883 self.on_batch_end = lambda batch, logs: None 2884 if on_train_begin is not None: 2885 self.on_train_begin = on_train_begin 2886 else: 2887 self.on_train_begin = lambda logs: None 2888 if on_train_end is not None: 2889 self.on_train_end = on_train_end 2890 else: 2891 self.on_train_end = lambda logs: None 2892