1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-import-not-at-top 16"""Callbacks: utilities called at certain points during model training. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import copy 24import csv 25import io 26import json 27import os 28import time 29 30import numpy as np 31import six 32 33from tensorflow.python.data.ops import iterator_ops 34from tensorflow.python.eager import context 35from tensorflow.python.framework import ops 36from tensorflow.python.keras import backend as K 37from tensorflow.python.keras.utils.data_utils import Sequence 38from tensorflow.python.keras.utils.generic_utils import Progbar 39from tensorflow.python.keras.utils.mode_keys import ModeKeys 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import summary_ops_v2 42from tensorflow.python.platform import tf_logging as logging 43from tensorflow.python.util.tf_export import keras_export 44 45try: 46 import requests 47except ImportError: 48 requests = None 49 50 51def configure_callbacks(callbacks, 52 model, 53 do_validation=False, 54 batch_size=None, 55 epochs=None, 56 steps_per_epoch=None, 57 samples=None, 58 verbose=1, 59 count_mode='steps', 60 mode=ModeKeys.TRAIN): 61 """Configures callbacks for use in various training loops. 62 63 Arguments: 64 callbacks: List of Callbacks. 65 model: Model being trained. 66 do_validation: Whether or not validation loop will be run. 67 batch_size: Number of samples per batch. 68 epochs: Number of epoch to train. 69 steps_per_epoch: Number of batches to run per training epoch. 70 samples: Number of training samples. 71 verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger. 72 count_mode: One of 'steps' or 'samples'. Per-batch or per-sample count. 73 mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT. 74 Which loop mode to configure callbacks for. 75 76 Returns: 77 Instance of CallbackList used to control all Callbacks. 78 """ 79 # Check if callbacks have already been configured. 80 if isinstance(callbacks, CallbackList): 81 return callbacks 82 83 if not callbacks: 84 callbacks = [] 85 86 # Add additional callbacks during training. 87 if mode == ModeKeys.TRAIN: 88 model.history = History() 89 callbacks = [BaseLogger()] + (callbacks or []) + [model.history] 90 if verbose: 91 callbacks.append(ProgbarLogger(count_mode)) 92 callback_list = CallbackList(callbacks) 93 94 # Set callback model 95 callback_model = model._get_callback_model() # pylint: disable=protected-access 96 callback_list.set_model(callback_model) 97 98 set_callback_parameters( 99 callback_list, 100 model, 101 do_validation=do_validation, 102 batch_size=batch_size, 103 epochs=epochs, 104 steps_per_epoch=steps_per_epoch, 105 samples=samples, 106 verbose=verbose, 107 mode=mode) 108 109 callback_list.model.stop_training = False 110 return callback_list 111 112 113def set_callback_parameters(callback_list, 114 model, 115 do_validation=False, 116 batch_size=None, 117 epochs=None, 118 steps_per_epoch=None, 119 samples=None, 120 verbose=1, 121 mode=ModeKeys.TRAIN): 122 """Sets callback parameters. 123 124 Arguments: 125 callback_list: CallbackList instance. 126 model: Model being trained. 127 do_validation: Whether or not validation loop will be run. 128 batch_size: Number of samples per batch. 129 epochs: Number of epoch to train. 130 steps_per_epoch: Number of batches to run per training epoch. 131 samples: Number of training samples. 132 verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger. 133 mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT. 134 Which loop mode to configure callbacks for. 135 """ 136 for cbk in callback_list: 137 if isinstance(cbk, (BaseLogger, ProgbarLogger)): 138 cbk.stateful_metrics = model.metrics_names[1:] # Exclude `loss` 139 140 # Set callback parameters 141 callback_metrics = [] 142 # When we have deferred build scenario with iterator input, we will compile 143 # when we standardize first batch of data. 144 if mode != ModeKeys.PREDICT and hasattr(model, 'metrics_names'): 145 callback_metrics = copy.copy(model.metrics_names) 146 if do_validation: 147 callback_metrics += ['val_' + n for n in model.metrics_names] 148 callback_params = { 149 'batch_size': batch_size, 150 'epochs': epochs, 151 'steps': steps_per_epoch, 152 'samples': samples, 153 'verbose': verbose, 154 'do_validation': do_validation, 155 'metrics': callback_metrics, 156 } 157 callback_list.set_params(callback_params) 158 159 160def _is_generator_like(data): 161 """Checks if data is a generator, Sequence, or Iterator.""" 162 return (hasattr(data, 'next') or hasattr(data, '__next__') or isinstance( 163 data, (Sequence, iterator_ops.Iterator, iterator_ops.EagerIterator))) 164 165 166def make_logs(model, logs, outputs, mode, prefix=''): 167 """Computes logs for sending to `on_batch_end` methods.""" 168 if mode in {ModeKeys.TRAIN, ModeKeys.TEST}: 169 if hasattr(model, 'metrics_names'): 170 for label, output in zip(model.metrics_names, outputs): 171 logs[prefix + label] = output 172 else: 173 logs['outputs'] = outputs 174 return logs 175 176 177class CallbackList(object): 178 """Container abstracting a list of callbacks. 179 180 Arguments: 181 callbacks: List of `Callback` instances. 182 queue_length: Queue length for keeping 183 running statistics over callback execution time. 184 """ 185 186 def __init__(self, callbacks=None, queue_length=10): 187 callbacks = callbacks or [] 188 self.callbacks = [c for c in callbacks] 189 self.queue_length = queue_length 190 self.params = {} 191 self.model = None 192 self._reset_batch_timing() 193 194 def _reset_batch_timing(self): 195 self._delta_t_batch = 0. 196 self._delta_ts = collections.defaultdict( 197 lambda: collections.deque([], maxlen=self.queue_length)) 198 199 def append(self, callback): 200 self.callbacks.append(callback) 201 202 def set_params(self, params): 203 self.params = params 204 for callback in self.callbacks: 205 callback.set_params(params) 206 207 def set_model(self, model): 208 self.model = model 209 for callback in self.callbacks: 210 callback.set_model(model) 211 212 def _call_batch_hook(self, mode, hook, batch, logs=None): 213 """Helper function for all batch_{begin | end} methods.""" 214 if not self.callbacks: 215 return 216 hook_name = 'on_{mode}_batch_{hook}'.format(mode=mode, hook=hook) 217 if hook == 'begin': 218 self._t_enter_batch = time.time() 219 if hook == 'end': 220 # Batch is ending, calculate batch time. 221 self._delta_t_batch = time.time() - self._t_enter_batch 222 223 logs = logs or {} 224 t_before_callbacks = time.time() 225 for callback in self.callbacks: 226 batch_hook = getattr(callback, hook_name) 227 batch_hook(batch, logs) 228 self._delta_ts[hook_name].append(time.time() - t_before_callbacks) 229 230 delta_t_median = np.median(self._delta_ts[hook_name]) 231 if (self._delta_t_batch > 0. and 232 delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1): 233 logging.warning( 234 'Method (%s) is slow compared ' 235 'to the batch update (%f). Check your callbacks.', hook_name, 236 delta_t_median) 237 238 def _call_begin_hook(self, mode): 239 """Helper function for on_{train|test|predict}_begin methods.""" 240 if mode == ModeKeys.TRAIN: 241 self.on_train_begin() 242 elif mode == ModeKeys.TEST: 243 self.on_test_begin() 244 else: 245 self.on_predict_begin() 246 247 def _call_end_hook(self, mode): 248 """Helper function for on_{train|test|predict}_end methods.""" 249 if mode == ModeKeys.TRAIN: 250 self.on_train_end() 251 elif mode == ModeKeys.TEST: 252 self.on_test_end() 253 else: 254 self.on_predict_end() 255 256 def on_batch_begin(self, batch, logs=None): 257 self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs) 258 259 def on_batch_end(self, batch, logs=None): 260 self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs) 261 262 def on_epoch_begin(self, epoch, logs=None): 263 """Calls the `on_epoch_begin` methods of its callbacks. 264 265 This function should only be called during TRAIN mode. 266 267 Arguments: 268 epoch: integer, index of epoch. 269 logs: dict. Currently no data is passed to this argument for this method 270 but that may change in the future. 271 """ 272 logs = logs or {} 273 for callback in self.callbacks: 274 callback.on_epoch_begin(epoch, logs) 275 self._reset_batch_timing() 276 277 def on_epoch_end(self, epoch, logs=None): 278 """Calls the `on_epoch_end` methods of its callbacks. 279 280 This function should only be called during TRAIN mode. 281 282 Arguments: 283 epoch: integer, index of epoch. 284 logs: dict, metric results for this training epoch, and for the 285 validation epoch if validation is performed. Validation result keys 286 are prefixed with `val_`. 287 """ 288 logs = logs or {} 289 for callback in self.callbacks: 290 callback.on_epoch_end(epoch, logs) 291 292 def on_train_batch_begin(self, batch, logs=None): 293 """Calls the `on_train_batch_begin` methods of its callbacks. 294 295 Arguments: 296 batch: integer, index of batch within the current epoch. 297 logs: dict. Has keys `batch` and `size` representing the current batch 298 number and the size of the batch. 299 """ 300 self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs) 301 302 def on_train_batch_end(self, batch, logs=None): 303 """Calls the `on_train_batch_end` methods of its callbacks. 304 305 Arguments: 306 batch: integer, index of batch within the current epoch. 307 logs: dict. Metric results for this batch. 308 """ 309 self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs) 310 311 def on_test_batch_begin(self, batch, logs=None): 312 """Calls the `on_test_batch_begin` methods of its callbacks. 313 314 Arguments: 315 batch: integer, index of batch within the current epoch. 316 logs: dict. Has keys `batch` and `size` representing the current batch 317 number and the size of the batch. 318 """ 319 self._call_batch_hook(ModeKeys.TEST, 'begin', batch, logs=logs) 320 321 def on_test_batch_end(self, batch, logs=None): 322 """Calls the `on_test_batch_end` methods of its callbacks. 323 324 Arguments: 325 batch: integer, index of batch within the current epoch. 326 logs: dict. Metric results for this batch. 327 """ 328 self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs) 329 330 def on_predict_batch_begin(self, batch, logs=None): 331 """Calls the `on_predict_batch_begin` methods of its callbacks. 332 333 Arguments: 334 batch: integer, index of batch within the current epoch. 335 logs: dict. Has keys `batch` and `size` representing the current batch 336 number and the size of the batch. 337 """ 338 self._call_batch_hook(ModeKeys.PREDICT, 'begin', batch, logs=logs) 339 340 def on_predict_batch_end(self, batch, logs=None): 341 """Calls the `on_predict_batch_end` methods of its callbacks. 342 343 Arguments: 344 batch: integer, index of batch within the current epoch. 345 logs: dict. Metric results for this batch. 346 """ 347 self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs) 348 349 def on_train_begin(self, logs=None): 350 """Calls the `on_train_begin` methods of its callbacks. 351 352 Arguments: 353 logs: dict. Currently no data is passed to this argument for this method 354 but that may change in the future. 355 """ 356 for callback in self.callbacks: 357 callback.on_train_begin(logs) 358 359 def on_train_end(self, logs=None): 360 """Calls the `on_train_end` methods of its callbacks. 361 362 Arguments: 363 logs: dict. Currently no data is passed to this argument for this method 364 but that may change in the future. 365 """ 366 for callback in self.callbacks: 367 callback.on_train_end(logs) 368 369 def on_test_begin(self, logs=None): 370 """Calls the `on_test_begin` methods of its callbacks. 371 372 Arguments: 373 logs: dict. Currently no data is passed to this argument for this method 374 but that may change in the future. 375 """ 376 for callback in self.callbacks: 377 callback.on_test_begin(logs) 378 379 def on_test_end(self, logs=None): 380 """Calls the `on_test_end` methods of its callbacks. 381 382 Arguments: 383 logs: dict. Currently no data is passed to this argument for this method 384 but that may change in the future. 385 """ 386 for callback in self.callbacks: 387 callback.on_test_end(logs) 388 389 def on_predict_begin(self, logs=None): 390 """Calls the 'on_predict_begin` methods of its callbacks. 391 392 Arguments: 393 logs: dict. Currently no data is passed to this argument for this method 394 but that may change in the future. 395 """ 396 for callback in self.callbacks: 397 callback.on_predict_begin(logs) 398 399 def on_predict_end(self, logs=None): 400 """Calls the `on_predict_end` methods of its callbacks. 401 402 Arguments: 403 logs: dict. Currently no data is passed to this argument for this method 404 but that may change in the future. 405 """ 406 for callback in self.callbacks: 407 callback.on_predict_end(logs) 408 409 def __iter__(self): 410 return iter(self.callbacks) 411 412 413@keras_export('keras.callbacks.Callback') 414class Callback(object): 415 """Abstract base class used to build new callbacks. 416 417 Attributes: 418 params: dict. Training parameters 419 (eg. verbosity, batch size, number of epochs...). 420 model: instance of `keras.models.Model`. 421 Reference of the model being trained. 422 423 The `logs` dictionary that callback methods 424 take as argument will contain keys for quantities relevant to 425 the current batch or epoch. 426 427 Currently, the `.fit()` method of the `Model` class 428 will include the following quantities in the `logs` that 429 it passes to its callbacks: 430 431 on_epoch_end: logs include `acc` and `loss`, and 432 optionally include `val_loss` 433 (if validation is enabled in `fit`), and `val_acc` 434 (if validation and accuracy monitoring are enabled). 435 on_batch_begin: logs include `size`, 436 the number of samples in the current batch. 437 on_batch_end: logs include `loss`, and optionally `acc` 438 (if accuracy monitoring is enabled). 439 """ 440 441 def __init__(self): 442 self.validation_data = None 443 self.model = None 444 # Whether this Callback should only run on the chief worker in a 445 # Multi-Worker setting. 446 # TODO(omalleyt): Make this attr public once solution is stable. 447 self._chief_worker_only = None 448 449 def set_params(self, params): 450 self.params = params 451 452 def set_model(self, model): 453 self.model = model 454 455 def on_batch_begin(self, batch, logs=None): 456 """A backwards compatibility alias for `on_train_batch_begin`.""" 457 458 def on_batch_end(self, batch, logs=None): 459 """A backwards compatibility alias for `on_train_batch_end`.""" 460 461 def on_epoch_begin(self, epoch, logs=None): 462 """Called at the start of an epoch. 463 464 Subclasses should override for any actions to run. This function should only 465 be called during TRAIN mode. 466 467 Arguments: 468 epoch: integer, index of epoch. 469 logs: dict. Currently no data is passed to this argument for this method 470 but that may change in the future. 471 """ 472 473 def on_epoch_end(self, epoch, logs=None): 474 """Called at the end of an epoch. 475 476 Subclasses should override for any actions to run. This function should only 477 be called during TRAIN mode. 478 479 Arguments: 480 epoch: integer, index of epoch. 481 logs: dict, metric results for this training epoch, and for the 482 validation epoch if validation is performed. Validation result keys 483 are prefixed with `val_`. 484 """ 485 486 def on_train_batch_begin(self, batch, logs=None): 487 """Called at the beginning of a training batch in `fit` methods. 488 489 Subclasses should override for any actions to run. 490 491 Arguments: 492 batch: integer, index of batch within the current epoch. 493 logs: dict. Has keys `batch` and `size` representing the current batch 494 number and the size of the batch. 495 """ 496 # For backwards compatibility. 497 self.on_batch_begin(batch, logs=logs) 498 499 def on_train_batch_end(self, batch, logs=None): 500 """Called at the end of a training batch in `fit` methods. 501 502 Subclasses should override for any actions to run. 503 504 Arguments: 505 batch: integer, index of batch within the current epoch. 506 logs: dict. Metric results for this batch. 507 """ 508 # For backwards compatibility. 509 self.on_batch_end(batch, logs=logs) 510 511 def on_test_batch_begin(self, batch, logs=None): 512 """Called at the beginning of a batch in `evaluate` methods. 513 514 Also called at the beginning of a validation batch in the `fit` 515 methods, if validation data is provided. 516 517 Subclasses should override for any actions to run. 518 519 Arguments: 520 batch: integer, index of batch within the current epoch. 521 logs: dict. Has keys `batch` and `size` representing the current batch 522 number and the size of the batch. 523 """ 524 525 def on_test_batch_end(self, batch, logs=None): 526 """Called at the end of a batch in `evaluate` methods. 527 528 Also called at the end of a validation batch in the `fit` 529 methods, if validation data is provided. 530 531 Subclasses should override for any actions to run. 532 533 Arguments: 534 batch: integer, index of batch within the current epoch. 535 logs: dict. Metric results for this batch. 536 """ 537 538 def on_predict_batch_begin(self, batch, logs=None): 539 """Called at the beginning of a batch in `predict` methods. 540 541 Subclasses should override for any actions to run. 542 543 Arguments: 544 batch: integer, index of batch within the current epoch. 545 logs: dict. Has keys `batch` and `size` representing the current batch 546 number and the size of the batch. 547 """ 548 549 def on_predict_batch_end(self, batch, logs=None): 550 """Called at the end of a batch in `predict` methods. 551 552 Subclasses should override for any actions to run. 553 554 Arguments: 555 batch: integer, index of batch within the current epoch. 556 logs: dict. Metric results for this batch. 557 """ 558 559 def on_train_begin(self, logs=None): 560 """Called at the beginning of training. 561 562 Subclasses should override for any actions to run. 563 564 Arguments: 565 logs: dict. Currently no data is passed to this argument for this method 566 but that may change in the future. 567 """ 568 569 def on_train_end(self, logs=None): 570 """Called at the end of training. 571 572 Subclasses should override for any actions to run. 573 574 Arguments: 575 logs: dict. Currently no data is passed to this argument for this method 576 but that may change in the future. 577 """ 578 579 def on_test_begin(self, logs=None): 580 """Called at the beginning of evaluation or validation. 581 582 Subclasses should override for any actions to run. 583 584 Arguments: 585 logs: dict. Currently no data is passed to this argument for this method 586 but that may change in the future. 587 """ 588 589 def on_test_end(self, logs=None): 590 """Called at the end of evaluation or validation. 591 592 Subclasses should override for any actions to run. 593 594 Arguments: 595 logs: dict. Currently no data is passed to this argument for this method 596 but that may change in the future. 597 """ 598 599 def on_predict_begin(self, logs=None): 600 """Called at the beginning of prediction. 601 602 Subclasses should override for any actions to run. 603 604 Arguments: 605 logs: dict. Currently no data is passed to this argument for this method 606 but that may change in the future. 607 """ 608 609 def on_predict_end(self, logs=None): 610 """Called at the end of prediction. 611 612 Subclasses should override for any actions to run. 613 614 Arguments: 615 logs: dict. Currently no data is passed to this argument for this method 616 but that may change in the future. 617 """ 618 619 620@keras_export('keras.callbacks.BaseLogger') 621class BaseLogger(Callback): 622 """Callback that accumulates epoch averages of metrics. 623 624 This callback is automatically applied to every Keras model. 625 626 Arguments: 627 stateful_metrics: Iterable of string names of metrics that 628 should *not* be averaged over an epoch. 629 Metrics in this list will be logged as-is in `on_epoch_end`. 630 All others will be averaged in `on_epoch_end`. 631 """ 632 633 def __init__(self, stateful_metrics=None): 634 super(BaseLogger, self).__init__() 635 self.stateful_metrics = set(stateful_metrics or []) 636 637 def on_epoch_begin(self, epoch, logs=None): 638 self.seen = 0 639 self.totals = {} 640 641 def on_batch_end(self, batch, logs=None): 642 logs = logs or {} 643 batch_size = logs.get('size', 0) 644 # In case of distribution strategy we can potentially run multiple steps 645 # at the same time, we should account for that in the `seen` calculation. 646 num_steps = logs.get('num_steps', 1) 647 self.seen += batch_size * num_steps 648 649 for k, v in logs.items(): 650 if k in self.stateful_metrics: 651 self.totals[k] = v 652 else: 653 if k in self.totals: 654 self.totals[k] += v * batch_size 655 else: 656 self.totals[k] = v * batch_size 657 658 def on_epoch_end(self, epoch, logs=None): 659 if logs is not None: 660 for k in self.params['metrics']: 661 if k in self.totals: 662 # Make value available to next callbacks. 663 if k in self.stateful_metrics: 664 logs[k] = self.totals[k] 665 else: 666 logs[k] = self.totals[k] / self.seen 667 668 669@keras_export('keras.callbacks.TerminateOnNaN') 670class TerminateOnNaN(Callback): 671 """Callback that terminates training when a NaN loss is encountered. 672 """ 673 674 def on_batch_end(self, batch, logs=None): 675 logs = logs or {} 676 loss = logs.get('loss') 677 if loss is not None: 678 if np.isnan(loss) or np.isinf(loss): 679 print('Batch %d: Invalid loss, terminating training' % (batch)) 680 self.model.stop_training = True 681 682 683@keras_export('keras.callbacks.ProgbarLogger') 684class ProgbarLogger(Callback): 685 """Callback that prints metrics to stdout. 686 687 Arguments: 688 count_mode: One of "steps" or "samples". 689 Whether the progress bar should 690 count samples seen or steps (batches) seen. 691 stateful_metrics: Iterable of string names of metrics that 692 should *not* be averaged over an epoch. 693 Metrics in this list will be logged as-is. 694 All others will be averaged over time (e.g. loss, etc). 695 696 Raises: 697 ValueError: In case of invalid `count_mode`. 698 """ 699 700 def __init__(self, count_mode='samples', stateful_metrics=None): 701 super(ProgbarLogger, self).__init__() 702 if count_mode == 'samples': 703 self.use_steps = False 704 elif count_mode == 'steps': 705 self.use_steps = True 706 else: 707 raise ValueError('Unknown `count_mode`: ' + str(count_mode)) 708 self.stateful_metrics = set(stateful_metrics or []) 709 710 def on_train_begin(self, logs=None): 711 self.verbose = self.params['verbose'] 712 self.epochs = self.params['epochs'] 713 714 def on_epoch_begin(self, epoch, logs=None): 715 self.seen = 0 716 if self.use_steps: 717 self.target = self.params['steps'] 718 else: 719 self.target = self.params['samples'] 720 721 if self.verbose: 722 if self.epochs > 1: 723 print('Epoch %d/%d' % (epoch + 1, self.epochs)) 724 self.progbar = Progbar( 725 target=self.target, 726 verbose=self.verbose, 727 stateful_metrics=self.stateful_metrics, 728 unit_name='step' if self.use_steps else 'sample') 729 730 def on_batch_begin(self, batch, logs=None): 731 self.log_values = [] 732 733 def on_batch_end(self, batch, logs=None): 734 logs = logs or {} 735 batch_size = logs.get('size', 0) 736 # In case of distribution strategy we can potentially run multiple steps 737 # at the same time, we should account for that in the `seen` calculation. 738 num_steps = logs.get('num_steps', 1) 739 if self.use_steps: 740 self.seen += num_steps 741 else: 742 self.seen += batch_size * num_steps 743 744 for k in self.params['metrics']: 745 if k in logs: 746 self.log_values.append((k, logs[k])) 747 748 # Skip progbar update for the last batch; 749 # will be handled by on_epoch_end. 750 if self.verbose and (self.target is None or self.seen < self.target): 751 self.progbar.update(self.seen, self.log_values) 752 753 def on_epoch_end(self, epoch, logs=None): 754 logs = logs or {} 755 for k in self.params['metrics']: 756 if k in logs: 757 self.log_values.append((k, logs[k])) 758 if self.verbose: 759 self.progbar.update(self.seen, self.log_values) 760 761 762@keras_export('keras.callbacks.History') 763class History(Callback): 764 """Callback that records events into a `History` object. 765 766 This callback is automatically applied to 767 every Keras model. The `History` object 768 gets returned by the `fit` method of models. 769 """ 770 771 def on_train_begin(self, logs=None): 772 self.epoch = [] 773 self.history = {} 774 775 def on_epoch_end(self, epoch, logs=None): 776 logs = logs or {} 777 self.epoch.append(epoch) 778 for k, v in logs.items(): 779 self.history.setdefault(k, []).append(v) 780 781 782@keras_export('keras.callbacks.ModelCheckpoint') 783class ModelCheckpoint(Callback): 784 """Save the model after every epoch. 785 786 `filepath` can contain named formatting options, 787 which will be filled the value of `epoch` and 788 keys in `logs` (passed in `on_epoch_end`). 789 790 For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, 791 then the model checkpoints will be saved with the epoch number and 792 the validation loss in the filename. 793 794 Arguments: 795 filepath: string, path to save the model file. 796 monitor: quantity to monitor. 797 verbose: verbosity mode, 0 or 1. 798 save_best_only: if `save_best_only=True`, 799 the latest best model according to 800 the quantity monitored will not be overwritten. 801 mode: one of {auto, min, max}. 802 If `save_best_only=True`, the decision 803 to overwrite the current save file is made 804 based on either the maximization or the 805 minimization of the monitored quantity. For `val_acc`, 806 this should be `max`, for `val_loss` this should 807 be `min`, etc. In `auto` mode, the direction is 808 automatically inferred from the name of the monitored quantity. 809 save_weights_only: if True, then only the model's weights will be 810 saved (`model.save_weights(filepath)`), else the full model 811 is saved (`model.save(filepath)`). 812 period: Interval (number of epochs) between checkpoints. 813 """ 814 815 def __init__(self, 816 filepath, 817 monitor='val_loss', 818 verbose=0, 819 save_best_only=False, 820 save_weights_only=False, 821 mode='auto', 822 period=1): 823 super(ModelCheckpoint, self).__init__() 824 self.monitor = monitor 825 self.verbose = verbose 826 self.filepath = filepath 827 self.save_best_only = save_best_only 828 self.save_weights_only = save_weights_only 829 self.period = period 830 self.epochs_since_last_save = 0 831 832 if mode not in ['auto', 'min', 'max']: 833 logging.warning('ModelCheckpoint mode %s is unknown, ' 834 'fallback to auto mode.', mode) 835 mode = 'auto' 836 837 if mode == 'min': 838 self.monitor_op = np.less 839 self.best = np.Inf 840 elif mode == 'max': 841 self.monitor_op = np.greater 842 self.best = -np.Inf 843 else: 844 if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): 845 self.monitor_op = np.greater 846 self.best = -np.Inf 847 else: 848 self.monitor_op = np.less 849 self.best = np.Inf 850 851 # Only the chief worker writes model checkpoints. 852 self._chief_worker_only = True 853 854 def set_model(self, model): 855 self.model = model 856 # Use name matching rather than `isinstance` to avoid circular dependencies. 857 if (not self.save_weights_only and 858 not model._is_graph_network and # pylint: disable=protected-access 859 model.__class__.__name__ != 'Sequential'): 860 self.save_weights_only = True 861 862 def on_epoch_end(self, epoch, logs=None): 863 logs = logs or {} 864 self.epochs_since_last_save += 1 865 if self.epochs_since_last_save >= self.period: 866 self.epochs_since_last_save = 0 867 filepath = self.filepath.format(epoch=epoch + 1, **logs) 868 if self.save_best_only: 869 current = logs.get(self.monitor) 870 if current is None: 871 logging.warning('Can save best model only with %s available, ' 872 'skipping.', self.monitor) 873 else: 874 if self.monitor_op(current, self.best): 875 if self.verbose > 0: 876 print('\nEpoch %05d: %s improved from %0.5f to %0.5f,' 877 ' saving model to %s' % (epoch + 1, self.monitor, self.best, 878 current, filepath)) 879 self.best = current 880 if self.save_weights_only: 881 self.model.save_weights(filepath, overwrite=True) 882 else: 883 self.model.save(filepath, overwrite=True) 884 else: 885 if self.verbose > 0: 886 print('\nEpoch %05d: %s did not improve from %0.5f' % 887 (epoch + 1, self.monitor, self.best)) 888 else: 889 if self.verbose > 0: 890 print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath)) 891 if self.save_weights_only: 892 self.model.save_weights(filepath, overwrite=True) 893 else: 894 self.model.save(filepath, overwrite=True) 895 896 897@keras_export('keras.callbacks.EarlyStopping') 898class EarlyStopping(Callback): 899 """Stop training when a monitored quantity has stopped improving. 900 901 Arguments: 902 monitor: Quantity to be monitored. 903 min_delta: Minimum change in the monitored quantity 904 to qualify as an improvement, i.e. an absolute 905 change of less than min_delta, will count as no 906 improvement. 907 patience: Number of epochs with no improvement 908 after which training will be stopped. 909 verbose: verbosity mode. 910 mode: One of `{"auto", "min", "max"}`. In `min` mode, 911 training will stop when the quantity 912 monitored has stopped decreasing; in `max` 913 mode it will stop when the quantity 914 monitored has stopped increasing; in `auto` 915 mode, the direction is automatically inferred 916 from the name of the monitored quantity. 917 baseline: Baseline value for the monitored quantity. 918 Training will stop if the model doesn't show improvement over the 919 baseline. 920 restore_best_weights: Whether to restore model weights from 921 the epoch with the best value of the monitored quantity. 922 If False, the model weights obtained at the last step of 923 training are used. 924 """ 925 926 def __init__(self, 927 monitor='val_loss', 928 min_delta=0, 929 patience=0, 930 verbose=0, 931 mode='auto', 932 baseline=None, 933 restore_best_weights=False): 934 super(EarlyStopping, self).__init__() 935 936 self.monitor = monitor 937 self.patience = patience 938 self.verbose = verbose 939 self.baseline = baseline 940 self.min_delta = abs(min_delta) 941 self.wait = 0 942 self.stopped_epoch = 0 943 self.restore_best_weights = restore_best_weights 944 self.best_weights = None 945 946 if mode not in ['auto', 'min', 'max']: 947 logging.warning('EarlyStopping mode %s is unknown, ' 948 'fallback to auto mode.', mode) 949 mode = 'auto' 950 951 if mode == 'min': 952 self.monitor_op = np.less 953 elif mode == 'max': 954 self.monitor_op = np.greater 955 else: 956 if 'acc' in self.monitor: 957 self.monitor_op = np.greater 958 else: 959 self.monitor_op = np.less 960 961 if self.monitor_op == np.greater: 962 self.min_delta *= 1 963 else: 964 self.min_delta *= -1 965 966 def on_train_begin(self, logs=None): 967 # Allow instances to be re-used 968 self.wait = 0 969 self.stopped_epoch = 0 970 if self.baseline is not None: 971 self.best = self.baseline 972 else: 973 self.best = np.Inf if self.monitor_op == np.less else -np.Inf 974 975 def on_epoch_end(self, epoch, logs=None): 976 current = self.get_monitor_value(logs) 977 if current is None: 978 return 979 if self.monitor_op(current - self.min_delta, self.best): 980 self.best = current 981 self.wait = 0 982 if self.restore_best_weights: 983 self.best_weights = self.model.get_weights() 984 else: 985 self.wait += 1 986 if self.wait >= self.patience: 987 self.stopped_epoch = epoch 988 self.model.stop_training = True 989 if self.restore_best_weights: 990 if self.verbose > 0: 991 print('Restoring model weights from the end of the best epoch.') 992 self.model.set_weights(self.best_weights) 993 994 def on_train_end(self, logs=None): 995 if self.stopped_epoch > 0 and self.verbose > 0: 996 print('Epoch %05d: early stopping' % (self.stopped_epoch + 1)) 997 998 def get_monitor_value(self, logs): 999 logs = logs or {} 1000 monitor_value = logs.get(self.monitor) 1001 if monitor_value is None: 1002 logging.warning('Early stopping conditioned on metric `%s` ' 1003 'which is not available. Available metrics are: %s', 1004 self.monitor, ','.join(list(logs.keys()))) 1005 return monitor_value 1006 1007 1008@keras_export('keras.callbacks.RemoteMonitor') 1009class RemoteMonitor(Callback): 1010 """Callback used to stream events to a server. 1011 1012 Requires the `requests` library. 1013 Events are sent to `root + '/publish/epoch/end/'` by default. Calls are 1014 HTTP POST, with a `data` argument which is a 1015 JSON-encoded dictionary of event data. 1016 If send_as_json is set to True, the content type of the request will be 1017 application/json. Otherwise the serialized JSON will be sent within a form. 1018 1019 Arguments: 1020 root: String; root url of the target server. 1021 path: String; path relative to `root` to which the events will be sent. 1022 field: String; JSON field under which the data will be stored. 1023 The field is used only if the payload is sent within a form 1024 (i.e. send_as_json is set to False). 1025 headers: Dictionary; optional custom HTTP headers. 1026 send_as_json: Boolean; whether the request should be 1027 sent as application/json. 1028 """ 1029 1030 def __init__(self, 1031 root='http://localhost:9000', 1032 path='/publish/epoch/end/', 1033 field='data', 1034 headers=None, 1035 send_as_json=False): 1036 super(RemoteMonitor, self).__init__() 1037 1038 self.root = root 1039 self.path = path 1040 self.field = field 1041 self.headers = headers 1042 self.send_as_json = send_as_json 1043 1044 def on_epoch_end(self, epoch, logs=None): 1045 if requests is None: 1046 raise ImportError('RemoteMonitor requires the `requests` library.') 1047 logs = logs or {} 1048 send = {} 1049 send['epoch'] = epoch 1050 for k, v in logs.items(): 1051 send[k] = v 1052 try: 1053 if self.send_as_json: 1054 requests.post(self.root + self.path, json=send, headers=self.headers) 1055 else: 1056 requests.post( 1057 self.root + self.path, {self.field: json.dumps(send)}, 1058 headers=self.headers) 1059 except requests.exceptions.RequestException: 1060 logging.warning('Warning: could not reach RemoteMonitor ' 1061 'root server at ' + str(self.root)) 1062 1063 1064@keras_export('keras.callbacks.LearningRateScheduler') 1065class LearningRateScheduler(Callback): 1066 """Learning rate scheduler. 1067 1068 Arguments: 1069 schedule: a function that takes an epoch index as input 1070 (integer, indexed from 0) and returns a new 1071 learning rate as output (float). 1072 verbose: int. 0: quiet, 1: update messages. 1073 """ 1074 1075 def __init__(self, schedule, verbose=0): 1076 super(LearningRateScheduler, self).__init__() 1077 self.schedule = schedule 1078 self.verbose = verbose 1079 1080 def on_epoch_begin(self, epoch, logs=None): 1081 if not hasattr(self.model.optimizer, 'lr'): 1082 raise ValueError('Optimizer must have a "lr" attribute.') 1083 try: # new API 1084 lr = float(K.get_value(self.model.optimizer.lr)) 1085 lr = self.schedule(epoch, lr) 1086 except TypeError: # Support for old API for backward compatibility 1087 lr = self.schedule(epoch) 1088 if not isinstance(lr, (float, np.float32, np.float64)): 1089 raise ValueError('The output of the "schedule" function ' 1090 'should be float.') 1091 K.set_value(self.model.optimizer.lr, lr) 1092 if self.verbose > 0: 1093 print('\nEpoch %05d: LearningRateScheduler reducing learning ' 1094 'rate to %s.' % (epoch + 1, lr)) 1095 1096 def on_epoch_end(self, epoch, logs=None): 1097 logs = logs or {} 1098 logs['lr'] = K.get_value(self.model.optimizer.lr) 1099 1100 1101@keras_export('keras.callbacks.TensorBoard', v1=[]) 1102class TensorBoard(Callback): 1103 # pylint: disable=line-too-long 1104 """Enable visualizations for TensorBoard. 1105 1106 TensorBoard is a visualization tool provided with TensorFlow. 1107 1108 This callback logs events for TensorBoard, including: 1109 * Metrics summary plots 1110 * Training graph visualization 1111 * Activation histograms 1112 * Sampled profiling 1113 1114 If you have installed TensorFlow with pip, you should be able 1115 to launch TensorBoard from the command line: 1116 1117 ```sh 1118 tensorboard --logdir=path_to_your_logs 1119 ``` 1120 1121 You can find more information about TensorBoard 1122 [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard). 1123 1124 Arguments: 1125 log_dir: the path of the directory where to save the log files to be 1126 parsed by TensorBoard. 1127 histogram_freq: frequency (in epochs) at which to compute activation and 1128 weight histograms for the layers of the model. If set to 0, histograms 1129 won't be computed. Validation data (or split) must be specified for 1130 histogram visualizations. 1131 write_graph: whether to visualize the graph in TensorBoard. The log file 1132 can become quite large when write_graph is set to True. 1133 write_images: whether to write model weights to visualize as image in 1134 TensorBoard. 1135 update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`, 1136 writes the losses and metrics to TensorBoard after each batch. The same 1137 applies for `'epoch'`. If using an integer, let's say `1000`, the 1138 callback will write the metrics and losses to TensorBoard every 1000 1139 samples. Note that writing too frequently to TensorBoard can slow down 1140 your training. 1141 profile_batch: Profile the batch to sample compute characteristics. By 1142 default, it will profile the second batch. Set profile_batch=0 to 1143 disable profiling. Must run in TensorFlow eager mode. 1144 1145 Raises: 1146 ValueError: If histogram_freq is set and no validation data is provided. 1147 """ 1148 1149 # pylint: enable=line-too-long 1150 1151 def __init__(self, 1152 log_dir='logs', 1153 histogram_freq=0, 1154 write_graph=True, 1155 write_images=False, 1156 update_freq='epoch', 1157 profile_batch=2, 1158 **kwargs): 1159 super(TensorBoard, self).__init__() 1160 self._validate_kwargs(kwargs) 1161 1162 self.log_dir = log_dir 1163 self.histogram_freq = histogram_freq 1164 self.write_graph = write_graph 1165 self.write_images = write_images 1166 if update_freq == 'batch': 1167 self.update_freq = 1 1168 else: 1169 self.update_freq = update_freq 1170 1171 self._samples_seen = 0 1172 self._samples_seen_at_last_write = 0 1173 self._current_batch = 0 1174 self._total_batches_seen = 0 1175 self._total_val_batches_seen = 0 1176 1177 # A collection of file writers currently in use, to be closed when 1178 # training ends for this callback. Writers are keyed by the 1179 # directory name under the root logdir: e.g., "train" or 1180 # "validation". 1181 self._writers = {} 1182 self._train_run_name = 'train' 1183 self._validation_run_name = 'validation' 1184 1185 self._profile_batch = profile_batch 1186 # True when a trace is running. 1187 self._is_tracing = False 1188 1189 # TensorBoard should only write summaries on the chief when in a 1190 # Multi-Worker setting. 1191 self._chief_worker_only = True 1192 1193 def _validate_kwargs(self, kwargs): 1194 """Handle arguments were supported in V1.""" 1195 if kwargs.get('write_grads', False): 1196 logging.warning('`write_grads` will be ignored in TensorFlow 2.0 ' 1197 'for the `TensorBoard` Callback.') 1198 if kwargs.get('embeddings_freq', False): 1199 logging.warning('Embeddings will be ignored in TensorFlow 2.0 ' 1200 'for the `TensorBoard` Callback.') 1201 if kwargs.get('batch_size', False): 1202 logging.warning('`batch_size` is no longer needed in the ' 1203 '`TensorBoard` Callback and will be ignored ' 1204 'in TensorFlow 2.0.') 1205 1206 unrecognized_kwargs = set(kwargs.keys()) - { 1207 'write_grads', 'embeddings_freq', 'embeddings_layer_names', 1208 'embeddings_metadata', 'embeddings_data', 'batch_size' 1209 } 1210 1211 # Only allow kwargs that were supported in V1. 1212 if unrecognized_kwargs: 1213 raise ValueError('Unrecognized arguments in `TensorBoard` ' 1214 'Callback: ' + str(unrecognized_kwargs)) 1215 1216 def set_model(self, model): 1217 """Sets Keras model and writes graph if specified.""" 1218 self.model = model 1219 with context.eager_mode(): 1220 self._close_writers() 1221 if self.write_graph: 1222 with self._get_writer(self._train_run_name).as_default(): 1223 with summary_ops_v2.always_record_summaries(): 1224 if not model.run_eagerly: 1225 summary_ops_v2.graph(K.get_graph()) 1226 1227 summary_writable = ( 1228 self.model._is_graph_network or # pylint: disable=protected-access 1229 self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access 1230 if summary_writable: 1231 summary_ops_v2.keras_model('keras', self.model, step=0) 1232 1233 def _close_writers(self): 1234 """Close all remaining open file writers owned by this callback. 1235 1236 If there are no such file writers, this is a no-op. 1237 """ 1238 with context.eager_mode(): 1239 for writer in six.itervalues(self._writers): 1240 writer.close() 1241 self._writers.clear() 1242 1243 def _get_writer(self, writer_name): 1244 """Get a summary writer for the given subdirectory under the logdir. 1245 1246 A writer will be created if it does not yet exist. 1247 1248 Args: 1249 writer_name: The name of the directory for which to create or 1250 retrieve a writer. Should be either `self._train_run_name` or 1251 `self._validation_run_name`. 1252 1253 Returns: 1254 A `SummaryWriter` object. 1255 """ 1256 if writer_name not in self._writers: 1257 path = os.path.join(self.log_dir, writer_name) 1258 writer = summary_ops_v2.create_file_writer_v2(path) 1259 self._writers[writer_name] = writer 1260 return self._writers[writer_name] 1261 1262 def on_train_begin(self, logs=None): 1263 if self._profile_batch == 1: 1264 summary_ops_v2.trace_on(graph=True, profiler=True) 1265 self._is_tracing = True 1266 1267 def on_batch_end(self, batch, logs=None): 1268 """Writes scalar summaries for metrics on every training batch. 1269 1270 Performs profiling if current batch is in profiler_batches. 1271 """ 1272 # Don't output batch_size and batch number as TensorBoard summaries 1273 logs = logs or {} 1274 self._samples_seen += logs.get('size', 1) 1275 samples_seen_since = self._samples_seen - self._samples_seen_at_last_write 1276 if self.update_freq != 'epoch' and samples_seen_since >= self.update_freq: 1277 self._log_metrics(logs, prefix='batch_', step=self._total_batches_seen) 1278 self._samples_seen_at_last_write = self._samples_seen 1279 self._total_batches_seen += 1 1280 if self._is_tracing: 1281 self._log_trace() 1282 elif (not self._is_tracing and 1283 self._total_batches_seen == self._profile_batch - 1): 1284 self._enable_trace() 1285 1286 def on_epoch_end(self, epoch, logs=None): 1287 """Runs metrics and histogram summaries at epoch end.""" 1288 step = epoch if self.update_freq == 'epoch' else self._samples_seen 1289 self._log_metrics(logs, prefix='epoch_', step=step) 1290 1291 if self.histogram_freq and epoch % self.histogram_freq == 0: 1292 self._log_weights(epoch) 1293 1294 def on_train_end(self, logs=None): 1295 if self._is_tracing: 1296 self._log_trace() 1297 self._close_writers() 1298 1299 def _enable_trace(self): 1300 if context.executing_eagerly(): 1301 summary_ops_v2.trace_on(graph=True, profiler=True) 1302 self._is_tracing = True 1303 1304 def _log_trace(self): 1305 if context.executing_eagerly(): 1306 with self._get_writer(self._train_run_name).as_default(), \ 1307 summary_ops_v2.always_record_summaries(): 1308 # TODO(b/126388999): Remove step info in the summary name. 1309 summary_ops_v2.trace_export( 1310 name='batch_%d' % self._total_batches_seen, 1311 step=self._total_batches_seen, 1312 profiler_outdir=os.path.join(self.log_dir, 'train')) 1313 self._is_tracing = False 1314 1315 def _log_metrics(self, logs, prefix, step): 1316 """Writes metrics out as custom scalar summaries. 1317 1318 Arguments: 1319 logs: Dict. Keys are scalar summary names, values are NumPy scalars. 1320 prefix: String. The prefix to apply to the scalar summary names. 1321 step: Int. The global step to use for TensorBoard. 1322 """ 1323 if logs is None: 1324 logs = {} 1325 1326 # Group metrics by the name of their associated file writer. Values 1327 # are lists of metrics, as (name, scalar_value) pairs. 1328 logs_by_writer = { 1329 self._train_run_name: [], 1330 self._validation_run_name: [], 1331 } 1332 validation_prefix = 'val_' 1333 for (name, value) in logs.items(): 1334 if name in ('batch', 'size', 'num_steps'): 1335 # Scrub non-metric items. 1336 continue 1337 if name.startswith(validation_prefix): 1338 name = name[len(validation_prefix):] 1339 writer_name = self._validation_run_name 1340 else: 1341 writer_name = self._train_run_name 1342 name = prefix + name # assign batch or epoch prefix 1343 logs_by_writer[writer_name].append((name, value)) 1344 1345 with context.eager_mode(): 1346 with summary_ops_v2.always_record_summaries(): 1347 for writer_name in logs_by_writer: 1348 these_logs = logs_by_writer[writer_name] 1349 if not these_logs: 1350 # Don't create a "validation" events file if we don't 1351 # actually have any validation data. 1352 continue 1353 writer = self._get_writer(writer_name) 1354 with writer.as_default(): 1355 for (name, value) in these_logs: 1356 summary_ops_v2.scalar(name, value, step=step) 1357 1358 def _log_weights(self, epoch): 1359 """Logs the weights of the Model to TensorBoard.""" 1360 writer = self._get_writer(self._train_run_name) 1361 with context.eager_mode(), \ 1362 writer.as_default(), \ 1363 summary_ops_v2.always_record_summaries(): 1364 for layer in self.model.layers: 1365 for weight in layer.weights: 1366 weight_name = weight.name.replace(':', '_') 1367 with ops.init_scope(): 1368 weight = K.get_value(weight) 1369 summary_ops_v2.histogram(weight_name, weight, step=epoch) 1370 if self.write_images: 1371 self._log_weight_as_image(weight, weight_name, epoch) 1372 writer.flush() 1373 1374 def _log_weight_as_image(self, weight, weight_name, epoch): 1375 """Logs a weight as a TensorBoard image.""" 1376 w_img = array_ops.squeeze(weight) 1377 shape = K.int_shape(w_img) 1378 if len(shape) == 1: # Bias case 1379 w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1]) 1380 elif len(shape) == 2: # Dense layer kernel case 1381 if shape[0] > shape[1]: 1382 w_img = array_ops.transpose(w_img) 1383 shape = K.int_shape(w_img) 1384 w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1]) 1385 elif len(shape) == 3: # ConvNet case 1386 if K.image_data_format() == 'channels_last': 1387 # Switch to channels_first to display every kernel as a separate 1388 # image. 1389 w_img = array_ops.transpose(w_img, perm=[2, 0, 1]) 1390 shape = K.int_shape(w_img) 1391 w_img = array_ops.reshape(w_img, [shape[0], shape[1], shape[2], 1]) 1392 1393 shape = K.int_shape(w_img) 1394 # Not possible to handle 3D convnets etc. 1395 if len(shape) == 4 and shape[-1] in [1, 3, 4]: 1396 summary_ops_v2.image(weight_name, w_img, step=epoch) 1397 1398 1399@keras_export('keras.callbacks.ReduceLROnPlateau') 1400class ReduceLROnPlateau(Callback): 1401 """Reduce learning rate when a metric has stopped improving. 1402 1403 Models often benefit from reducing the learning rate by a factor 1404 of 2-10 once learning stagnates. This callback monitors a 1405 quantity and if no improvement is seen for a 'patience' number 1406 of epochs, the learning rate is reduced. 1407 1408 Example: 1409 1410 ```python 1411 reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, 1412 patience=5, min_lr=0.001) 1413 model.fit(X_train, Y_train, callbacks=[reduce_lr]) 1414 ``` 1415 1416 Arguments: 1417 monitor: quantity to be monitored. 1418 factor: factor by which the learning rate will 1419 be reduced. new_lr = lr * factor 1420 patience: number of epochs with no improvement 1421 after which learning rate will be reduced. 1422 verbose: int. 0: quiet, 1: update messages. 1423 mode: one of {auto, min, max}. In `min` mode, 1424 lr will be reduced when the quantity 1425 monitored has stopped decreasing; in `max` 1426 mode it will be reduced when the quantity 1427 monitored has stopped increasing; in `auto` 1428 mode, the direction is automatically inferred 1429 from the name of the monitored quantity. 1430 min_delta: threshold for measuring the new optimum, 1431 to only focus on significant changes. 1432 cooldown: number of epochs to wait before resuming 1433 normal operation after lr has been reduced. 1434 min_lr: lower bound on the learning rate. 1435 """ 1436 1437 def __init__(self, 1438 monitor='val_loss', 1439 factor=0.1, 1440 patience=10, 1441 verbose=0, 1442 mode='auto', 1443 min_delta=1e-4, 1444 cooldown=0, 1445 min_lr=0, 1446 **kwargs): 1447 super(ReduceLROnPlateau, self).__init__() 1448 1449 self.monitor = monitor 1450 if factor >= 1.0: 1451 raise ValueError('ReduceLROnPlateau ' 'does not support a factor >= 1.0.') 1452 if 'epsilon' in kwargs: 1453 min_delta = kwargs.pop('epsilon') 1454 logging.warning('`epsilon` argument is deprecated and ' 1455 'will be removed, use `min_delta` instead.') 1456 self.factor = factor 1457 self.min_lr = min_lr 1458 self.min_delta = min_delta 1459 self.patience = patience 1460 self.verbose = verbose 1461 self.cooldown = cooldown 1462 self.cooldown_counter = 0 # Cooldown counter. 1463 self.wait = 0 1464 self.best = 0 1465 self.mode = mode 1466 self.monitor_op = None 1467 self._reset() 1468 1469 def _reset(self): 1470 """Resets wait counter and cooldown counter. 1471 """ 1472 if self.mode not in ['auto', 'min', 'max']: 1473 logging.warning('Learning Rate Plateau Reducing mode %s is unknown, ' 1474 'fallback to auto mode.', self.mode) 1475 self.mode = 'auto' 1476 if (self.mode == 'min' or 1477 (self.mode == 'auto' and 'acc' not in self.monitor)): 1478 self.monitor_op = lambda a, b: np.less(a, b - self.min_delta) 1479 self.best = np.Inf 1480 else: 1481 self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta) 1482 self.best = -np.Inf 1483 self.cooldown_counter = 0 1484 self.wait = 0 1485 1486 def on_train_begin(self, logs=None): 1487 self._reset() 1488 1489 def on_epoch_end(self, epoch, logs=None): 1490 logs = logs or {} 1491 logs['lr'] = K.get_value(self.model.optimizer.lr) 1492 current = logs.get(self.monitor) 1493 if current is None: 1494 logging.warning('Reduce LR on plateau conditioned on metric `%s` ' 1495 'which is not available. Available metrics are: %s', 1496 self.monitor, ','.join(list(logs.keys()))) 1497 1498 else: 1499 if self.in_cooldown(): 1500 self.cooldown_counter -= 1 1501 self.wait = 0 1502 1503 if self.monitor_op(current, self.best): 1504 self.best = current 1505 self.wait = 0 1506 elif not self.in_cooldown(): 1507 self.wait += 1 1508 if self.wait >= self.patience: 1509 old_lr = float(K.get_value(self.model.optimizer.lr)) 1510 if old_lr > self.min_lr: 1511 new_lr = old_lr * self.factor 1512 new_lr = max(new_lr, self.min_lr) 1513 K.set_value(self.model.optimizer.lr, new_lr) 1514 if self.verbose > 0: 1515 print('\nEpoch %05d: ReduceLROnPlateau reducing learning ' 1516 'rate to %s.' % (epoch + 1, new_lr)) 1517 self.cooldown_counter = self.cooldown 1518 self.wait = 0 1519 1520 def in_cooldown(self): 1521 return self.cooldown_counter > 0 1522 1523 1524@keras_export('keras.callbacks.CSVLogger') 1525class CSVLogger(Callback): 1526 """Callback that streams epoch results to a csv file. 1527 1528 Supports all values that can be represented as a string, 1529 including 1D iterables such as np.ndarray. 1530 1531 Example: 1532 1533 ```python 1534 csv_logger = CSVLogger('training.log') 1535 model.fit(X_train, Y_train, callbacks=[csv_logger]) 1536 ``` 1537 1538 Arguments: 1539 filename: filename of the csv file, e.g. 'run/log.csv'. 1540 separator: string used to separate elements in the csv file. 1541 append: True: append if file exists (useful for continuing 1542 training). False: overwrite existing file, 1543 """ 1544 1545 def __init__(self, filename, separator=',', append=False): 1546 self.sep = separator 1547 self.filename = filename 1548 self.append = append 1549 self.writer = None 1550 self.keys = None 1551 self.append_header = True 1552 if six.PY2: 1553 self.file_flags = 'b' 1554 self._open_args = {} 1555 else: 1556 self.file_flags = '' 1557 self._open_args = {'newline': '\n'} 1558 super(CSVLogger, self).__init__() 1559 1560 def on_train_begin(self, logs=None): 1561 if self.append: 1562 if os.path.exists(self.filename): 1563 with open(self.filename, 'r' + self.file_flags) as f: 1564 self.append_header = not bool(len(f.readline())) 1565 mode = 'a' 1566 else: 1567 mode = 'w' 1568 self.csv_file = io.open(self.filename, 1569 mode + self.file_flags, 1570 **self._open_args) 1571 1572 def on_epoch_end(self, epoch, logs=None): 1573 logs = logs or {} 1574 1575 def handle_value(k): 1576 is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0 1577 if isinstance(k, six.string_types): 1578 return k 1579 elif isinstance(k, collections.Iterable) and not is_zero_dim_ndarray: 1580 return '"[%s]"' % (', '.join(map(str, k))) 1581 else: 1582 return k 1583 1584 if self.keys is None: 1585 self.keys = sorted(logs.keys()) 1586 1587 if self.model.stop_training: 1588 # We set NA so that csv parsers do not fail for this last epoch. 1589 logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys]) 1590 1591 if not self.writer: 1592 1593 class CustomDialect(csv.excel): 1594 delimiter = self.sep 1595 1596 fieldnames = ['epoch'] + self.keys 1597 if six.PY2: 1598 fieldnames = [unicode(x) for x in fieldnames] 1599 1600 self.writer = csv.DictWriter( 1601 self.csv_file, 1602 fieldnames=fieldnames, 1603 dialect=CustomDialect) 1604 if self.append_header: 1605 self.writer.writeheader() 1606 1607 row_dict = collections.OrderedDict({'epoch': epoch}) 1608 row_dict.update((key, handle_value(logs[key])) for key in self.keys) 1609 self.writer.writerow(row_dict) 1610 self.csv_file.flush() 1611 1612 def on_train_end(self, logs=None): 1613 self.csv_file.close() 1614 self.writer = None 1615 1616 1617@keras_export('keras.callbacks.LambdaCallback') 1618class LambdaCallback(Callback): 1619 r"""Callback for creating simple, custom callbacks on-the-fly. 1620 1621 This callback is constructed with anonymous functions that will be called 1622 at the appropriate time. Note that the callbacks expects positional 1623 arguments, as: 1624 1625 - `on_epoch_begin` and `on_epoch_end` expect two positional arguments: 1626 `epoch`, `logs` 1627 - `on_batch_begin` and `on_batch_end` expect two positional arguments: 1628 `batch`, `logs` 1629 - `on_train_begin` and `on_train_end` expect one positional argument: 1630 `logs` 1631 1632 Arguments: 1633 on_epoch_begin: called at the beginning of every epoch. 1634 on_epoch_end: called at the end of every epoch. 1635 on_batch_begin: called at the beginning of every batch. 1636 on_batch_end: called at the end of every batch. 1637 on_train_begin: called at the beginning of model training. 1638 on_train_end: called at the end of model training. 1639 1640 Example: 1641 1642 ```python 1643 # Print the batch number at the beginning of every batch. 1644 batch_print_callback = LambdaCallback( 1645 on_batch_begin=lambda batch,logs: print(batch)) 1646 1647 # Stream the epoch loss to a file in JSON format. The file content 1648 # is not well-formed JSON but rather has a JSON object per line. 1649 import json 1650 json_log = open('loss_log.json', mode='wt', buffering=1) 1651 json_logging_callback = LambdaCallback( 1652 on_epoch_end=lambda epoch, logs: json_log.write( 1653 json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'), 1654 on_train_end=lambda logs: json_log.close() 1655 ) 1656 1657 # Terminate some processes after having finished model training. 1658 processes = ... 1659 cleanup_callback = LambdaCallback( 1660 on_train_end=lambda logs: [ 1661 p.terminate() for p in processes if p.is_alive()]) 1662 1663 model.fit(..., 1664 callbacks=[batch_print_callback, 1665 json_logging_callback, 1666 cleanup_callback]) 1667 ``` 1668 """ 1669 1670 def __init__(self, 1671 on_epoch_begin=None, 1672 on_epoch_end=None, 1673 on_batch_begin=None, 1674 on_batch_end=None, 1675 on_train_begin=None, 1676 on_train_end=None, 1677 **kwargs): 1678 super(LambdaCallback, self).__init__() 1679 self.__dict__.update(kwargs) 1680 if on_epoch_begin is not None: 1681 self.on_epoch_begin = on_epoch_begin 1682 else: 1683 self.on_epoch_begin = lambda epoch, logs: None 1684 if on_epoch_end is not None: 1685 self.on_epoch_end = on_epoch_end 1686 else: 1687 self.on_epoch_end = lambda epoch, logs: None 1688 if on_batch_begin is not None: 1689 self.on_batch_begin = on_batch_begin 1690 else: 1691 self.on_batch_begin = lambda batch, logs: None 1692 if on_batch_end is not None: 1693 self.on_batch_end = on_batch_end 1694 else: 1695 self.on_batch_end = lambda batch, logs: None 1696 if on_train_begin is not None: 1697 self.on_train_begin = on_train_begin 1698 else: 1699 self.on_train_begin = lambda logs: None 1700 if on_train_end is not None: 1701 self.on_train_end = on_train_end 1702 else: 1703 self.on_train_end = lambda logs: None 1704