1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-import-not-at-top 16"""Callbacks: utilities called at certain points during model training. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import copy 24import csv 25import io 26import json 27import os 28import re 29import tempfile 30import time 31 32import numpy as np 33import six 34 35from tensorflow.python.data.ops import iterator_ops 36from tensorflow.python.distribute import distributed_file_utils 37from tensorflow.python.distribute import multi_worker_util 38from tensorflow.python.eager import context 39from tensorflow.python.framework import ops 40from tensorflow.python.keras import backend as K 41from tensorflow.python.keras.distribute import multi_worker_training_state as training_state 42from tensorflow.python.keras.utils.data_utils import Sequence 43from tensorflow.python.keras.utils.generic_utils import Progbar 44from tensorflow.python.keras.utils.mode_keys import ModeKeys 45from tensorflow.python.lib.io import file_io 46from tensorflow.python.ops import array_ops 47from tensorflow.python.ops import math_ops 48from tensorflow.python.ops import summary_ops_v2 49from tensorflow.python.ops import variables 50from tensorflow.python.platform import tf_logging as logging 51from tensorflow.python.training import checkpoint_management 52from tensorflow.python.util.compat import collections_abc 53from tensorflow.python.util.tf_export import keras_export 54from tensorflow.tools.docs import doc_controls 55 56try: 57 import requests 58except ImportError: 59 requests = None 60 61 62def configure_callbacks(callbacks, 63 model, 64 do_validation=False, 65 batch_size=None, 66 epochs=None, 67 steps_per_epoch=None, 68 samples=None, 69 verbose=1, 70 count_mode='steps', 71 mode=ModeKeys.TRAIN): 72 """Configures callbacks for use in various training loops. 73 74 Arguments: 75 callbacks: List of Callbacks. 76 model: Model being trained. 77 do_validation: Whether or not validation loop will be run. 78 batch_size: Number of samples per batch. 79 epochs: Number of epoch to train. 80 steps_per_epoch: Number of batches to run per training epoch. 81 samples: Number of training samples. 82 verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger. 83 count_mode: One of 'steps' or 'samples'. Per-batch or per-sample count. 84 mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT. 85 Which loop mode to configure callbacks for. 86 87 Returns: 88 Instance of CallbackList used to control all Callbacks. 89 """ 90 # Check if callbacks have already been configured. 91 if isinstance(callbacks, CallbackList): 92 return callbacks 93 94 if not callbacks: 95 callbacks = [] 96 97 # Add additional callbacks during training. 98 if mode == ModeKeys.TRAIN: 99 model.history = History() 100 callbacks = [BaseLogger()] + (callbacks or []) + [model.history] 101 if verbose: 102 callbacks.append(ProgbarLogger(count_mode)) 103 callback_list = CallbackList(callbacks) 104 105 # Set callback model 106 callback_model = model._get_callback_model() # pylint: disable=protected-access 107 callback_list.set_model(callback_model) 108 109 set_callback_parameters( 110 callback_list, 111 model, 112 do_validation=do_validation, 113 batch_size=batch_size, 114 epochs=epochs, 115 steps_per_epoch=steps_per_epoch, 116 samples=samples, 117 verbose=verbose, 118 mode=mode) 119 120 callback_list.model.stop_training = False 121 return callback_list 122 123 124def set_callback_parameters(callback_list, 125 model, 126 do_validation=False, 127 batch_size=None, 128 epochs=None, 129 steps_per_epoch=None, 130 samples=None, 131 verbose=1, 132 mode=ModeKeys.TRAIN): 133 """Sets callback parameters. 134 135 Arguments: 136 callback_list: CallbackList instance. 137 model: Model being trained. 138 do_validation: Whether or not validation loop will be run. 139 batch_size: Number of samples per batch. 140 epochs: Number of epoch to train. 141 steps_per_epoch: Number of batches to run per training epoch. 142 samples: Number of training samples. 143 verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger. 144 mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT. 145 Which loop mode to configure callbacks for. 146 """ 147 metric_names = model.metrics_names 148 for cbk in callback_list: 149 if isinstance(cbk, (BaseLogger, ProgbarLogger)): 150 cbk.stateful_metrics = metric_names[1:] # Exclude `loss` 151 152 # Set callback parameters 153 callback_metrics = [] 154 # When we have deferred build scenario with iterator input, we will compile 155 # when we standardize first batch of data. 156 if mode != ModeKeys.PREDICT: 157 callback_metrics = copy.copy(metric_names) 158 if do_validation: 159 callback_metrics += ['val_' + n for n in metric_names] 160 callback_params = { 161 'batch_size': batch_size, 162 'epochs': epochs, 163 'steps': steps_per_epoch, 164 'samples': samples, 165 'verbose': verbose, 166 'do_validation': do_validation, 167 'metrics': callback_metrics, 168 } 169 callback_list.set_params(callback_params) 170 171 172def _is_generator_like(data): 173 """Checks if data is a generator, Sequence, or Iterator.""" 174 return (hasattr(data, 'next') or hasattr(data, '__next__') or isinstance( 175 data, (Sequence, iterator_ops.Iterator, iterator_ops.OwnedIterator))) 176 177 178def make_logs(model, logs, outputs, mode, prefix=''): 179 """Computes logs for sending to `on_batch_end` methods.""" 180 metric_names = model.metrics_names 181 if mode in {ModeKeys.TRAIN, ModeKeys.TEST} and metric_names: 182 for label, output in zip(metric_names, outputs): 183 logs[prefix + label] = output 184 else: 185 logs['outputs'] = outputs 186 return logs 187 188 189class CallbackList(object): 190 """Container abstracting a list of callbacks. 191 192 Arguments: 193 callbacks: List of `Callback` instances. 194 queue_length: Queue length for keeping 195 running statistics over callback execution time. 196 """ 197 198 def __init__(self, callbacks=None, queue_length=10): 199 callbacks = callbacks or [] 200 self.callbacks = [c for c in callbacks] 201 self.queue_length = queue_length 202 self.params = {} 203 self.model = None 204 self._reset_batch_timing() 205 206 def _reset_batch_timing(self): 207 self._delta_t_batch = 0. 208 self._delta_ts = collections.defaultdict( 209 lambda: collections.deque([], maxlen=self.queue_length)) 210 211 def append(self, callback): 212 self.callbacks.append(callback) 213 214 def set_params(self, params): 215 self.params = params 216 for callback in self.callbacks: 217 callback.set_params(params) 218 219 def set_model(self, model): 220 self.model = model 221 for callback in self.callbacks: 222 callback.set_model(model) 223 224 def _call_batch_hook(self, mode, hook, batch, logs=None): 225 """Helper function for all batch_{begin | end} methods.""" 226 if not self.callbacks: 227 return 228 hook_name = 'on_{mode}_batch_{hook}'.format(mode=mode, hook=hook) 229 if hook == 'begin': 230 self._t_enter_batch = time.time() 231 if hook == 'end': 232 # Batch is ending, calculate batch time. 233 self._delta_t_batch = time.time() - self._t_enter_batch 234 235 logs = logs or {} 236 t_before_callbacks = time.time() 237 for callback in self.callbacks: 238 batch_hook = getattr(callback, hook_name) 239 batch_hook(batch, logs) 240 self._delta_ts[hook_name].append(time.time() - t_before_callbacks) 241 242 delta_t_median = np.median(self._delta_ts[hook_name]) 243 if (self._delta_t_batch > 0. and 244 delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1): 245 logging.warning( 246 'Method (%s) is slow compared ' 247 'to the batch update (%f). Check your callbacks.', hook_name, 248 delta_t_median) 249 250 def _call_begin_hook(self, mode): 251 """Helper function for on_{train|test|predict}_begin methods.""" 252 if mode == ModeKeys.TRAIN: 253 self.on_train_begin() 254 elif mode == ModeKeys.TEST: 255 self.on_test_begin() 256 else: 257 self.on_predict_begin() 258 259 def _call_end_hook(self, mode): 260 """Helper function for on_{train|test|predict}_end methods.""" 261 if mode == ModeKeys.TRAIN: 262 self.on_train_end() 263 elif mode == ModeKeys.TEST: 264 self.on_test_end() 265 else: 266 self.on_predict_end() 267 268 def on_batch_begin(self, batch, logs=None): 269 self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs) 270 271 def on_batch_end(self, batch, logs=None): 272 self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs) 273 274 def on_epoch_begin(self, epoch, logs=None): 275 """Calls the `on_epoch_begin` methods of its callbacks. 276 277 This function should only be called during TRAIN mode. 278 279 Arguments: 280 epoch: integer, index of epoch. 281 logs: dict. Currently no data is passed to this argument for this method 282 but that may change in the future. 283 """ 284 logs = logs or {} 285 for callback in self.callbacks: 286 callback.on_epoch_begin(epoch, logs) 287 self._reset_batch_timing() 288 289 def on_epoch_end(self, epoch, logs=None): 290 """Calls the `on_epoch_end` methods of its callbacks. 291 292 This function should only be called during TRAIN mode. 293 294 Arguments: 295 epoch: integer, index of epoch. 296 logs: dict, metric results for this training epoch, and for the 297 validation epoch if validation is performed. Validation result keys 298 are prefixed with `val_`. 299 """ 300 logs = logs or {} 301 for callback in self.callbacks: 302 callback.on_epoch_end(epoch, logs) 303 304 def on_train_batch_begin(self, batch, logs=None): 305 """Calls the `on_train_batch_begin` methods of its callbacks. 306 307 Arguments: 308 batch: integer, index of batch within the current epoch. 309 logs: dict. Has keys `batch` and `size` representing the current batch 310 number and the size of the batch. 311 """ 312 self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs) 313 314 def on_train_batch_end(self, batch, logs=None): 315 """Calls the `on_train_batch_end` methods of its callbacks. 316 317 Arguments: 318 batch: integer, index of batch within the current epoch. 319 logs: dict. Metric results for this batch. 320 """ 321 self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs) 322 323 def on_test_batch_begin(self, batch, logs=None): 324 """Calls the `on_test_batch_begin` methods of its callbacks. 325 326 Arguments: 327 batch: integer, index of batch within the current epoch. 328 logs: dict. Has keys `batch` and `size` representing the current batch 329 number and the size of the batch. 330 """ 331 self._call_batch_hook(ModeKeys.TEST, 'begin', batch, logs=logs) 332 333 def on_test_batch_end(self, batch, logs=None): 334 """Calls the `on_test_batch_end` methods of its callbacks. 335 336 Arguments: 337 batch: integer, index of batch within the current epoch. 338 logs: dict. Metric results for this batch. 339 """ 340 self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs) 341 342 def on_predict_batch_begin(self, batch, logs=None): 343 """Calls the `on_predict_batch_begin` methods of its callbacks. 344 345 Arguments: 346 batch: integer, index of batch within the current epoch. 347 logs: dict. Has keys `batch` and `size` representing the current batch 348 number and the size of the batch. 349 """ 350 self._call_batch_hook(ModeKeys.PREDICT, 'begin', batch, logs=logs) 351 352 def on_predict_batch_end(self, batch, logs=None): 353 """Calls the `on_predict_batch_end` methods of its callbacks. 354 355 Arguments: 356 batch: integer, index of batch within the current epoch. 357 logs: dict. Metric results for this batch. 358 """ 359 self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs) 360 361 def on_train_begin(self, logs=None): 362 """Calls the `on_train_begin` methods of its callbacks. 363 364 Arguments: 365 logs: dict. Currently no data is passed to this argument for this method 366 but that may change in the future. 367 """ 368 for callback in self.callbacks: 369 callback.on_train_begin(logs) 370 371 def on_train_end(self, logs=None): 372 """Calls the `on_train_end` methods of its callbacks. 373 374 Arguments: 375 logs: dict. Currently no data is passed to this argument for this method 376 but that may change in the future. 377 """ 378 for callback in self.callbacks: 379 callback.on_train_end(logs) 380 381 def on_test_begin(self, logs=None): 382 """Calls the `on_test_begin` methods of its callbacks. 383 384 Arguments: 385 logs: dict. Currently no data is passed to this argument for this method 386 but that may change in the future. 387 """ 388 for callback in self.callbacks: 389 callback.on_test_begin(logs) 390 391 def on_test_end(self, logs=None): 392 """Calls the `on_test_end` methods of its callbacks. 393 394 Arguments: 395 logs: dict. Currently no data is passed to this argument for this method 396 but that may change in the future. 397 """ 398 for callback in self.callbacks: 399 callback.on_test_end(logs) 400 401 def on_predict_begin(self, logs=None): 402 """Calls the 'on_predict_begin` methods of its callbacks. 403 404 Arguments: 405 logs: dict. Currently no data is passed to this argument for this method 406 but that may change in the future. 407 """ 408 for callback in self.callbacks: 409 callback.on_predict_begin(logs) 410 411 def on_predict_end(self, logs=None): 412 """Calls the `on_predict_end` methods of its callbacks. 413 414 Arguments: 415 logs: dict. Currently no data is passed to this argument for this method 416 but that may change in the future. 417 """ 418 for callback in self.callbacks: 419 callback.on_predict_end(logs) 420 421 def __iter__(self): 422 return iter(self.callbacks) 423 424 425@keras_export('keras.callbacks.Callback') 426class Callback(object): 427 """Abstract base class used to build new callbacks. 428 429 Attributes: 430 params: dict. Training parameters 431 (eg. verbosity, batch size, number of epochs...). 432 model: instance of `keras.models.Model`. 433 Reference of the model being trained. 434 validation_data: Deprecated. Do not use. 435 436 The `logs` dictionary that callback methods 437 take as argument will contain keys for quantities relevant to 438 the current batch or epoch. 439 440 Currently, the `.fit()` method of the `Model` class 441 will include the following quantities in the `logs` that 442 it passes to its callbacks: 443 444 on_epoch_end: logs include `acc` and `loss`, and 445 optionally include `val_loss` 446 (if validation is enabled in `fit`), and `val_acc` 447 (if validation and accuracy monitoring are enabled). 448 on_batch_begin: logs include `size`, 449 the number of samples in the current batch. 450 on_batch_end: logs include `loss`, and optionally `acc` 451 (if accuracy monitoring is enabled). 452 """ 453 454 def __init__(self): 455 self.validation_data = None 456 self.model = None 457 # Whether this Callback should only run on the chief worker in a 458 # Multi-Worker setting. 459 # TODO(omalleyt): Make this attr public once solution is stable. 460 self._chief_worker_only = None 461 462 def set_params(self, params): 463 self.params = params 464 465 def set_model(self, model): 466 self.model = model 467 468 @doc_controls.for_subclass_implementers 469 def on_batch_begin(self, batch, logs=None): 470 """A backwards compatibility alias for `on_train_batch_begin`.""" 471 472 @doc_controls.for_subclass_implementers 473 def on_batch_end(self, batch, logs=None): 474 """A backwards compatibility alias for `on_train_batch_end`.""" 475 476 @doc_controls.for_subclass_implementers 477 def on_epoch_begin(self, epoch, logs=None): 478 """Called at the start of an epoch. 479 480 Subclasses should override for any actions to run. This function should only 481 be called during TRAIN mode. 482 483 Arguments: 484 epoch: integer, index of epoch. 485 logs: dict. Currently no data is passed to this argument for this method 486 but that may change in the future. 487 """ 488 489 @doc_controls.for_subclass_implementers 490 def on_epoch_end(self, epoch, logs=None): 491 """Called at the end of an epoch. 492 493 Subclasses should override for any actions to run. This function should only 494 be called during TRAIN mode. 495 496 Arguments: 497 epoch: integer, index of epoch. 498 logs: dict, metric results for this training epoch, and for the 499 validation epoch if validation is performed. Validation result keys 500 are prefixed with `val_`. 501 """ 502 503 @doc_controls.for_subclass_implementers 504 def on_train_batch_begin(self, batch, logs=None): 505 """Called at the beginning of a training batch in `fit` methods. 506 507 Subclasses should override for any actions to run. 508 509 Arguments: 510 batch: integer, index of batch within the current epoch. 511 logs: dict. Has keys `batch` and `size` representing the current batch 512 number and the size of the batch. 513 """ 514 # For backwards compatibility. 515 self.on_batch_begin(batch, logs=logs) 516 517 @doc_controls.for_subclass_implementers 518 def on_train_batch_end(self, batch, logs=None): 519 """Called at the end of a training batch in `fit` methods. 520 521 Subclasses should override for any actions to run. 522 523 Arguments: 524 batch: integer, index of batch within the current epoch. 525 logs: dict. Metric results for this batch. 526 """ 527 # For backwards compatibility. 528 self.on_batch_end(batch, logs=logs) 529 530 @doc_controls.for_subclass_implementers 531 def on_test_batch_begin(self, batch, logs=None): 532 """Called at the beginning of a batch in `evaluate` methods. 533 534 Also called at the beginning of a validation batch in the `fit` 535 methods, if validation data is provided. 536 537 Subclasses should override for any actions to run. 538 539 Arguments: 540 batch: integer, index of batch within the current epoch. 541 logs: dict. Has keys `batch` and `size` representing the current batch 542 number and the size of the batch. 543 """ 544 545 @doc_controls.for_subclass_implementers 546 def on_test_batch_end(self, batch, logs=None): 547 """Called at the end of a batch in `evaluate` methods. 548 549 Also called at the end of a validation batch in the `fit` 550 methods, if validation data is provided. 551 552 Subclasses should override for any actions to run. 553 554 Arguments: 555 batch: integer, index of batch within the current epoch. 556 logs: dict. Metric results for this batch. 557 """ 558 559 @doc_controls.for_subclass_implementers 560 def on_predict_batch_begin(self, batch, logs=None): 561 """Called at the beginning of a batch in `predict` methods. 562 563 Subclasses should override for any actions to run. 564 565 Arguments: 566 batch: integer, index of batch within the current epoch. 567 logs: dict. Has keys `batch` and `size` representing the current batch 568 number and the size of the batch. 569 """ 570 571 @doc_controls.for_subclass_implementers 572 def on_predict_batch_end(self, batch, logs=None): 573 """Called at the end of a batch in `predict` methods. 574 575 Subclasses should override for any actions to run. 576 577 Arguments: 578 batch: integer, index of batch within the current epoch. 579 logs: dict. Metric results for this batch. 580 """ 581 582 @doc_controls.for_subclass_implementers 583 def on_train_begin(self, logs=None): 584 """Called at the beginning of training. 585 586 Subclasses should override for any actions to run. 587 588 Arguments: 589 logs: dict. Currently no data is passed to this argument for this method 590 but that may change in the future. 591 """ 592 593 @doc_controls.for_subclass_implementers 594 def on_train_end(self, logs=None): 595 """Called at the end of training. 596 597 Subclasses should override for any actions to run. 598 599 Arguments: 600 logs: dict. Currently no data is passed to this argument for this method 601 but that may change in the future. 602 """ 603 604 @doc_controls.for_subclass_implementers 605 def on_test_begin(self, logs=None): 606 """Called at the beginning of evaluation or validation. 607 608 Subclasses should override for any actions to run. 609 610 Arguments: 611 logs: dict. Currently no data is passed to this argument for this method 612 but that may change in the future. 613 """ 614 615 @doc_controls.for_subclass_implementers 616 def on_test_end(self, logs=None): 617 """Called at the end of evaluation or validation. 618 619 Subclasses should override for any actions to run. 620 621 Arguments: 622 logs: dict. Currently no data is passed to this argument for this method 623 but that may change in the future. 624 """ 625 626 @doc_controls.for_subclass_implementers 627 def on_predict_begin(self, logs=None): 628 """Called at the beginning of prediction. 629 630 Subclasses should override for any actions to run. 631 632 Arguments: 633 logs: dict. Currently no data is passed to this argument for this method 634 but that may change in the future. 635 """ 636 637 @doc_controls.for_subclass_implementers 638 def on_predict_end(self, logs=None): 639 """Called at the end of prediction. 640 641 Subclasses should override for any actions to run. 642 643 Arguments: 644 logs: dict. Currently no data is passed to this argument for this method 645 but that may change in the future. 646 """ 647 648 649@keras_export('keras.callbacks.BaseLogger') 650class BaseLogger(Callback): 651 """Callback that accumulates epoch averages of metrics. 652 653 This callback is automatically applied to every Keras model. 654 655 Arguments: 656 stateful_metrics: Iterable of string names of metrics that 657 should *not* be averaged over an epoch. 658 Metrics in this list will be logged as-is in `on_epoch_end`. 659 All others will be averaged in `on_epoch_end`. 660 """ 661 662 def __init__(self, stateful_metrics=None): 663 super(BaseLogger, self).__init__() 664 self.stateful_metrics = set(stateful_metrics or []) 665 666 def on_epoch_begin(self, epoch, logs=None): 667 self.seen = 0 668 self.totals = {} 669 670 def on_batch_end(self, batch, logs=None): 671 logs = logs or {} 672 batch_size = logs.get('size', 0) 673 # In case of distribution strategy we can potentially run multiple steps 674 # at the same time, we should account for that in the `seen` calculation. 675 num_steps = logs.get('num_steps', 1) 676 self.seen += batch_size * num_steps 677 678 for k, v in logs.items(): 679 if k in self.stateful_metrics: 680 self.totals[k] = v 681 else: 682 if k in self.totals: 683 self.totals[k] += v * batch_size 684 else: 685 self.totals[k] = v * batch_size 686 687 def on_epoch_end(self, epoch, logs=None): 688 if logs is not None: 689 for k in self.params['metrics']: 690 if k in self.totals: 691 # Make value available to next callbacks. 692 if k in self.stateful_metrics: 693 logs[k] = self.totals[k] 694 else: 695 logs[k] = self.totals[k] / self.seen 696 697 698@keras_export('keras.callbacks.TerminateOnNaN') 699class TerminateOnNaN(Callback): 700 """Callback that terminates training when a NaN loss is encountered. 701 """ 702 703 def on_batch_end(self, batch, logs=None): 704 logs = logs or {} 705 loss = logs.get('loss') 706 if loss is not None: 707 if np.isnan(loss) or np.isinf(loss): 708 print('Batch %d: Invalid loss, terminating training' % (batch)) 709 self.model.stop_training = True 710 711 712@keras_export('keras.callbacks.ProgbarLogger') 713class ProgbarLogger(Callback): 714 """Callback that prints metrics to stdout. 715 716 Arguments: 717 count_mode: One of "steps" or "samples". 718 Whether the progress bar should 719 count samples seen or steps (batches) seen. 720 stateful_metrics: Iterable of string names of metrics that 721 should *not* be averaged over an epoch. 722 Metrics in this list will be logged as-is. 723 All others will be averaged over time (e.g. loss, etc). 724 725 Raises: 726 ValueError: In case of invalid `count_mode`. 727 """ 728 729 def __init__(self, count_mode='samples', stateful_metrics=None): 730 super(ProgbarLogger, self).__init__() 731 if count_mode == 'samples': 732 self.use_steps = False 733 elif count_mode == 'steps': 734 self.use_steps = True 735 else: 736 raise ValueError('Unknown `count_mode`: ' + str(count_mode)) 737 self.stateful_metrics = set(stateful_metrics or []) 738 self.log_values = None 739 740 def on_train_begin(self, logs=None): 741 self.verbose = self.params['verbose'] 742 self.epochs = self.params['epochs'] 743 744 def on_epoch_begin(self, epoch, logs=None): 745 self.seen = 0 746 if self.use_steps: 747 self.target = self.params['steps'] 748 else: 749 self.target = self.params['samples'] 750 751 if self.verbose: 752 if self.epochs > 1: 753 print('Epoch %d/%d' % (epoch + 1, self.epochs)) 754 self.progbar = Progbar( 755 target=self.target, 756 verbose=self.verbose, 757 stateful_metrics=self.stateful_metrics, 758 unit_name='step' if self.use_steps else 'sample') 759 760 def on_batch_begin(self, batch, logs=None): 761 self.log_values = [] 762 763 def on_batch_end(self, batch, logs=None): 764 logs = logs or {} 765 batch_size = logs.get('size', 0) 766 # In case of distribution strategy we can potentially run multiple steps 767 # at the same time, we should account for that in the `seen` calculation. 768 num_steps = logs.get('num_steps', 1) 769 if self.use_steps: 770 self.seen += num_steps 771 else: 772 self.seen += batch_size * num_steps 773 774 for k in self.params['metrics']: 775 if k in logs: 776 self.log_values.append((k, logs[k])) 777 778 # Skip progbar update for the last batch; 779 # will be handled by on_epoch_end. 780 if self.verbose and (self.target is None or self.seen < self.target): 781 self.progbar.update(self.seen, self.log_values) 782 783 def on_epoch_end(self, epoch, logs=None): 784 logs = logs or {} 785 for k in self.params['metrics']: 786 if k in logs: 787 self.log_values.append((k, logs[k])) 788 if self.verbose: 789 self.progbar.update(self.seen, self.log_values) 790 791 792@keras_export('keras.callbacks.History') 793class History(Callback): 794 """Callback that records events into a `History` object. 795 796 This callback is automatically applied to 797 every Keras model. The `History` object 798 gets returned by the `fit` method of models. 799 """ 800 801 def on_train_begin(self, logs=None): 802 self.epoch = [] 803 self.history = {} 804 805 def on_epoch_end(self, epoch, logs=None): 806 logs = logs or {} 807 self.epoch.append(epoch) 808 for k, v in logs.items(): 809 self.history.setdefault(k, []).append(v) 810 811 812@keras_export('keras.callbacks.ModelCheckpoint') 813class ModelCheckpoint(Callback): 814 """Callback to save the Keras model or model weights at some frequency. 815 816 `ModelCheckpoint` callback is used in conjunction with training using 817 `model.fit()` to save a model or weights (in a checkpoint file) at some 818 interval, so the model or weights can be loaded later to continue the training 819 from the state saved. 820 821 A few options this callback provides include: 822 823 - Whether to only keep the model that has achieved the "best performance" so 824 far, or whether to save the model at the end of every epoch regardless of 825 performance. 826 - Definition of 'best'; which quantity to monitor and whether it should be 827 maximized or minimized. 828 - The frequency it should save at. Currently, the callback supports saving at 829 the end of every epoch, or after a fixed number of training samples. 830 - Whether only weights are saved, or the whole model is saved. 831 832 Example: 833 834 ```python 835 EPOCHS = 10 836 checkpoint_filepath = '/tmp/checkpoint' 837 model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( 838 filepath=checkpoint_filepath, 839 save_weights_only=True, 840 monitor='val_acc', 841 mode='max', 842 save_best_only=True) 843 844 # Model weights are saved at the end of every epoch, if it's the best seen 845 # so far. 846 model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback]) 847 848 # The model weights (that are considered the best) are loaded into the model. 849 model.load_weights(checkpoint_filepath) 850 ``` 851 852 Arguments: 853 filepath: string, path to save the model file. `filepath` can contain 854 named formatting options, which will be filled the value of `epoch` and 855 keys in `logs` (passed in `on_epoch_end`). For example: if `filepath` is 856 `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model checkpoints 857 will be saved with the epoch number and the validation loss in the 858 filename. 859 monitor: quantity to monitor. 860 verbose: verbosity mode, 0 or 1. 861 save_best_only: if `save_best_only=True`, the latest best model according 862 to the quantity monitored will not be overwritten. 863 If `filepath` doesn't contain formatting options like `{epoch}` then 864 `filepath` will be overwritten by each new better model. 865 mode: one of {auto, min, max}. If `save_best_only=True`, the decision to 866 overwrite the current save file is made based on either the maximization 867 or the minimization of the monitored quantity. For `val_acc`, this 868 should be `max`, for `val_loss` this should be `min`, etc. In `auto` 869 mode, the direction is automatically inferred from the name of the 870 monitored quantity. 871 save_weights_only: if True, then only the model's weights will be saved 872 (`model.save_weights(filepath)`), else the full model is saved 873 (`model.save(filepath)`). 874 save_freq: `'epoch'` or integer. When using `'epoch'`, the callback saves 875 the model after each epoch. When using integer, the callback saves the 876 model at end of a batch at which this many samples have been seen since 877 last saving. Note that if the saving isn't aligned to epochs, the 878 monitored metric may potentially be less reliable (it could reflect as 879 little as 1 batch, since the metrics get reset every epoch). Defaults to 880 `'epoch'` 881 **kwargs: Additional arguments for backwards compatibility. Possible key 882 is `period`. 883 """ 884 885 def __init__(self, 886 filepath, 887 monitor='val_loss', 888 verbose=0, 889 save_best_only=False, 890 save_weights_only=False, 891 mode='auto', 892 save_freq='epoch', 893 **kwargs): 894 super(ModelCheckpoint, self).__init__() 895 self.monitor = monitor 896 self.verbose = verbose 897 self.filepath = filepath 898 self.save_best_only = save_best_only 899 self.save_weights_only = save_weights_only 900 self.save_freq = save_freq 901 self.epochs_since_last_save = 0 902 self._samples_seen_since_last_saving = 0 903 904 # Deprecated field `load_weights_on_restart` is for loading the checkpoint 905 # file from `filepath` at the start of `model.fit()` 906 # TODO(rchao): Remove the arg during next breaking release. 907 if 'load_weights_on_restart' in kwargs: 908 self.load_weights_on_restart = kwargs['load_weights_on_restart'] 909 logging.warning('`load_weights_on_restart` argument is deprecated. ' 910 'Please use `model.load_weights()` for loading weights ' 911 'before the start of `model.fit()`.') 912 else: 913 self.load_weights_on_restart = False 914 915 # Deprecated field `period` is for the number of epochs between which 916 # the model is saved. 917 if 'period' in kwargs: 918 self.period = kwargs['period'] 919 logging.warning('`period` argument is deprecated. Please use `save_freq` ' 920 'to specify the frequency in number of samples seen.') 921 else: 922 self.period = 1 923 924 if mode not in ['auto', 'min', 'max']: 925 logging.warning('ModelCheckpoint mode %s is unknown, ' 926 'fallback to auto mode.', mode) 927 mode = 'auto' 928 929 if mode == 'min': 930 self.monitor_op = np.less 931 self.best = np.Inf 932 elif mode == 'max': 933 self.monitor_op = np.greater 934 self.best = -np.Inf 935 else: 936 if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): 937 self.monitor_op = np.greater 938 self.best = -np.Inf 939 else: 940 self.monitor_op = np.less 941 self.best = np.Inf 942 943 if self.save_freq != 'epoch' and not isinstance(self.save_freq, int): 944 raise ValueError('Unrecognized save_freq: {}'.format(self.save_freq)) 945 946 # Only the chief worker writes model checkpoints, but all workers 947 # restore checkpoint at on_train_begin(). 948 self._chief_worker_only = False 949 950 def set_model(self, model): 951 self.model = model 952 # Use name matching rather than `isinstance` to avoid circular dependencies. 953 if (not self.save_weights_only and 954 not model._is_graph_network and # pylint: disable=protected-access 955 model.__class__.__name__ != 'Sequential'): 956 self.save_weights_only = True 957 958 def on_train_begin(self, logs=None): 959 # pylint: disable=protected-access 960 if self.model._in_multi_worker_mode(): 961 # MultiWorkerTrainingState is used to manage the training state needed 962 # for preemption-recovery of a worker in multi-worker training. 963 self.model._training_state = ( 964 training_state.MultiWorkerTrainingState(self.model, self.filepath)) 965 self._training_state = self.model._training_state 966 if self._training_state.restore(): 967 # If the training state needs to be and is successfully restored, 968 # it is recovering from a previous failure (or preemption). In such 969 # case, do not load the weights from user specified file path. 970 return 971 972 # If this is not multi worker training, restoring is not needed, or 973 # restoring failed, check if it should load weights on restart. 974 if self.load_weights_on_restart: 975 if (not self.model._in_multi_worker_mode() or 976 multi_worker_util.should_load_checkpoint()): 977 filepath_to_load = ( 978 self._get_most_recently_modified_file_matching_pattern( 979 self.filepath)) 980 if (filepath_to_load is not None and 981 training_state.checkpoint_exists(filepath_to_load)): 982 try: 983 # `filepath` may contain placeholders such as `{epoch:02d}`, and 984 # thus it attempts to load the most recently modified file with file 985 # name matching the pattern. 986 self.model.load_weights(filepath_to_load) 987 except (IOError, ValueError) as e: 988 raise ValueError('Error loading file from {}. Reason: {}'.format( 989 filepath_to_load, e)) 990 991 def on_train_end(self, logs=None): 992 # pylint: disable=protected-access 993 if self.model._in_multi_worker_mode(): 994 if self.model.stop_training or getattr( 995 self.model, '_successful_loop_finish', False): 996 # In multi-worker training, on successful exit of training, delete the 997 # training state backup file that was saved for the purpose of worker 998 # recovery. 999 self._training_state.delete_backup() 1000 # Restore the training state so the model is ready for next (possible) 1001 # multi worker training. 1002 del self._training_state 1003 del self.model._training_state 1004 1005 def on_batch_end(self, batch, logs=None): 1006 logs = logs or {} 1007 if isinstance(self.save_freq, int): 1008 self._samples_seen_since_last_saving += logs.get('size', 1) 1009 if self._samples_seen_since_last_saving >= self.save_freq: 1010 self._save_model(epoch=self._current_epoch, logs=logs) 1011 self._samples_seen_since_last_saving = 0 1012 1013 def on_epoch_begin(self, epoch, logs=None): 1014 self._current_epoch = epoch 1015 1016 def on_epoch_end(self, epoch, logs=None): 1017 self.epochs_since_last_save += 1 1018 # pylint: disable=protected-access 1019 if self.save_freq == 'epoch': 1020 if self.model._in_multi_worker_mode(): 1021 # Exclude training state variables in user-requested checkpoint file. 1022 with self._training_state.untrack_vars(): 1023 self._save_model(epoch=epoch, logs=logs) 1024 else: 1025 self._save_model(epoch=epoch, logs=logs) 1026 if self.model._in_multi_worker_mode(): 1027 # For multi-worker training, back up the weights and current training 1028 # state for possible future recovery. 1029 # TODO(rchao): Call `back_up` at finer period such as N steps. 1030 self._training_state.back_up(epoch) 1031 1032 def _save_model(self, epoch, logs): 1033 """Saves the model. 1034 1035 Arguments: 1036 epoch: the epoch this iteration is in. 1037 logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`. 1038 """ 1039 logs = logs or {} 1040 1041 if isinstance(self.save_freq, 1042 int) or self.epochs_since_last_save >= self.period: 1043 self.epochs_since_last_save = 0 1044 filepath = self._get_file_path(epoch, logs) 1045 1046 try: 1047 if self.save_best_only: 1048 current = logs.get(self.monitor) 1049 if current is None: 1050 logging.warning('Can save best model only with %s available, ' 1051 'skipping.', self.monitor) 1052 else: 1053 if self.monitor_op(current, self.best): 1054 if self.verbose > 0: 1055 print('\nEpoch %05d: %s improved from %0.5f to %0.5f,' 1056 ' saving model to %s' % (epoch + 1, self.monitor, 1057 self.best, current, filepath)) 1058 self.best = current 1059 if self.save_weights_only: 1060 self.model.save_weights(filepath, overwrite=True) 1061 else: 1062 self.model.save(filepath, overwrite=True) 1063 else: 1064 if self.verbose > 0: 1065 print('\nEpoch %05d: %s did not improve from %0.5f' % 1066 (epoch + 1, self.monitor, self.best)) 1067 else: 1068 if self.verbose > 0: 1069 print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath)) 1070 if self.save_weights_only: 1071 self.model.save_weights(filepath, overwrite=True) 1072 else: 1073 self.model.save(filepath, overwrite=True) 1074 1075 self._maybe_remove_file() 1076 except IOError as e: 1077 # `e.errno` appears to be `None` so checking the content of `e.args[0]`. 1078 if 'is a directory' in six.ensure_str(e.args[0]): 1079 raise IOError('Please specify a non-directory filepath for ' 1080 'ModelCheckpoint. Filepath used is an existing ' 1081 'directory: {}'.format(filepath)) 1082 1083 def _get_file_path(self, epoch, logs): 1084 """Returns the file path for checkpoint.""" 1085 # pylint: disable=protected-access 1086 if not self.model._in_multi_worker_mode( 1087 ) or multi_worker_util.should_save_checkpoint(): 1088 try: 1089 # `filepath` may contain placeholders such as `{epoch:02d}` and 1090 # `{mape:.2f}`. A mismatch between logged metrics and the path's 1091 # placeholders can cause formatting to fail. 1092 return self.filepath.format(epoch=epoch + 1, **logs) 1093 except KeyError as e: 1094 raise KeyError('Failed to format this callback filepath: "{}". ' 1095 'Reason: {}'.format(self.filepath, e)) 1096 else: 1097 # If this is multi-worker training, and this worker should not 1098 # save checkpoint, we use a temp filepath to store a dummy checkpoint, so 1099 # it writes to a file that will be removed at the end of `_save_model()` 1100 # call. This is because the SyncOnReadVariable needs to be synced across 1101 # all the workers in order to be read, and all workers need to initiate 1102 # that. 1103 self._temp_file_dir = tempfile.mkdtemp() 1104 extension = os.path.splitext(self.filepath)[1] 1105 return os.path.join(self._temp_file_dir, 'temp' + extension) 1106 1107 def _maybe_remove_file(self): 1108 # Remove the checkpoint directory in multi-worker training where this worker 1109 # should not checkpoint. It is a dummy directory previously saved for sync 1110 # distributed training. 1111 1112 if (self.model._in_multi_worker_mode() and # pylint: disable=protected-access 1113 not multi_worker_util.should_save_checkpoint()): 1114 file_io.delete_recursively(self._temp_file_dir) 1115 del self._temp_file_dir 1116 1117 def _get_most_recently_modified_file_matching_pattern(self, pattern): 1118 """Returns the most recently modified filepath matching pattern. 1119 1120 Pattern may contain python formatting placeholder. If 1121 `tf.train.latest_checkpoint()` does not return None, use that; otherwise, 1122 check for most recently modified one that matches the pattern. 1123 1124 In the rare case where there are more than one pattern-matching file having 1125 the same modified time that is most recent among all, return the filepath 1126 that is largest (by `>` operator, lexicographically using the numeric 1127 equivalents). This provides a tie-breaker when multiple files are most 1128 recent. Note that a larger `filepath` can sometimes indicate a later time of 1129 modification (for instance, when epoch/batch is used as formatting option), 1130 but not necessarily (when accuracy or loss is used). The tie-breaker is 1131 put in the logic as best effort to return the most recent, and to avoid 1132 undeterministic result. 1133 1134 Modified time of a file is obtained with `os.path.getmtime()`. 1135 1136 This utility function is best demonstrated via an example: 1137 1138 ```python 1139 file_pattern = 'f.batch{batch:02d}epoch{epoch:02d}.h5' 1140 test_dir = self.get_temp_dir() 1141 path_pattern = os.path.join(test_dir, file_pattern) 1142 file_paths = [ 1143 os.path.join(test_dir, file_name) for file_name in 1144 ['f.batch03epoch02.h5', 'f.batch02epoch02.h5', 'f.batch01epoch01.h5'] 1145 ] 1146 for file_path in file_paths: 1147 # Write something to each of the files 1148 self.assertEqual( 1149 _get_most_recently_modified_file_matching_pattern(path_pattern), 1150 file_paths[-1]) 1151 ``` 1152 1153 Arguments: 1154 pattern: The file pattern that may optionally contain python placeholder 1155 such as `{epoch:02d}`. 1156 1157 Returns: 1158 The most recently modified file's full filepath matching `pattern`. If 1159 `pattern` does not contain any placeholder, this returns the filepath 1160 that 1161 exactly matches `pattern`. Returns `None` if no match is found. 1162 """ 1163 dir_name = os.path.dirname(pattern) 1164 base_name = os.path.basename(pattern) 1165 base_name_regex = '^' + re.sub(r'{.*}', r'.*', base_name) + '$' 1166 1167 # If tf.train.latest_checkpoint tells us there exists a latest checkpoint, 1168 # use that as it is more robust than `os.path.getmtime()`. 1169 latest_tf_checkpoint = checkpoint_management.latest_checkpoint(dir_name) 1170 if latest_tf_checkpoint is not None and re.match( 1171 base_name_regex, os.path.basename(latest_tf_checkpoint)): 1172 return latest_tf_checkpoint 1173 1174 latest_mod_time = 0 1175 file_path_with_latest_mod_time = None 1176 n_file_with_latest_mod_time = 0 1177 file_path_with_largest_file_name = None 1178 1179 if file_io.file_exists(dir_name): 1180 for file_name in os.listdir(dir_name): 1181 # Only consider if `file_name` matches the pattern. 1182 if re.match(base_name_regex, file_name): 1183 file_path = os.path.join(dir_name, file_name) 1184 mod_time = os.path.getmtime(file_path) 1185 if (file_path_with_largest_file_name is None or 1186 file_path > file_path_with_largest_file_name): 1187 file_path_with_largest_file_name = file_path 1188 if mod_time > latest_mod_time: 1189 latest_mod_time = mod_time 1190 file_path_with_latest_mod_time = file_path 1191 # In the case a file with later modified time is found, reset 1192 # the counter for the number of files with latest modified time. 1193 n_file_with_latest_mod_time = 1 1194 elif mod_time == latest_mod_time: 1195 # In the case a file has modified time tied with the most recent, 1196 # increment the counter for the number of files with latest modified 1197 # time by 1. 1198 n_file_with_latest_mod_time += 1 1199 1200 if n_file_with_latest_mod_time == 1: 1201 # Return the sole file that has most recent modified time. 1202 return file_path_with_latest_mod_time 1203 else: 1204 # If there are more than one file having latest modified time, return 1205 # the file path with the largest file name. 1206 return file_path_with_largest_file_name 1207 1208 1209@keras_export('keras.callbacks.EarlyStopping') 1210class EarlyStopping(Callback): 1211 """Stop training when a monitored metric has stopped improving. 1212 1213 Assuming the goal of a training is to minimize the loss. With this, the 1214 metric to be monitored would be 'loss', and mode would be 'min'. A 1215 `model.fit()` training loop will check at end of every epoch whether 1216 the loss is no longer decreasing, considering the `min_delta` and 1217 `patience` if applicable. Once it's found no longer decreasing, 1218 `model.stop_training` is marked True and the training terminates. 1219 1220 The quantity to be monitored needs to be available in `logs` dict. 1221 To make it so, pass the loss or metrics at `model.compile()`. 1222 1223 Example: 1224 1225 >>> callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3) 1226 >>> # This callback will stop the training when there is no improvement in 1227 >>> # the validation loss for three consecutive epochs. 1228 >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) 1229 >>> model.compile(tf.keras.optimizers.SGD(), loss='mse') 1230 >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), 1231 ... epochs=10, callbacks=[callback]) 1232 Train on 5 samples 1233 Epoch 1/10 1234 5/5 [==============================] - ... loss: 6533.1904 1235 Epoch 2/10 1236 5/5 [==============================] - ... loss: 110183360.0000 1237 Epoch 3/10 1238 5/5 [==============================] - ... loss: 1862575718400.0000 1239 Epoch 4/10 1240 5/5 [==============================] - ... loss: 31485597793124352.0000 1241 """ 1242 1243 def __init__(self, 1244 monitor='val_loss', 1245 min_delta=0, 1246 patience=0, 1247 verbose=0, 1248 mode='auto', 1249 baseline=None, 1250 restore_best_weights=False): 1251 """Initialize an EarlyStopping callback. 1252 1253 Arguments: 1254 monitor: Quantity to be monitored. 1255 min_delta: Minimum change in the monitored quantity 1256 to qualify as an improvement, i.e. an absolute 1257 change of less than min_delta, will count as no 1258 improvement. 1259 patience: Number of epochs with no improvement 1260 after which training will be stopped. 1261 verbose: verbosity mode. 1262 mode: One of `{"auto", "min", "max"}`. In `min` mode, 1263 training will stop when the quantity 1264 monitored has stopped decreasing; in `max` 1265 mode it will stop when the quantity 1266 monitored has stopped increasing; in `auto` 1267 mode, the direction is automatically inferred 1268 from the name of the monitored quantity. 1269 baseline: Baseline value for the monitored quantity. 1270 Training will stop if the model doesn't show improvement over the 1271 baseline. 1272 restore_best_weights: Whether to restore model weights from 1273 the epoch with the best value of the monitored quantity. 1274 If False, the model weights obtained at the last step of 1275 training are used. 1276 """ 1277 super(EarlyStopping, self).__init__() 1278 1279 self.monitor = monitor 1280 self.patience = patience 1281 self.verbose = verbose 1282 self.baseline = baseline 1283 self.min_delta = abs(min_delta) 1284 self.wait = 0 1285 self.stopped_epoch = 0 1286 self.restore_best_weights = restore_best_weights 1287 self.best_weights = None 1288 1289 if mode not in ['auto', 'min', 'max']: 1290 logging.warning('EarlyStopping mode %s is unknown, ' 1291 'fallback to auto mode.', mode) 1292 mode = 'auto' 1293 1294 if mode == 'min': 1295 self.monitor_op = np.less 1296 elif mode == 'max': 1297 self.monitor_op = np.greater 1298 else: 1299 if 'acc' in self.monitor: 1300 self.monitor_op = np.greater 1301 else: 1302 self.monitor_op = np.less 1303 1304 if self.monitor_op == np.greater: 1305 self.min_delta *= 1 1306 else: 1307 self.min_delta *= -1 1308 1309 def on_train_begin(self, logs=None): 1310 # Allow instances to be re-used 1311 self.wait = 0 1312 self.stopped_epoch = 0 1313 if self.baseline is not None: 1314 self.best = self.baseline 1315 else: 1316 self.best = np.Inf if self.monitor_op == np.less else -np.Inf 1317 1318 def on_epoch_end(self, epoch, logs=None): 1319 current = self.get_monitor_value(logs) 1320 if current is None: 1321 return 1322 if self.monitor_op(current - self.min_delta, self.best): 1323 self.best = current 1324 self.wait = 0 1325 if self.restore_best_weights: 1326 self.best_weights = self.model.get_weights() 1327 else: 1328 self.wait += 1 1329 if self.wait >= self.patience: 1330 self.stopped_epoch = epoch 1331 self.model.stop_training = True 1332 if self.restore_best_weights: 1333 if self.verbose > 0: 1334 print('Restoring model weights from the end of the best epoch.') 1335 self.model.set_weights(self.best_weights) 1336 1337 def on_train_end(self, logs=None): 1338 if self.stopped_epoch > 0 and self.verbose > 0: 1339 print('Epoch %05d: early stopping' % (self.stopped_epoch + 1)) 1340 1341 def get_monitor_value(self, logs): 1342 logs = logs or {} 1343 monitor_value = logs.get(self.monitor) 1344 if monitor_value is None: 1345 logging.warning('Early stopping conditioned on metric `%s` ' 1346 'which is not available. Available metrics are: %s', 1347 self.monitor, ','.join(list(logs.keys()))) 1348 return monitor_value 1349 1350 1351@keras_export('keras.callbacks.RemoteMonitor') 1352class RemoteMonitor(Callback): 1353 """Callback used to stream events to a server. 1354 1355 Requires the `requests` library. 1356 Events are sent to `root + '/publish/epoch/end/'` by default. Calls are 1357 HTTP POST, with a `data` argument which is a 1358 JSON-encoded dictionary of event data. 1359 If send_as_json is set to True, the content type of the request will be 1360 application/json. Otherwise the serialized JSON will be sent within a form. 1361 1362 Arguments: 1363 root: String; root url of the target server. 1364 path: String; path relative to `root` to which the events will be sent. 1365 field: String; JSON field under which the data will be stored. 1366 The field is used only if the payload is sent within a form 1367 (i.e. send_as_json is set to False). 1368 headers: Dictionary; optional custom HTTP headers. 1369 send_as_json: Boolean; whether the request should be 1370 sent as application/json. 1371 """ 1372 1373 def __init__(self, 1374 root='http://localhost:9000', 1375 path='/publish/epoch/end/', 1376 field='data', 1377 headers=None, 1378 send_as_json=False): 1379 super(RemoteMonitor, self).__init__() 1380 1381 self.root = root 1382 self.path = path 1383 self.field = field 1384 self.headers = headers 1385 self.send_as_json = send_as_json 1386 1387 def on_epoch_end(self, epoch, logs=None): 1388 if requests is None: 1389 raise ImportError('RemoteMonitor requires the `requests` library.') 1390 logs = logs or {} 1391 send = {} 1392 send['epoch'] = epoch 1393 for k, v in logs.items(): 1394 # np.ndarray and np.generic are not scalar types 1395 # therefore we must unwrap their scalar values and 1396 # pass to the json-serializable dict 'send' 1397 if isinstance(v, (np.ndarray, np.generic)): 1398 send[k] = v.item() 1399 else: 1400 send[k] = v 1401 try: 1402 if self.send_as_json: 1403 requests.post(self.root + self.path, json=send, headers=self.headers) 1404 else: 1405 requests.post( 1406 self.root + self.path, {self.field: json.dumps(send)}, 1407 headers=self.headers) 1408 except requests.exceptions.RequestException: 1409 logging.warning('Warning: could not reach RemoteMonitor ' 1410 'root server at ' + str(self.root)) 1411 1412 1413@keras_export('keras.callbacks.LearningRateScheduler') 1414class LearningRateScheduler(Callback): 1415 """Learning rate scheduler. 1416 1417 Arguments: 1418 schedule: a function that takes an epoch index as input 1419 (integer, indexed from 0) and returns a new 1420 learning rate as output (float). 1421 verbose: int. 0: quiet, 1: update messages. 1422 1423 ```python 1424 # This function keeps the learning rate at 0.001 for the first ten epochs 1425 # and decreases it exponentially after that. 1426 def scheduler(epoch): 1427 if epoch < 10: 1428 return 0.001 1429 else: 1430 return 0.001 * tf.math.exp(0.1 * (10 - epoch)) 1431 1432 callback = tf.keras.callbacks.LearningRateScheduler(scheduler) 1433 model.fit(data, labels, epochs=100, callbacks=[callback], 1434 validation_data=(val_data, val_labels)) 1435 ``` 1436 """ 1437 1438 def __init__(self, schedule, verbose=0): 1439 super(LearningRateScheduler, self).__init__() 1440 self.schedule = schedule 1441 self.verbose = verbose 1442 1443 def on_epoch_begin(self, epoch, logs=None): 1444 if not hasattr(self.model.optimizer, 'lr'): 1445 raise ValueError('Optimizer must have a "lr" attribute.') 1446 try: # new API 1447 lr = float(K.get_value(self.model.optimizer.lr)) 1448 lr = self.schedule(epoch, lr) 1449 except TypeError: # Support for old API for backward compatibility 1450 lr = self.schedule(epoch) 1451 if not isinstance(lr, (ops.Tensor, float, np.float32, np.float64)): 1452 raise ValueError('The output of the "schedule" function ' 1453 'should be float.') 1454 if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating: 1455 raise ValueError('The dtype of Tensor should be float') 1456 K.set_value(self.model.optimizer.lr, K.get_value(lr)) 1457 if self.verbose > 0: 1458 print('\nEpoch %05d: LearningRateScheduler reducing learning ' 1459 'rate to %s.' % (epoch + 1, lr)) 1460 1461 def on_epoch_end(self, epoch, logs=None): 1462 logs = logs or {} 1463 logs['lr'] = K.get_value(self.model.optimizer.lr) 1464 1465 1466@keras_export('keras.callbacks.TensorBoard', v1=[]) 1467class TensorBoard(Callback): 1468 # pylint: disable=line-too-long 1469 """Enable visualizations for TensorBoard. 1470 1471 TensorBoard is a visualization tool provided with TensorFlow. 1472 1473 This callback logs events for TensorBoard, including: 1474 1475 * Metrics summary plots 1476 * Training graph visualization 1477 * Activation histograms 1478 * Sampled profiling 1479 1480 If you have installed TensorFlow with pip, you should be able 1481 to launch TensorBoard from the command line: 1482 1483 ```sh 1484 tensorboard --logdir=path_to_your_logs 1485 ``` 1486 1487 You can find more information about TensorBoard 1488 [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard). 1489 1490 Example: 1491 ```python 1492 tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs") 1493 model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) 1494 #run the tensorboard command to view the visualizations 1495 ``` 1496 1497 Arguments: 1498 log_dir: the path of the directory where to save the log files to be 1499 parsed by TensorBoard. 1500 histogram_freq: frequency (in epochs) at which to compute activation and 1501 weight histograms for the layers of the model. If set to 0, histograms 1502 won't be computed. Validation data (or split) must be specified for 1503 histogram visualizations. 1504 write_graph: whether to visualize the graph in TensorBoard. The log file 1505 can become quite large when write_graph is set to True. 1506 write_images: whether to write model weights to visualize as image in 1507 TensorBoard. 1508 update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`, 1509 writes the losses and metrics to TensorBoard after each batch. The same 1510 applies for `'epoch'`. If using an integer, let's say `1000`, the 1511 callback will write the metrics and losses to TensorBoard every 1000 1512 batches. Note that writing too frequently to TensorBoard can slow down 1513 your training. 1514 profile_batch: Profile the batch to sample compute characteristics. By 1515 default, it will profile the second batch. Set profile_batch=0 to 1516 disable profiling. Must run in TensorFlow eager mode. 1517 embeddings_freq: frequency (in epochs) at which embedding layers will 1518 be visualized. If set to 0, embeddings won't be visualized. 1519 embeddings_metadata: a dictionary which maps layer name to a file name in 1520 which metadata for this embedding layer is saved. See the 1521 [details]( 1522 https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional) 1523 about metadata files format. In case if the same metadata file is 1524 used for all embedding layers, string can be passed. 1525 1526 Raises: 1527 ValueError: If histogram_freq is set and no validation data is provided. 1528 """ 1529 1530 # pylint: enable=line-too-long 1531 1532 def __init__(self, 1533 log_dir='logs', 1534 histogram_freq=0, 1535 write_graph=True, 1536 write_images=False, 1537 update_freq='epoch', 1538 profile_batch=2, 1539 embeddings_freq=0, 1540 embeddings_metadata=None, 1541 **kwargs): 1542 super(TensorBoard, self).__init__() 1543 self._validate_kwargs(kwargs) 1544 1545 self.log_dir = log_dir 1546 self.histogram_freq = histogram_freq 1547 self.write_graph = write_graph 1548 self.write_images = write_images 1549 if update_freq == 'batch': 1550 self.update_freq = 1 1551 else: 1552 self.update_freq = update_freq 1553 self.embeddings_freq = embeddings_freq 1554 self.embeddings_metadata = embeddings_metadata 1555 1556 self._samples_seen = 0 1557 self._samples_seen_at_last_write = 0 1558 self._current_batch = 0 1559 1560 # A collection of file writers currently in use, to be closed when 1561 # training ends for this callback. Writers are keyed by the 1562 # directory name under the root logdir: e.g., "train" or 1563 # "validation". 1564 self._train_run_name = 'train' 1565 self._validation_run_name = 'validation' 1566 self._writers = {} 1567 1568 self._profile_batch = profile_batch 1569 # True when a trace is running. 1570 self._is_tracing = False 1571 1572 def _validate_kwargs(self, kwargs): 1573 """Handle arguments were supported in V1.""" 1574 if kwargs.get('write_grads', False): 1575 logging.warning('`write_grads` will be ignored in TensorFlow 2.0 ' 1576 'for the `TensorBoard` Callback.') 1577 if kwargs.get('batch_size', False): 1578 logging.warning('`batch_size` is no longer needed in the ' 1579 '`TensorBoard` Callback and will be ignored ' 1580 'in TensorFlow 2.0.') 1581 if kwargs.get('embeddings_layer_names', False): 1582 logging.warning('`embeddings_layer_names` is not supported in ' 1583 'TensorFlow 2.0. Instead, all `Embedding` layers ' 1584 'will be visualized.') 1585 if kwargs.get('embeddings_data', False): 1586 logging.warning('`embeddings_data` is not supported in TensorFlow ' 1587 '2.0. Instead, all `Embedding` variables will be ' 1588 'visualized.') 1589 1590 unrecognized_kwargs = set(kwargs.keys()) - { 1591 'write_grads', 'embeddings_layer_names', 'embeddings_data', 'batch_size' 1592 } 1593 1594 # Only allow kwargs that were supported in V1. 1595 if unrecognized_kwargs: 1596 raise ValueError('Unrecognized arguments in `TensorBoard` ' 1597 'Callback: ' + str(unrecognized_kwargs)) 1598 1599 def set_model(self, model): 1600 """Sets Keras model and writes graph if specified.""" 1601 self.model = model 1602 1603 # TensorBoard callback involves writing a summary file in a 1604 # possibly distributed settings. 1605 self._log_write_dir = distributed_file_utils.write_dirpath( 1606 self.log_dir, self.model._get_distribution_strategy()) # pylint: disable=protected-access 1607 1608 with context.eager_mode(): 1609 self._close_writers() 1610 if self.write_graph: 1611 with self._get_writer(self._train_run_name).as_default(): 1612 with summary_ops_v2.always_record_summaries(): 1613 if not model.run_eagerly: 1614 summary_ops_v2.graph(K.get_graph(), step=0) 1615 1616 summary_writable = ( 1617 self.model._is_graph_network or # pylint: disable=protected-access 1618 self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access 1619 if summary_writable: 1620 summary_ops_v2.keras_model('keras', self.model, step=0) 1621 1622 if self.embeddings_freq: 1623 self._configure_embeddings() 1624 1625 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 1626 self._prev_summary_recording = summary_state.is_recording 1627 self._prev_summary_writer = summary_state.writer 1628 self._prev_summary_step = summary_state.step 1629 1630 def _configure_embeddings(self): 1631 """Configure the Projector for embeddings.""" 1632 # TODO(omalleyt): Add integration tests. 1633 from tensorflow.python.keras.layers import embeddings 1634 try: 1635 from tensorboard.plugins import projector 1636 except ImportError: 1637 raise ImportError('Failed to import TensorBoard. Please make sure that ' 1638 'TensorBoard integration is complete."') 1639 config = projector.ProjectorConfig() 1640 for layer in self.model.layers: 1641 if isinstance(layer, embeddings.Embedding): 1642 embedding = config.embeddings.add() 1643 embedding.tensor_name = layer.embeddings.name 1644 1645 if self.embeddings_metadata is not None: 1646 if isinstance(self.embeddings_metadata, str): 1647 embedding.metadata_path = self.embeddings_metadata 1648 else: 1649 if layer.name in embedding.metadata_path: 1650 embedding.metadata_path = self.embeddings_metadata.pop(layer.name) 1651 1652 if self.embeddings_metadata: 1653 raise ValueError('Unrecognized `Embedding` layer names passed to ' 1654 '`keras.callbacks.TensorBoard` `embeddings_metadata` ' 1655 'argument: ' + str(self.embeddings_metadata.keys())) 1656 1657 class DummyWriter(object): 1658 """Dummy writer to conform to `Projector` API.""" 1659 1660 def __init__(self, logdir): 1661 self.logdir = logdir 1662 1663 def get_logdir(self): 1664 return self.logdir 1665 1666 writer = DummyWriter(self._log_write_dir) 1667 projector.visualize_embeddings(writer, config) 1668 1669 def _close_writers(self): 1670 """Close all remaining open file writers owned by this callback. 1671 1672 If there are no such file writers, this is a no-op. 1673 """ 1674 with context.eager_mode(): 1675 for writer in six.itervalues(self._writers): 1676 writer.close() 1677 self._writers.clear() 1678 1679 def _get_writer(self, writer_name): 1680 """Get a summary writer for the given subdirectory under the logdir. 1681 1682 A writer will be created if it does not yet exist. 1683 1684 Arguments: 1685 writer_name: The name of the directory for which to create or 1686 retrieve a writer. Should be either `self._train_run_name` or 1687 `self._validation_run_name`. 1688 1689 Returns: 1690 A `SummaryWriter` object. 1691 """ 1692 if writer_name not in self._writers: 1693 path = os.path.join(self._log_write_dir, writer_name) 1694 writer = summary_ops_v2.create_file_writer_v2(path) 1695 self._writers[writer_name] = writer 1696 return self._writers[writer_name] 1697 1698 def _set_default_writer(self, writer_name): 1699 """Sets the default writer for custom batch-level summaries.""" 1700 if self.update_freq == 'epoch': 1701 # Writer is only used for custom summaries, which are written 1702 # batch-by-batch. 1703 return 1704 1705 step = self._total_batches_seen[writer_name] 1706 1707 def _should_record(): 1708 return math_ops.equal(step % self.update_freq, 0) 1709 1710 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 1711 summary_state.is_recording = _should_record 1712 summary_state.writer = self._get_writer(writer_name) 1713 summary_ops_v2.set_step(step) 1714 1715 def _init_batch_steps(self): 1716 """Create the total batch counters.""" 1717 if ops.executing_eagerly_outside_functions(): 1718 # Variables are needed for the `step` value of custom tf.summaries 1719 # to be updated inside a tf.function. 1720 self._total_batches_seen = { 1721 self._train_run_name: variables.Variable(0, dtype='int64'), 1722 self._validation_run_name: variables.Variable(0, dtype='int64') 1723 } 1724 else: 1725 # Custom tf.summaries are not supported in legacy graph mode. 1726 self._total_batches_seen = { 1727 self._train_run_name: 0, 1728 self._validation_run_name: 0 1729 } 1730 1731 def _increment_step(self, writer_name): 1732 step = self._total_batches_seen[writer_name] 1733 if isinstance(step, variables.Variable): 1734 step.assign_add(1) 1735 else: 1736 self._total_batches_seen[writer_name] += 1 1737 1738 def on_train_begin(self, logs=None): 1739 self._init_batch_steps() 1740 if self._profile_batch == 1: 1741 summary_ops_v2.trace_on(graph=True, profiler=True) 1742 self._is_tracing = True 1743 1744 def on_test_begin(self, logs=None): 1745 self._set_default_writer(self._validation_run_name) 1746 1747 def on_train_batch_end(self, batch, logs=None): 1748 """Writes scalar summaries for metrics on every training batch. 1749 1750 Performs profiling if current batch is in profiler_batches. 1751 1752 Arguments: 1753 batch: Integer, index of batch within the current epoch. 1754 logs: Dict. Metric results for this batch. 1755 """ 1756 if self.update_freq == 'epoch' and self._profile_batch is None: 1757 return 1758 1759 # Don't output batch_size and batch number as TensorBoard summaries 1760 logs = logs or {} 1761 train_batches = self._total_batches_seen[self._train_run_name] 1762 if self.update_freq != 'epoch' and batch % self.update_freq == 0: 1763 self._log_metrics(logs, prefix='batch_', step=train_batches) 1764 1765 self._increment_step(self._train_run_name) 1766 1767 if context.executing_eagerly(): 1768 if self._is_tracing: 1769 self._log_trace() 1770 elif (not self._is_tracing and 1771 math_ops.equal(train_batches, self._profile_batch - 1)): 1772 self._enable_trace() 1773 1774 def on_test_batch_end(self, batch, logs=None): 1775 if self.update_freq == 'epoch': 1776 return 1777 self._increment_step(self._validation_run_name) 1778 1779 def on_epoch_begin(self, epoch, logs=None): 1780 self._set_default_writer(self._train_run_name) 1781 1782 def on_epoch_end(self, epoch, logs=None): 1783 """Runs metrics and histogram summaries at epoch end.""" 1784 self._log_metrics(logs, prefix='epoch_', step=epoch) 1785 1786 if self.histogram_freq and epoch % self.histogram_freq == 0: 1787 self._log_weights(epoch) 1788 1789 if self.embeddings_freq and epoch % self.embeddings_freq == 0: 1790 self._log_embeddings(epoch) 1791 1792 def on_train_end(self, logs=None): 1793 if self._is_tracing: 1794 self._log_trace() 1795 self._close_writers() 1796 1797 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 1798 summary_state.is_recording = self._prev_summary_recording 1799 summary_state.writer = self._prev_summary_writer 1800 summary_state.step = self._prev_summary_step 1801 1802 # Safely remove the unneeded temp files. 1803 distributed_file_utils.remove_temp_dirpath( 1804 self.log_dir, self.model._get_distribution_strategy()) # pylint: disable=protected-access 1805 1806 def _enable_trace(self): 1807 if context.executing_eagerly(): 1808 summary_ops_v2.trace_on(graph=True, profiler=True) 1809 self._is_tracing = True 1810 1811 def _log_trace(self): 1812 """Logs the trace graph to TensorBoard.""" 1813 if context.executing_eagerly(): 1814 with self._get_writer(self._train_run_name).as_default(), \ 1815 summary_ops_v2.always_record_summaries(): 1816 # TODO(b/126388999): Remove step info in the summary name. 1817 step = K.get_value(self._total_batches_seen[self._train_run_name]) 1818 summary_ops_v2.trace_export( 1819 name='batch_%d' % step, 1820 step=step, 1821 profiler_outdir=os.path.join(self._log_write_dir, 'train')) 1822 self._is_tracing = False 1823 1824 def _log_metrics(self, logs, prefix, step): 1825 """Writes metrics out as custom scalar summaries. 1826 1827 Arguments: 1828 logs: Dict. Keys are scalar summary names, values are NumPy scalars. 1829 prefix: String. The prefix to apply to the scalar summary names. 1830 step: Int. The global step to use for TensorBoard. 1831 """ 1832 if logs is None: 1833 logs = {} 1834 1835 # Group metrics by the name of their associated file writer. Values 1836 # are lists of metrics, as (name, scalar_value) pairs. 1837 logs_by_writer = { 1838 self._train_run_name: [], 1839 self._validation_run_name: [], 1840 } 1841 validation_prefix = 'val_' 1842 for (name, value) in logs.items(): 1843 if name in ('batch', 'size', 'num_steps'): 1844 # Scrub non-metric items. 1845 continue 1846 if name.startswith(validation_prefix): 1847 name = name[len(validation_prefix):] 1848 writer_name = self._validation_run_name 1849 else: 1850 writer_name = self._train_run_name 1851 name = prefix + name # assign batch or epoch prefix 1852 logs_by_writer[writer_name].append((name, value)) 1853 1854 with context.eager_mode(): 1855 with summary_ops_v2.always_record_summaries(): 1856 for writer_name in logs_by_writer: 1857 these_logs = logs_by_writer[writer_name] 1858 if not these_logs: 1859 # Don't create a "validation" events file if we don't 1860 # actually have any validation data. 1861 continue 1862 writer = self._get_writer(writer_name) 1863 with writer.as_default(): 1864 for (name, value) in these_logs: 1865 summary_ops_v2.scalar(name, value, step=step) 1866 1867 def _log_weights(self, epoch): 1868 """Logs the weights of the Model to TensorBoard.""" 1869 writer = self._get_writer(self._train_run_name) 1870 with context.eager_mode(), \ 1871 writer.as_default(), \ 1872 summary_ops_v2.always_record_summaries(): 1873 for layer in self.model.layers: 1874 for weight in layer.weights: 1875 weight_name = weight.name.replace(':', '_') 1876 with ops.init_scope(): 1877 weight = K.get_value(weight) 1878 summary_ops_v2.histogram(weight_name, weight, step=epoch) 1879 if self.write_images: 1880 self._log_weight_as_image(weight, weight_name, epoch) 1881 writer.flush() 1882 1883 def _log_weight_as_image(self, weight, weight_name, epoch): 1884 """Logs a weight as a TensorBoard image.""" 1885 w_img = array_ops.squeeze(weight) 1886 shape = K.int_shape(w_img) 1887 if len(shape) == 1: # Bias case 1888 w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1]) 1889 elif len(shape) == 2: # Dense layer kernel case 1890 if shape[0] > shape[1]: 1891 w_img = array_ops.transpose(w_img) 1892 shape = K.int_shape(w_img) 1893 w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1]) 1894 elif len(shape) == 3: # ConvNet case 1895 if K.image_data_format() == 'channels_last': 1896 # Switch to channels_first to display every kernel as a separate 1897 # image. 1898 w_img = array_ops.transpose(w_img, perm=[2, 0, 1]) 1899 shape = K.int_shape(w_img) 1900 w_img = array_ops.reshape(w_img, [shape[0], shape[1], shape[2], 1]) 1901 1902 shape = K.int_shape(w_img) 1903 # Not possible to handle 3D convnets etc. 1904 if len(shape) == 4 and shape[-1] in [1, 3, 4]: 1905 summary_ops_v2.image(weight_name, w_img, step=epoch) 1906 1907 def _log_embeddings(self, epoch): 1908 embeddings_ckpt = os.path.join(self._log_write_dir, 'train', 1909 'keras_embedding.ckpt-{}'.format(epoch)) 1910 self.model.save_weights(embeddings_ckpt) 1911 1912 1913@keras_export('keras.callbacks.ReduceLROnPlateau') 1914class ReduceLROnPlateau(Callback): 1915 """Reduce learning rate when a metric has stopped improving. 1916 1917 Models often benefit from reducing the learning rate by a factor 1918 of 2-10 once learning stagnates. This callback monitors a 1919 quantity and if no improvement is seen for a 'patience' number 1920 of epochs, the learning rate is reduced. 1921 1922 Example: 1923 1924 ```python 1925 reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, 1926 patience=5, min_lr=0.001) 1927 model.fit(X_train, Y_train, callbacks=[reduce_lr]) 1928 ``` 1929 1930 Arguments: 1931 monitor: quantity to be monitored. 1932 factor: factor by which the learning rate will be reduced. new_lr = lr * 1933 factor 1934 patience: number of epochs with no improvement after which learning rate 1935 will be reduced. 1936 verbose: int. 0: quiet, 1: update messages. 1937 mode: one of {auto, min, max}. In `min` mode, lr will be reduced when the 1938 quantity monitored has stopped decreasing; in `max` mode it will be 1939 reduced when the quantity monitored has stopped increasing; in `auto` 1940 mode, the direction is automatically inferred from the name of the 1941 monitored quantity. 1942 min_delta: threshold for measuring the new optimum, to only focus on 1943 significant changes. 1944 cooldown: number of epochs to wait before resuming normal operation after 1945 lr has been reduced. 1946 min_lr: lower bound on the learning rate. 1947 """ 1948 1949 def __init__(self, 1950 monitor='val_loss', 1951 factor=0.1, 1952 patience=10, 1953 verbose=0, 1954 mode='auto', 1955 min_delta=1e-4, 1956 cooldown=0, 1957 min_lr=0, 1958 **kwargs): 1959 super(ReduceLROnPlateau, self).__init__() 1960 1961 self.monitor = monitor 1962 if factor >= 1.0: 1963 raise ValueError('ReduceLROnPlateau ' 'does not support a factor >= 1.0.') 1964 if 'epsilon' in kwargs: 1965 min_delta = kwargs.pop('epsilon') 1966 logging.warning('`epsilon` argument is deprecated and ' 1967 'will be removed, use `min_delta` instead.') 1968 self.factor = factor 1969 self.min_lr = min_lr 1970 self.min_delta = min_delta 1971 self.patience = patience 1972 self.verbose = verbose 1973 self.cooldown = cooldown 1974 self.cooldown_counter = 0 # Cooldown counter. 1975 self.wait = 0 1976 self.best = 0 1977 self.mode = mode 1978 self.monitor_op = None 1979 self._reset() 1980 1981 def _reset(self): 1982 """Resets wait counter and cooldown counter. 1983 """ 1984 if self.mode not in ['auto', 'min', 'max']: 1985 logging.warning('Learning Rate Plateau Reducing mode %s is unknown, ' 1986 'fallback to auto mode.', self.mode) 1987 self.mode = 'auto' 1988 if (self.mode == 'min' or 1989 (self.mode == 'auto' and 'acc' not in self.monitor)): 1990 self.monitor_op = lambda a, b: np.less(a, b - self.min_delta) 1991 self.best = np.Inf 1992 else: 1993 self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta) 1994 self.best = -np.Inf 1995 self.cooldown_counter = 0 1996 self.wait = 0 1997 1998 def on_train_begin(self, logs=None): 1999 self._reset() 2000 2001 def on_epoch_end(self, epoch, logs=None): 2002 logs = logs or {} 2003 logs['lr'] = K.get_value(self.model.optimizer.lr) 2004 current = logs.get(self.monitor) 2005 if current is None: 2006 logging.warning('Reduce LR on plateau conditioned on metric `%s` ' 2007 'which is not available. Available metrics are: %s', 2008 self.monitor, ','.join(list(logs.keys()))) 2009 2010 else: 2011 if self.in_cooldown(): 2012 self.cooldown_counter -= 1 2013 self.wait = 0 2014 2015 if self.monitor_op(current, self.best): 2016 self.best = current 2017 self.wait = 0 2018 elif not self.in_cooldown(): 2019 self.wait += 1 2020 if self.wait >= self.patience: 2021 old_lr = float(K.get_value(self.model.optimizer.lr)) 2022 if old_lr > self.min_lr: 2023 new_lr = old_lr * self.factor 2024 new_lr = max(new_lr, self.min_lr) 2025 K.set_value(self.model.optimizer.lr, new_lr) 2026 if self.verbose > 0: 2027 print('\nEpoch %05d: ReduceLROnPlateau reducing learning ' 2028 'rate to %s.' % (epoch + 1, new_lr)) 2029 self.cooldown_counter = self.cooldown 2030 self.wait = 0 2031 2032 def in_cooldown(self): 2033 return self.cooldown_counter > 0 2034 2035 2036@keras_export('keras.callbacks.CSVLogger') 2037class CSVLogger(Callback): 2038 """Callback that streams epoch results to a csv file. 2039 2040 Supports all values that can be represented as a string, 2041 including 1D iterables such as np.ndarray. 2042 2043 Example: 2044 2045 ```python 2046 csv_logger = CSVLogger('training.log') 2047 model.fit(X_train, Y_train, callbacks=[csv_logger]) 2048 ``` 2049 2050 Arguments: 2051 filename: filename of the csv file, e.g. 'run/log.csv'. 2052 separator: string used to separate elements in the csv file. 2053 append: True: append if file exists (useful for continuing 2054 training). False: overwrite existing file, 2055 """ 2056 2057 def __init__(self, filename, separator=',', append=False): 2058 self.sep = separator 2059 self.filename = filename 2060 self.append = append 2061 self.writer = None 2062 self.keys = None 2063 self.append_header = True 2064 if six.PY2: 2065 self.file_flags = 'b' 2066 self._open_args = {} 2067 else: 2068 self.file_flags = '' 2069 self._open_args = {'newline': '\n'} 2070 super(CSVLogger, self).__init__() 2071 2072 def on_train_begin(self, logs=None): 2073 if self.append: 2074 if file_io.file_exists(self.filename): 2075 with open(self.filename, 'r' + self.file_flags) as f: 2076 self.append_header = not bool(len(f.readline())) 2077 mode = 'a' 2078 else: 2079 mode = 'w' 2080 self.csv_file = io.open(self.filename, 2081 mode + self.file_flags, 2082 **self._open_args) 2083 2084 def on_epoch_end(self, epoch, logs=None): 2085 logs = logs or {} 2086 2087 def handle_value(k): 2088 is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0 2089 if isinstance(k, six.string_types): 2090 return k 2091 elif isinstance(k, collections_abc.Iterable) and not is_zero_dim_ndarray: 2092 return '"[%s]"' % (', '.join(map(str, k))) 2093 else: 2094 return k 2095 2096 if self.keys is None: 2097 self.keys = sorted(logs.keys()) 2098 2099 if self.model.stop_training: 2100 # We set NA so that csv parsers do not fail for this last epoch. 2101 logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys]) 2102 2103 if not self.writer: 2104 2105 class CustomDialect(csv.excel): 2106 delimiter = self.sep 2107 2108 fieldnames = ['epoch'] + self.keys 2109 if six.PY2: 2110 fieldnames = [unicode(x) for x in fieldnames] 2111 2112 self.writer = csv.DictWriter( 2113 self.csv_file, 2114 fieldnames=fieldnames, 2115 dialect=CustomDialect) 2116 if self.append_header: 2117 self.writer.writeheader() 2118 2119 row_dict = collections.OrderedDict({'epoch': epoch}) 2120 row_dict.update((key, handle_value(logs[key])) for key in self.keys) 2121 self.writer.writerow(row_dict) 2122 self.csv_file.flush() 2123 2124 def on_train_end(self, logs=None): 2125 self.csv_file.close() 2126 self.writer = None 2127 2128 2129@keras_export('keras.callbacks.LambdaCallback') 2130class LambdaCallback(Callback): 2131 r"""Callback for creating simple, custom callbacks on-the-fly. 2132 2133 This callback is constructed with anonymous functions that will be called 2134 at the appropriate time. Note that the callbacks expects positional 2135 arguments, as: 2136 2137 - `on_epoch_begin` and `on_epoch_end` expect two positional arguments: 2138 `epoch`, `logs` 2139 - `on_batch_begin` and `on_batch_end` expect two positional arguments: 2140 `batch`, `logs` 2141 - `on_train_begin` and `on_train_end` expect one positional argument: 2142 `logs` 2143 2144 Arguments: 2145 on_epoch_begin: called at the beginning of every epoch. 2146 on_epoch_end: called at the end of every epoch. 2147 on_batch_begin: called at the beginning of every batch. 2148 on_batch_end: called at the end of every batch. 2149 on_train_begin: called at the beginning of model training. 2150 on_train_end: called at the end of model training. 2151 2152 Example: 2153 2154 ```python 2155 # Print the batch number at the beginning of every batch. 2156 batch_print_callback = LambdaCallback( 2157 on_batch_begin=lambda batch,logs: print(batch)) 2158 2159 # Stream the epoch loss to a file in JSON format. The file content 2160 # is not well-formed JSON but rather has a JSON object per line. 2161 import json 2162 json_log = open('loss_log.json', mode='wt', buffering=1) 2163 json_logging_callback = LambdaCallback( 2164 on_epoch_end=lambda epoch, logs: json_log.write( 2165 json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'), 2166 on_train_end=lambda logs: json_log.close() 2167 ) 2168 2169 # Terminate some processes after having finished model training. 2170 processes = ... 2171 cleanup_callback = LambdaCallback( 2172 on_train_end=lambda logs: [ 2173 p.terminate() for p in processes if p.is_alive()]) 2174 2175 model.fit(..., 2176 callbacks=[batch_print_callback, 2177 json_logging_callback, 2178 cleanup_callback]) 2179 ``` 2180 """ 2181 2182 def __init__(self, 2183 on_epoch_begin=None, 2184 on_epoch_end=None, 2185 on_batch_begin=None, 2186 on_batch_end=None, 2187 on_train_begin=None, 2188 on_train_end=None, 2189 **kwargs): 2190 super(LambdaCallback, self).__init__() 2191 self.__dict__.update(kwargs) 2192 if on_epoch_begin is not None: 2193 self.on_epoch_begin = on_epoch_begin 2194 else: 2195 self.on_epoch_begin = lambda epoch, logs: None 2196 if on_epoch_end is not None: 2197 self.on_epoch_end = on_epoch_end 2198 else: 2199 self.on_epoch_end = lambda epoch, logs: None 2200 if on_batch_begin is not None: 2201 self.on_batch_begin = on_batch_begin 2202 else: 2203 self.on_batch_begin = lambda batch, logs: None 2204 if on_batch_end is not None: 2205 self.on_batch_end = on_batch_end 2206 else: 2207 self.on_batch_end = lambda batch, logs: None 2208 if on_train_begin is not None: 2209 self.on_train_begin = on_train_begin 2210 else: 2211 self.on_train_begin = lambda logs: None 2212 if on_train_end is not None: 2213 self.on_train_end = on_train_end 2214 else: 2215 self.on_train_end = lambda logs: None 2216