• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-import-not-at-top
16# pylint: disable=g-classes-have-attributes
17"""Callbacks: utilities called at certain points during model training.
18"""
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import collections
24import copy
25import csv
26import io
27import json
28import os
29import re
30import sys
31import time
32
33import numpy as np
34import six
35
36from tensorflow.core.framework import summary_pb2
37from tensorflow.python.data.ops import iterator_ops
38from tensorflow.python.distribute import collective_all_reduce_strategy
39from tensorflow.python.distribute import mirrored_strategy
40from tensorflow.python.distribute import tpu_strategy
41from tensorflow.python.eager import context
42from tensorflow.python.framework import constant_op
43from tensorflow.python.framework import dtypes
44from tensorflow.python.framework import ops
45from tensorflow.python.keras import backend as K
46from tensorflow.python.keras.distribute import distributed_file_utils
47from tensorflow.python.keras.distribute import worker_training_state
48from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
49from tensorflow.python.keras.utils import generic_utils
50from tensorflow.python.keras.utils import tf_utils
51from tensorflow.python.keras.utils import version_utils
52from tensorflow.python.keras.utils.data_utils import Sequence
53from tensorflow.python.keras.utils.generic_utils import Progbar
54from tensorflow.python.keras.utils.io_utils import path_to_string
55from tensorflow.python.keras.utils.mode_keys import ModeKeys
56from tensorflow.python.lib.io import file_io
57from tensorflow.python.ops import array_ops
58from tensorflow.python.ops import math_ops
59from tensorflow.python.ops import summary_ops_v2
60from tensorflow.python.platform import gfile
61from tensorflow.python.platform import tf_logging as logging
62from tensorflow.python.profiler import profiler_v2 as profiler
63from tensorflow.python.saved_model import save_options as save_options_lib
64from tensorflow.python.training import checkpoint_management
65from tensorflow.python.training.saving import checkpoint_options as checkpoint_options_lib
66from tensorflow.python.util import nest
67from tensorflow.python.util.tf_export import keras_export
68from tensorflow.tools.docs import doc_controls
69
70try:
71  import requests
72except ImportError:
73  requests = None
74
75
76# Note: `configure_callbacks` is only used in TF1.
77def configure_callbacks(callbacks,
78                        model,
79                        do_validation=False,
80                        batch_size=None,
81                        epochs=None,
82                        steps_per_epoch=None,
83                        samples=None,
84                        verbose=1,
85                        count_mode='steps',
86                        mode=ModeKeys.TRAIN):
87  """Configures callbacks for use in various training loops.
88
89  Args:
90      callbacks: List of Callbacks.
91      model: Model being trained.
92      do_validation: Whether or not validation loop will be run.
93      batch_size: Number of samples per batch.
94      epochs: Number of epoch to train.
95      steps_per_epoch: Number of batches to run per training epoch.
96      samples: Number of training samples.
97      verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger.
98      count_mode: One of 'steps' or 'samples'. Per-batch or per-sample count.
99      mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT.
100        Which loop mode to configure callbacks for.
101
102  Returns:
103      Instance of CallbackList used to control all Callbacks.
104  """
105  # Check if callbacks have already been configured.
106  if isinstance(callbacks, CallbackList):
107    return callbacks
108
109  if not callbacks:
110    callbacks = []
111
112  # Add additional callbacks during training.
113  if mode == ModeKeys.TRAIN:
114    model.history = History()
115    callbacks = [BaseLogger()] + (callbacks or []) + [model.history]
116    if verbose:
117      callbacks.append(ProgbarLogger(count_mode))
118  callback_list = CallbackList(callbacks)
119
120  # Set callback model
121  callback_model = model._get_callback_model()  # pylint: disable=protected-access
122  callback_list.set_model(callback_model)
123
124  set_callback_parameters(
125      callback_list,
126      model,
127      do_validation=do_validation,
128      batch_size=batch_size,
129      epochs=epochs,
130      steps_per_epoch=steps_per_epoch,
131      samples=samples,
132      verbose=verbose,
133      mode=mode)
134
135  callback_list.model.stop_training = False
136  return callback_list
137
138
139def set_callback_parameters(callback_list,
140                            model,
141                            do_validation=False,
142                            batch_size=None,
143                            epochs=None,
144                            steps_per_epoch=None,
145                            samples=None,
146                            verbose=1,
147                            mode=ModeKeys.TRAIN):
148  """Sets callback parameters.
149
150  Args:
151      callback_list: CallbackList instance.
152      model: Model being trained.
153      do_validation: Whether or not validation loop will be run.
154      batch_size: Number of samples per batch.
155      epochs: Number of epoch to train.
156      steps_per_epoch: Number of batches to run per training epoch.
157      samples: Number of training samples.
158      verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger.
159      mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT.
160        Which loop mode to configure callbacks for.
161  """
162  metric_names = model.metrics_names
163  for cbk in callback_list:
164    if isinstance(cbk, (BaseLogger, ProgbarLogger)):
165      cbk.stateful_metrics = metric_names[1:]  # Exclude `loss`
166
167  # Set callback parameters
168  callback_metrics = []
169  # When we have deferred build scenario with iterator input, we will compile
170  # when we standardize first batch of data.
171  if mode != ModeKeys.PREDICT:
172    callback_metrics = copy.copy(metric_names)
173    if do_validation:
174      callback_metrics += ['val_' + n for n in metric_names]
175  callback_params = {
176      'batch_size': batch_size,
177      'epochs': epochs,
178      'steps': steps_per_epoch,
179      'samples': samples,
180      'verbose': verbose,
181      'do_validation': do_validation,
182      'metrics': callback_metrics,
183  }
184  callback_list.set_params(callback_params)
185
186
187def _is_generator_like(data):
188  """Checks if data is a generator, Sequence, or Iterator."""
189  return (hasattr(data, '__next__') or hasattr(data, 'next') or isinstance(
190      data, (Sequence, iterator_ops.Iterator, iterator_ops.IteratorBase)))
191
192
193def make_logs(model, logs, outputs, mode, prefix=''):
194  """Computes logs for sending to `on_batch_end` methods."""
195  metric_names = model.metrics_names
196  if mode in {ModeKeys.TRAIN, ModeKeys.TEST} and metric_names:
197    for label, output in zip(metric_names, outputs):
198      logs[prefix + label] = output
199  else:
200    logs['outputs'] = outputs
201  return logs
202
203
204@keras_export('keras.callbacks.CallbackList')
205class CallbackList(object):
206  """Container abstracting a list of callbacks."""
207
208  def __init__(self,
209               callbacks=None,
210               add_history=False,
211               add_progbar=False,
212               model=None,
213               **params):
214    """Container for `Callback` instances.
215
216    This object wraps a list of `Callback` instances, making it possible
217    to call them all at once via a single endpoint
218    (e.g. `callback_list.on_epoch_end(...)`).
219
220    Args:
221      callbacks: List of `Callback` instances.
222      add_history: Whether a `History` callback should be added, if one does not
223        already exist in the `callbacks` list.
224      add_progbar: Whether a `ProgbarLogger` callback should be added, if one
225        does not already exist in the `callbacks` list.
226      model: The `Model` these callbacks are used with.
227      **params: If provided, parameters will be passed to each `Callback` via
228        `Callback.set_params`.
229    """
230    self.callbacks = nest.flatten(callbacks) if callbacks else []
231    self._add_default_callbacks(add_history, add_progbar)
232
233    if model:
234      self.set_model(model)
235    if params:
236      self.set_params(params)
237
238    # Performance optimization: determines if batch hooks need to be called.
239    # pylint: disable=protected-access
240    self._should_call_train_batch_hooks = any(
241        cb._implements_train_batch_hooks() for cb in self.callbacks)
242    self._should_call_test_batch_hooks = any(
243        cb._implements_test_batch_hooks() for cb in self.callbacks)
244    self._should_call_predict_batch_hooks = any(
245        cb._implements_predict_batch_hooks() for cb in self.callbacks)
246    # pylint: enable=protected-access
247
248    # Performance check: Check batch hooks for slowness compared to batch time.
249    # Only run check for custom callbacks (i.e. not present in this file).
250    self._check_timing = any([cbk.__class__.__name__ not in globals()
251                              for cbk in self.callbacks])
252    self._num_batches_for_timing_check = 5
253    self._hook_times = {}
254    self._batch_start_time = None
255    self._batch_times = []
256
257  def _add_default_callbacks(self, add_history, add_progbar):
258    """Adds `Callback`s that are always present."""
259    self._progbar = None
260    self._history = None
261
262    for cb in self.callbacks:
263      if isinstance(cb, ProgbarLogger):
264        self._progbar = cb
265      elif isinstance(cb, History):
266        self._history = cb
267
268    if self._progbar is None and add_progbar:
269      self._progbar = ProgbarLogger(count_mode='steps')
270      self.callbacks.insert(0, self._progbar)
271
272    if self._history is None and add_history:
273      self._history = History()
274      self.callbacks.append(self._history)
275
276  def append(self, callback):
277    self.callbacks.append(callback)
278
279  def set_params(self, params):
280    self.params = params
281    for callback in self.callbacks:
282      callback.set_params(params)
283
284  def set_model(self, model):
285    self.model = model
286    if self._history:
287      model.history = self._history
288    for callback in self.callbacks:
289      callback.set_model(model)
290
291  def _call_batch_hook(self, mode, hook, batch, logs=None):
292    """Helper function for all batch_{begin | end} methods."""
293    if not self.callbacks:
294      return
295
296    if hook == 'begin':
297      self._call_batch_begin_hook(mode, batch, logs)
298    elif hook == 'end':
299      self._call_batch_end_hook(mode, batch, logs)
300    else:
301      raise ValueError('Unrecognized hook: {}'.format(hook))
302
303  def _call_batch_begin_hook(self, mode, batch, logs):
304    """Helper function for `on_*_batch_begin` methods."""
305    hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
306    self._call_batch_hook_helper(hook_name, batch, logs)
307
308    if self._check_timing:
309      self._batch_start_time = time.time()
310
311  def _call_batch_end_hook(self, mode, batch, logs):
312    """Helper function for `on_*_batch_end` methods."""
313    hook_name = 'on_{mode}_batch_end'.format(mode=mode)
314
315    if self._check_timing and batch >= 1:
316      batch_time = time.time() - self._batch_start_time
317      self._batch_times.append(batch_time)
318
319    self._call_batch_hook_helper(hook_name, batch, logs)
320
321    if len(self._batch_times) >= self._num_batches_for_timing_check:
322      end_hook_name = hook_name
323      begin_hook_name = 'on_{mode}_batch_begin'.format(mode=mode)
324      avg_batch_time = sum(self._batch_times) / len(self._batch_times)
325      avg_end_hook_time = sum(self._hook_times[end_hook_name]) / len(
326          self._hook_times[end_hook_name])
327      avg_begin_hook_time = sum(self._hook_times[begin_hook_name]) / len(
328          self._hook_times[begin_hook_name])
329
330      threshold_time = 1.0 * avg_batch_time
331      warning_msg = ('Callback method `{hook}` is slow compared to '
332                     'the batch time (batch time: {batch_time:.4f}s vs '
333                     '`{hook}` time: {hook_time:.4f}s). Check your callbacks.')
334      if avg_begin_hook_time > threshold_time:
335        logging.warning(warning_msg.format(
336            hook=begin_hook_name,
337            batch_time=avg_batch_time,
338            hook_time=avg_begin_hook_time))
339      if avg_end_hook_time > threshold_time:
340        logging.warning(warning_msg.format(
341            hook=end_hook_name,
342            batch_time=avg_batch_time,
343            hook_time=avg_end_hook_time))
344      self._check_timing = False
345      self._batch_start_time = None
346      self._batch_times = []
347      self._hook_times = {}
348
349  def _call_batch_hook_helper(self, hook_name, batch, logs):
350    """Helper function for `on_*_batch_*` methods."""
351    logs = logs or {}
352    numpy_logs = None
353    if self._check_timing:
354      start_time = time.time()
355
356    for callback in self.callbacks:
357      hook = getattr(callback, hook_name)
358      if getattr(callback, '_supports_tf_logs', False):
359        hook(batch, logs)
360      else:
361        if numpy_logs is None:  # Only convert once.
362          numpy_logs = tf_utils.to_numpy_or_python_type(logs)
363        hook(batch, numpy_logs)
364
365    if self._check_timing:
366      if hook_name not in self._hook_times:
367        self._hook_times[hook_name] = []
368      self._hook_times[hook_name].append(time.time() - start_time)
369
370  def _call_begin_hook(self, mode):
371    """Helper function for on_{train|test|predict}_begin methods."""
372    if mode == ModeKeys.TRAIN:
373      self.on_train_begin()
374    elif mode == ModeKeys.TEST:
375      self.on_test_begin()
376    else:
377      self.on_predict_begin()
378
379  def _call_end_hook(self, mode):
380    """Helper function for on_{train|test|predict}_end methods."""
381    if mode == ModeKeys.TRAIN:
382      self.on_train_end()
383    elif mode == ModeKeys.TEST:
384      self.on_test_end()
385    else:
386      self.on_predict_end()
387
388  def on_batch_begin(self, batch, logs=None):
389    if self._should_call_train_batch_hooks:
390      self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
391
392  def on_batch_end(self, batch, logs=None):
393    if self._should_call_train_batch_hooks:
394      self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
395
396  def on_epoch_begin(self, epoch, logs=None):
397    """Calls the `on_epoch_begin` methods of its callbacks.
398
399    This function should only be called during TRAIN mode.
400
401    Args:
402        epoch: Integer, index of epoch.
403        logs: Dict. Currently no data is passed to this argument for this method
404          but that may change in the future.
405    """
406    logs = logs or {}
407    numpy_logs = None
408    for callback in self.callbacks:
409      if getattr(callback, '_supports_tf_logs', False):
410        callback.on_epoch_begin(epoch, logs)
411      else:
412        if numpy_logs is None:  # Only convert once.
413          numpy_logs = tf_utils.to_numpy_or_python_type(logs)
414        callback.on_epoch_begin(epoch, numpy_logs)
415
416  def on_epoch_end(self, epoch, logs=None):
417    """Calls the `on_epoch_end` methods of its callbacks.
418
419    This function should only be called during TRAIN mode.
420
421    Args:
422        epoch: Integer, index of epoch.
423        logs: Dict, metric results for this training epoch, and for the
424          validation epoch if validation is performed. Validation result keys
425          are prefixed with `val_`.
426    """
427    logs = logs or {}
428    numpy_logs = None
429    for callback in self.callbacks:
430      if getattr(callback, '_supports_tf_logs', False):
431        callback.on_epoch_end(epoch, logs)
432      else:
433        if numpy_logs is None:  # Only convert once.
434          numpy_logs = tf_utils.to_numpy_or_python_type(logs)
435        callback.on_epoch_end(epoch, numpy_logs)
436
437  def on_train_batch_begin(self, batch, logs=None):
438    """Calls the `on_train_batch_begin` methods of its callbacks.
439
440    Args:
441        batch: Integer, index of batch within the current epoch.
442        logs: Dict, contains the return value of `model.train_step`. Typically,
443          the values of the `Model`'s metrics are returned.  Example:
444          `{'loss': 0.2, 'accuracy': 0.7}`.
445    """
446    if self._should_call_train_batch_hooks:
447      self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
448
449  def on_train_batch_end(self, batch, logs=None):
450    """Calls the `on_train_batch_end` methods of its callbacks.
451
452    Args:
453        batch: Integer, index of batch within the current epoch.
454        logs: Dict. Aggregated metric results up until this batch.
455    """
456    if self._should_call_train_batch_hooks:
457      self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
458
459  def on_test_batch_begin(self, batch, logs=None):
460    """Calls the `on_test_batch_begin` methods of its callbacks.
461
462    Args:
463        batch: Integer, index of batch within the current epoch.
464        logs: Dict, contains the return value of `model.test_step`. Typically,
465          the values of the `Model`'s metrics are returned.  Example:
466          `{'loss': 0.2, 'accuracy': 0.7}`.
467    """
468    if self._should_call_test_batch_hooks:
469      self._call_batch_hook(ModeKeys.TEST, 'begin', batch, logs=logs)
470
471  def on_test_batch_end(self, batch, logs=None):
472    """Calls the `on_test_batch_end` methods of its callbacks.
473
474    Args:
475        batch: Integer, index of batch within the current epoch.
476        logs: Dict. Aggregated metric results up until this batch.
477    """
478    if self._should_call_test_batch_hooks:
479      self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs)
480
481  def on_predict_batch_begin(self, batch, logs=None):
482    """Calls the `on_predict_batch_begin` methods of its callbacks.
483
484    Args:
485        batch: Integer, index of batch within the current epoch.
486        logs: Dict, contains the return value of `model.predict_step`,
487          it typically returns a dict with a key 'outputs' containing
488          the model's outputs.
489    """
490    if self._should_call_predict_batch_hooks:
491      self._call_batch_hook(ModeKeys.PREDICT, 'begin', batch, logs=logs)
492
493  def on_predict_batch_end(self, batch, logs=None):
494    """Calls the `on_predict_batch_end` methods of its callbacks.
495
496    Args:
497        batch: Integer, index of batch within the current epoch.
498        logs: Dict. Aggregated metric results up until this batch.
499    """
500    if self._should_call_predict_batch_hooks:
501      self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs)
502
503  def on_train_begin(self, logs=None):
504    """Calls the `on_train_begin` methods of its callbacks.
505
506    Args:
507        logs: Dict. Currently no data is passed to this argument for this method
508          but that may change in the future.
509    """
510    logs = logs or {}
511    numpy_logs = None
512    for callback in self.callbacks:
513      if getattr(callback, '_supports_tf_logs', False):
514        callback.on_train_begin(logs)
515      else:
516        if numpy_logs is None:  # Only convert once.
517          numpy_logs = tf_utils.to_numpy_or_python_type(logs)
518        callback.on_train_begin(numpy_logs)
519
520  def on_train_end(self, logs=None):
521    """Calls the `on_train_end` methods of its callbacks.
522
523    Args:
524        logs: Dict. Currently no data is passed to this argument for this method
525          but that may change in the future.
526    """
527    logs = logs or {}
528    numpy_logs = None
529    for callback in self.callbacks:
530      if getattr(callback, '_supports_tf_logs', False):
531        callback.on_train_end(logs)
532      else:
533        if numpy_logs is None:  # Only convert once.
534          numpy_logs = tf_utils.to_numpy_or_python_type(logs)
535        callback.on_train_end(numpy_logs)
536
537  def on_test_begin(self, logs=None):
538    """Calls the `on_test_begin` methods of its callbacks.
539
540    Args:
541        logs: Dict. Currently no data is passed to this argument for this method
542          but that may change in the future.
543    """
544    logs = logs or {}
545    numpy_logs = None
546    for callback in self.callbacks:
547      if getattr(callback, '_supports_tf_logs', False):
548        callback.on_test_begin(logs)
549      else:
550        if numpy_logs is None:  # Only convert once.
551          numpy_logs = tf_utils.to_numpy_or_python_type(logs)
552        callback.on_test_begin(numpy_logs)
553
554  def on_test_end(self, logs=None):
555    """Calls the `on_test_end` methods of its callbacks.
556
557    Args:
558        logs: Dict. Currently no data is passed to this argument for this method
559          but that may change in the future.
560    """
561    logs = logs or {}
562    numpy_logs = None
563    for callback in self.callbacks:
564      if getattr(callback, '_supports_tf_logs', False):
565        callback.on_test_end(logs)
566      else:
567        if numpy_logs is None:  # Only convert once.
568          numpy_logs = tf_utils.to_numpy_or_python_type(logs)
569        callback.on_test_end(numpy_logs)
570
571  def on_predict_begin(self, logs=None):
572    """Calls the 'on_predict_begin` methods of its callbacks.
573
574    Args:
575        logs: Dict. Currently no data is passed to this argument for this method
576          but that may change in the future.
577    """
578    logs = logs or {}
579    numpy_logs = None
580    for callback in self.callbacks:
581      if getattr(callback, '_supports_tf_logs', False):
582        callback.on_predict_begin(logs)
583      else:
584        if numpy_logs is None:  # Only convert once.
585          numpy_logs = tf_utils.to_numpy_or_python_type(logs)
586        callback.on_predict_begin(numpy_logs)
587
588  def on_predict_end(self, logs=None):
589    """Calls the `on_predict_end` methods of its callbacks.
590
591    Args:
592        logs: Dict. Currently no data is passed to this argument for this method
593          but that may change in the future.
594    """
595    logs = logs or {}
596    numpy_logs = None
597    for callback in self.callbacks:
598      if getattr(callback, '_supports_tf_logs', False):
599        callback.on_predict_end(logs)
600      else:
601        if numpy_logs is None:  # Only convert once.
602          numpy_logs = tf_utils.to_numpy_or_python_type(logs)
603        callback.on_predict_end(numpy_logs)
604
605  def __iter__(self):
606    return iter(self.callbacks)
607
608
609@keras_export('keras.callbacks.Callback')
610class Callback(object):
611  """Abstract base class used to build new callbacks.
612
613  Attributes:
614      params: Dict. Training parameters
615          (eg. verbosity, batch size, number of epochs...).
616      model: Instance of `keras.models.Model`.
617          Reference of the model being trained.
618
619  The `logs` dictionary that callback methods
620  take as argument will contain keys for quantities relevant to
621  the current batch or epoch (see method-specific docstrings).
622  """
623
624  def __init__(self):
625    self.validation_data = None  # pylint: disable=g-missing-from-attributes
626    self.model = None
627    # Whether this Callback should only run on the chief worker in a
628    # Multi-Worker setting.
629    # TODO(omalleyt): Make this attr public once solution is stable.
630    self._chief_worker_only = None
631    self._supports_tf_logs = False
632
633  def set_params(self, params):
634    self.params = params
635
636  def set_model(self, model):
637    self.model = model
638
639  @doc_controls.for_subclass_implementers
640  @generic_utils.default
641  def on_batch_begin(self, batch, logs=None):
642    """A backwards compatibility alias for `on_train_batch_begin`."""
643
644  @doc_controls.for_subclass_implementers
645  @generic_utils.default
646  def on_batch_end(self, batch, logs=None):
647    """A backwards compatibility alias for `on_train_batch_end`."""
648
649  @doc_controls.for_subclass_implementers
650  def on_epoch_begin(self, epoch, logs=None):
651    """Called at the start of an epoch.
652
653    Subclasses should override for any actions to run. This function should only
654    be called during TRAIN mode.
655
656    Args:
657        epoch: Integer, index of epoch.
658        logs: Dict. Currently no data is passed to this argument for this method
659          but that may change in the future.
660    """
661
662  @doc_controls.for_subclass_implementers
663  def on_epoch_end(self, epoch, logs=None):
664    """Called at the end of an epoch.
665
666    Subclasses should override for any actions to run. This function should only
667    be called during TRAIN mode.
668
669    Args:
670        epoch: Integer, index of epoch.
671        logs: Dict, metric results for this training epoch, and for the
672          validation epoch if validation is performed. Validation result keys
673          are prefixed with `val_`. For training epoch, the values of the
674         `Model`'s metrics are returned. Example : `{'loss': 0.2, 'accuracy':
675           0.7}`.
676    """
677
678  @doc_controls.for_subclass_implementers
679  @generic_utils.default
680  def on_train_batch_begin(self, batch, logs=None):
681    """Called at the beginning of a training batch in `fit` methods.
682
683    Subclasses should override for any actions to run.
684
685    Note that if the `steps_per_execution` argument to `compile` in
686    `tf.keras.Model` is set to `N`, this method will only be called every `N`
687    batches.
688
689    Args:
690        batch: Integer, index of batch within the current epoch.
691        logs: Dict, contains the return value of `model.train_step`. Typically,
692          the values of the `Model`'s metrics are returned.  Example:
693          `{'loss': 0.2, 'accuracy': 0.7}`.
694    """
695    # For backwards compatibility.
696    self.on_batch_begin(batch, logs=logs)
697
698  @doc_controls.for_subclass_implementers
699  @generic_utils.default
700  def on_train_batch_end(self, batch, logs=None):
701    """Called at the end of a training batch in `fit` methods.
702
703    Subclasses should override for any actions to run.
704
705    Note that if the `steps_per_execution` argument to `compile` in
706    `tf.keras.Model` is set to `N`, this method will only be called every `N`
707    batches.
708
709    Args:
710        batch: Integer, index of batch within the current epoch.
711        logs: Dict. Aggregated metric results up until this batch.
712    """
713    # For backwards compatibility.
714    self.on_batch_end(batch, logs=logs)
715
716  @doc_controls.for_subclass_implementers
717  @generic_utils.default
718  def on_test_batch_begin(self, batch, logs=None):
719    """Called at the beginning of a batch in `evaluate` methods.
720
721    Also called at the beginning of a validation batch in the `fit`
722    methods, if validation data is provided.
723
724    Subclasses should override for any actions to run.
725
726    Note that if the `steps_per_execution` argument to `compile` in
727    `tf.keras.Model` is set to `N`, this method will only be called every `N`
728    batches.
729
730    Args:
731        batch: Integer, index of batch within the current epoch.
732        logs: Dict, contains the return value of `model.test_step`. Typically,
733          the values of the `Model`'s metrics are returned.  Example:
734          `{'loss': 0.2, 'accuracy': 0.7}`.
735    """
736
737  @doc_controls.for_subclass_implementers
738  @generic_utils.default
739  def on_test_batch_end(self, batch, logs=None):
740    """Called at the end of a batch in `evaluate` methods.
741
742    Also called at the end of a validation batch in the `fit`
743    methods, if validation data is provided.
744
745    Subclasses should override for any actions to run.
746
747    Note that if the `steps_per_execution` argument to `compile` in
748    `tf.keras.Model` is set to `N`, this method will only be called every `N`
749    batches.
750
751    Args:
752        batch: Integer, index of batch within the current epoch.
753        logs: Dict. Aggregated metric results up until this batch.
754    """
755
756  @doc_controls.for_subclass_implementers
757  @generic_utils.default
758  def on_predict_batch_begin(self, batch, logs=None):
759    """Called at the beginning of a batch in `predict` methods.
760
761    Subclasses should override for any actions to run.
762
763    Note that if the `steps_per_execution` argument to `compile` in
764    `tf.keras.Model` is set to `N`, this method will only be called every `N`
765    batches.
766
767    Args:
768        batch: Integer, index of batch within the current epoch.
769        logs: Dict, contains the return value of `model.predict_step`,
770          it typically returns a dict with a key 'outputs' containing
771          the model's outputs.
772    """
773
774  @doc_controls.for_subclass_implementers
775  @generic_utils.default
776  def on_predict_batch_end(self, batch, logs=None):
777    """Called at the end of a batch in `predict` methods.
778
779    Subclasses should override for any actions to run.
780
781    Note that if the `steps_per_execution` argument to `compile` in
782    `tf.keras.Model` is set to `N`, this method will only be called every `N`
783    batches.
784
785    Args:
786        batch: Integer, index of batch within the current epoch.
787        logs: Dict. Aggregated metric results up until this batch.
788    """
789
790  @doc_controls.for_subclass_implementers
791  def on_train_begin(self, logs=None):
792    """Called at the beginning of training.
793
794    Subclasses should override for any actions to run.
795
796    Args:
797        logs: Dict. Currently no data is passed to this argument for this method
798          but that may change in the future.
799    """
800
801  @doc_controls.for_subclass_implementers
802  def on_train_end(self, logs=None):
803    """Called at the end of training.
804
805    Subclasses should override for any actions to run.
806
807    Args:
808        logs: Dict. Currently the output of the last call to `on_epoch_end()`
809          is passed to this argument for this method but that may change in
810          the future.
811    """
812
813  @doc_controls.for_subclass_implementers
814  def on_test_begin(self, logs=None):
815    """Called at the beginning of evaluation or validation.
816
817    Subclasses should override for any actions to run.
818
819    Args:
820        logs: Dict. Currently no data is passed to this argument for this method
821          but that may change in the future.
822    """
823
824  @doc_controls.for_subclass_implementers
825  def on_test_end(self, logs=None):
826    """Called at the end of evaluation or validation.
827
828    Subclasses should override for any actions to run.
829
830    Args:
831        logs: Dict. Currently the output of the last call to
832          `on_test_batch_end()` is passed to this argument for this method
833          but that may change in the future.
834    """
835
836  @doc_controls.for_subclass_implementers
837  def on_predict_begin(self, logs=None):
838    """Called at the beginning of prediction.
839
840    Subclasses should override for any actions to run.
841
842    Args:
843        logs: Dict. Currently no data is passed to this argument for this method
844          but that may change in the future.
845    """
846
847  @doc_controls.for_subclass_implementers
848  def on_predict_end(self, logs=None):
849    """Called at the end of prediction.
850
851    Subclasses should override for any actions to run.
852
853    Args:
854        logs: Dict. Currently no data is passed to this argument for this method
855          but that may change in the future.
856    """
857
858  def _implements_train_batch_hooks(self):
859    """Determines if this Callback should be called for each train batch."""
860    return (not generic_utils.is_default(self.on_batch_begin) or
861            not generic_utils.is_default(self.on_batch_end) or
862            not generic_utils.is_default(self.on_train_batch_begin) or
863            not generic_utils.is_default(self.on_train_batch_end))
864
865  def _implements_test_batch_hooks(self):
866    """Determines if this Callback should be called for each test batch."""
867    return (not generic_utils.is_default(self.on_test_batch_begin) or
868            not generic_utils.is_default(self.on_test_batch_end))
869
870  def _implements_predict_batch_hooks(self):
871    """Determines if this Callback should be called for each predict batch."""
872    return (not generic_utils.is_default(self.on_predict_batch_begin) or
873            not generic_utils.is_default(self.on_predict_batch_end))
874
875
876@keras_export('keras.callbacks.BaseLogger')
877class BaseLogger(Callback):
878  """Callback that accumulates epoch averages of metrics.
879
880  This callback is automatically applied to every Keras model.
881
882  Args:
883      stateful_metrics: Iterable of string names of metrics that
884          should *not* be averaged over an epoch.
885          Metrics in this list will be logged as-is in `on_epoch_end`.
886          All others will be averaged in `on_epoch_end`.
887  """
888
889  def __init__(self, stateful_metrics=None):
890    super(BaseLogger, self).__init__()
891    self.stateful_metrics = set(stateful_metrics or [])
892
893  def on_epoch_begin(self, epoch, logs=None):
894    self.seen = 0
895    self.totals = {}
896
897  def on_batch_end(self, batch, logs=None):
898    logs = logs or {}
899    batch_size = logs.get('size', 0)
900    # In case of distribution strategy we can potentially run multiple steps
901    # at the same time, we should account for that in the `seen` calculation.
902    num_steps = logs.get('num_steps', 1)
903    self.seen += batch_size * num_steps
904
905    for k, v in logs.items():
906      if k in self.stateful_metrics:
907        self.totals[k] = v
908      else:
909        if k in self.totals:
910          self.totals[k] += v * batch_size
911        else:
912          self.totals[k] = v * batch_size
913
914  def on_epoch_end(self, epoch, logs=None):
915    if logs is not None:
916      for k in self.params['metrics']:
917        if k in self.totals:
918          # Make value available to next callbacks.
919          if k in self.stateful_metrics:
920            logs[k] = self.totals[k]
921          else:
922            logs[k] = self.totals[k] / self.seen
923
924
925@keras_export('keras.callbacks.TerminateOnNaN')
926class TerminateOnNaN(Callback):
927  """Callback that terminates training when a NaN loss is encountered.
928  """
929
930  def __init__(self):
931    super(TerminateOnNaN, self).__init__()
932    self._supports_tf_logs = True
933
934  def on_batch_end(self, batch, logs=None):
935    logs = logs or {}
936    loss = logs.get('loss')
937    if loss is not None:
938      loss = tf_utils.to_numpy_or_python_type(loss)
939      if np.isnan(loss) or np.isinf(loss):
940        print('Batch %d: Invalid loss, terminating training' % (batch))
941        self.model.stop_training = True
942
943
944@keras_export('keras.callbacks.ProgbarLogger')
945class ProgbarLogger(Callback):
946  """Callback that prints metrics to stdout.
947
948  Args:
949      count_mode: One of `"steps"` or `"samples"`.
950          Whether the progress bar should
951          count samples seen or steps (batches) seen.
952      stateful_metrics: Iterable of string names of metrics that
953          should *not* be averaged over an epoch.
954          Metrics in this list will be logged as-is.
955          All others will be averaged over time (e.g. loss, etc).
956          If not provided, defaults to the `Model`'s metrics.
957
958  Raises:
959      ValueError: In case of invalid `count_mode`.
960  """
961
962  def __init__(self, count_mode='samples', stateful_metrics=None):
963    super(ProgbarLogger, self).__init__()
964    self._supports_tf_logs = True
965    if count_mode == 'samples':
966      self.use_steps = False
967    elif count_mode == 'steps':
968      self.use_steps = True
969    else:
970      raise ValueError('Unknown `count_mode`: ' + str(count_mode))
971    # Defaults to all Model's metrics except for loss.
972    self.stateful_metrics = set(stateful_metrics) if stateful_metrics else None
973
974    self.seen = 0
975    self.progbar = None
976    self.target = None
977    self.verbose = 1
978    self.epochs = 1
979
980    self._train_step, self._test_step, self._predict_step = None, None, None
981    self._call_batch_hooks = True
982
983    self._called_in_fit = False
984
985  def set_params(self, params):
986    self.verbose = params['verbose']
987    self.epochs = params['epochs']
988    if self.use_steps and 'steps' in params:
989      self.target = params['steps']
990    elif not self.use_steps and 'samples' in params:
991      self.target = params['samples']
992    else:
993      self.target = None  # Will be inferred at the end of the first epoch.
994
995    self._call_batch_hooks = self.verbose == 1
996    if self.target is None:
997      try:
998        self._train_step = self.model._train_counter  # pylint: disable=protected-access
999        self._test_step = self.model._test_counter  # pylint: disable=protected-access
1000        self._predict_step = self.model._predict_counter  # pylint: disable=protected-access
1001      except AttributeError:
1002        self._call_batch_hooks = True
1003
1004  def on_train_begin(self, logs=None):
1005    # When this logger is called inside `fit`, validation is silent.
1006    self._called_in_fit = True
1007
1008  def on_test_begin(self, logs=None):
1009    if not self._called_in_fit:
1010      self._reset_progbar()
1011      self._maybe_init_progbar()
1012
1013  def on_predict_begin(self, logs=None):
1014    self._reset_progbar()
1015    self._maybe_init_progbar()
1016
1017  def on_epoch_begin(self, epoch, logs=None):
1018    self._reset_progbar()
1019    self._maybe_init_progbar()
1020    if self.verbose and self.epochs > 1:
1021      print('Epoch %d/%d' % (epoch + 1, self.epochs))
1022
1023  def on_train_batch_end(self, batch, logs=None):
1024    self._batch_update_progbar(batch, logs)
1025
1026  def on_test_batch_end(self, batch, logs=None):
1027    if not self._called_in_fit:
1028      self._batch_update_progbar(batch, logs)
1029
1030  def on_predict_batch_end(self, batch, logs=None):
1031    # Don't pass prediction results.
1032    self._batch_update_progbar(batch, None)
1033
1034  def on_epoch_end(self, epoch, logs=None):
1035    self._finalize_progbar(logs, self._train_step)
1036
1037  def on_test_end(self, logs=None):
1038    if not self._called_in_fit:
1039      self._finalize_progbar(logs, self._test_step)
1040
1041  def on_predict_end(self, logs=None):
1042    self._finalize_progbar(logs, self._predict_step)
1043
1044  def _reset_progbar(self):
1045    self.seen = 0
1046    self.progbar = None
1047
1048  def _maybe_init_progbar(self):
1049    if self.stateful_metrics is None:
1050      if self.model:
1051        self.stateful_metrics = set(m.name for m in self.model.metrics)
1052      else:
1053        self.stateful_metrics = set()
1054
1055    if self.progbar is None:
1056      self.progbar = Progbar(
1057          target=self.target,
1058          verbose=self.verbose,
1059          stateful_metrics=self.stateful_metrics,
1060          unit_name='step' if self.use_steps else 'sample')
1061
1062  def _implements_train_batch_hooks(self):
1063    return self._call_batch_hooks
1064
1065  def _implements_test_batch_hooks(self):
1066    return self._call_batch_hooks
1067
1068  def _implements_predict_batch_hooks(self):
1069    return self._call_batch_hooks
1070
1071  def _batch_update_progbar(self, batch, logs=None):
1072    """Updates the progbar."""
1073    logs = logs or {}
1074    self._maybe_init_progbar()
1075    if self.use_steps:
1076      self.seen = batch + 1  # One-indexed.
1077    else:
1078      # v1 path only.
1079      logs = copy.copy(logs)
1080      batch_size = logs.pop('size', 0)
1081      num_steps = logs.pop('num_steps', 1)
1082      logs.pop('batch', None)
1083      add_seen = num_steps * batch_size
1084      self.seen += add_seen
1085
1086    if self.verbose == 1:
1087      # Only block async when verbose = 1.
1088      logs = tf_utils.to_numpy_or_python_type(logs)
1089      self.progbar.update(self.seen, list(logs.items()), finalize=False)
1090
1091  def _finalize_progbar(self, logs, counter):
1092    logs = tf_utils.to_numpy_or_python_type(logs or {})
1093    if self.target is None:
1094      if counter is not None:
1095        counter = counter.numpy()
1096        if not self.use_steps:
1097          counter *= logs.get('size', 1)
1098      self.target = counter or self.seen
1099      self.progbar.target = self.target
1100    self.progbar.update(self.target, list(logs.items()), finalize=True)
1101
1102
1103@keras_export('keras.callbacks.History')
1104class History(Callback):
1105  """Callback that records events into a `History` object.
1106
1107  This callback is automatically applied to
1108  every Keras model. The `History` object
1109  gets returned by the `fit` method of models.
1110  """
1111
1112  def __init__(self):
1113    super(History, self).__init__()
1114    self.history = {}
1115
1116  def on_train_begin(self, logs=None):
1117    self.epoch = []
1118
1119  def on_epoch_end(self, epoch, logs=None):
1120    logs = logs or {}
1121    self.epoch.append(epoch)
1122    for k, v in logs.items():
1123      self.history.setdefault(k, []).append(v)
1124
1125    # Set the history attribute on the model after the epoch ends. This will
1126    # make sure that the state which is set is the latest one.
1127    self.model.history = self
1128
1129
1130@keras_export('keras.callbacks.ModelCheckpoint')
1131class ModelCheckpoint(Callback):
1132  """Callback to save the Keras model or model weights at some frequency.
1133
1134  `ModelCheckpoint` callback is used in conjunction with training using
1135  `model.fit()` to save a model or weights (in a checkpoint file) at some
1136  interval, so the model or weights can be loaded later to continue the training
1137  from the state saved.
1138
1139  A few options this callback provides include:
1140
1141  - Whether to only keep the model that has achieved the "best performance" so
1142    far, or whether to save the model at the end of every epoch regardless of
1143    performance.
1144  - Definition of 'best'; which quantity to monitor and whether it should be
1145    maximized or minimized.
1146  - The frequency it should save at. Currently, the callback supports saving at
1147    the end of every epoch, or after a fixed number of training batches.
1148  - Whether only weights are saved, or the whole model is saved.
1149
1150  Note: If you get `WARNING:tensorflow:Can save best model only with <name>
1151  available, skipping` see the description of the `monitor` argument for
1152  details on how to get this right.
1153
1154  Example:
1155
1156  ```python
1157  model.compile(loss=..., optimizer=...,
1158                metrics=['accuracy'])
1159
1160  EPOCHS = 10
1161  checkpoint_filepath = '/tmp/checkpoint'
1162  model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
1163      filepath=checkpoint_filepath,
1164      save_weights_only=True,
1165      monitor='val_accuracy',
1166      mode='max',
1167      save_best_only=True)
1168
1169  # Model weights are saved at the end of every epoch, if it's the best seen
1170  # so far.
1171  model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
1172
1173  # The model weights (that are considered the best) are loaded into the model.
1174  model.load_weights(checkpoint_filepath)
1175  ```
1176
1177  Args:
1178      filepath: string or `PathLike`, path to save the model file. e.g.
1179        filepath = os.path.join(working_dir, 'ckpt', file_name). `filepath`
1180        can contain named formatting options, which will be filled the value of
1181        `epoch` and keys in `logs` (passed in `on_epoch_end`). For example: if
1182        `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model
1183        checkpoints will be saved with the epoch number and the validation loss
1184        in the filename. The directory of the filepath should not be reused by
1185        any other callbacks to avoid conflicts.
1186      monitor: The metric name to monitor. Typically the metrics are set by the
1187        `Model.compile` method. Note:
1188
1189        * Prefix the name with `"val_`" to monitor validation metrics.
1190        * Use `"loss"` or "`val_loss`" to monitor the model's total loss.
1191        * If you specify metrics as strings, like `"accuracy"`, pass the same
1192          string (with or without the `"val_"` prefix).
1193        * If you pass `metrics.Metric` objects, `monitor` should be set to
1194          `metric.name`
1195        * If you're not sure about the metric names you can check the contents
1196          of the `history.history` dictionary returned by
1197          `history = model.fit()`
1198        * Multi-output models set additional prefixes on the metric names.
1199
1200      verbose: verbosity mode, 0 or 1.
1201      save_best_only: if `save_best_only=True`, it only saves when the model
1202        is considered the "best" and the latest best model according to the
1203        quantity monitored will not be overwritten. If `filepath` doesn't
1204        contain formatting options like `{epoch}` then `filepath` will be
1205        overwritten by each new better model.
1206      mode: one of {'auto', 'min', 'max'}. If `save_best_only=True`, the
1207        decision to overwrite the current save file is made based on either
1208        the maximization or the minimization of the monitored quantity.
1209        For `val_acc`, this should be `max`, for `val_loss` this should be
1210        `min`, etc. In `auto` mode, the direction is automatically inferred
1211        from the name of the monitored quantity.
1212      save_weights_only: if True, then only the model's weights will be saved
1213        (`model.save_weights(filepath)`), else the full model is saved
1214        (`model.save(filepath)`).
1215      save_freq: `'epoch'` or integer. When using `'epoch'`, the callback saves
1216        the model after each epoch. When using integer, the callback saves the
1217        model at end of this many batches. If the `Model` is compiled with
1218        `steps_per_execution=N`, then the saving criteria will be
1219        checked every Nth batch. Note that if the saving isn't aligned to
1220        epochs, the monitored metric may potentially be less reliable (it
1221        could reflect as little as 1 batch, since the metrics get reset every
1222        epoch). Defaults to `'epoch'`.
1223      options: Optional `tf.train.CheckpointOptions` object if
1224        `save_weights_only` is true or optional `tf.saved_model.SaveOptions`
1225        object if `save_weights_only` is false.
1226      **kwargs: Additional arguments for backwards compatibility. Possible key
1227        is `period`.
1228  """
1229
1230  def __init__(self,
1231               filepath,
1232               monitor='val_loss',
1233               verbose=0,
1234               save_best_only=False,
1235               save_weights_only=False,
1236               mode='auto',
1237               save_freq='epoch',
1238               options=None,
1239               **kwargs):
1240    super(ModelCheckpoint, self).__init__()
1241    self._supports_tf_logs = True
1242    self.monitor = monitor
1243    self.verbose = verbose
1244    self.filepath = path_to_string(filepath)
1245    self.save_best_only = save_best_only
1246    self.save_weights_only = save_weights_only
1247    self.save_freq = save_freq
1248    self.epochs_since_last_save = 0
1249    self._batches_seen_since_last_saving = 0
1250    self._last_batch_seen = 0
1251
1252    if save_weights_only:
1253      if options is None or isinstance(
1254          options, checkpoint_options_lib.CheckpointOptions):
1255        self._options = options or checkpoint_options_lib.CheckpointOptions()
1256      else:
1257        raise TypeError('If save_weights_only is True, then `options` must be '
1258                        'either None or a tf.train.CheckpointOptions')
1259    else:
1260      if options is None or isinstance(options, save_options_lib.SaveOptions):
1261        self._options = options or save_options_lib.SaveOptions()
1262      else:
1263        raise TypeError('If save_weights_only is False, then `options` must be'
1264                        'either None or a tf.saved_model.SaveOptions')
1265
1266    # Deprecated field `load_weights_on_restart` is for loading the checkpoint
1267    # file from `filepath` at the start of `model.fit()`
1268    # TODO(rchao): Remove the arg during next breaking release.
1269    if 'load_weights_on_restart' in kwargs:
1270      self.load_weights_on_restart = kwargs['load_weights_on_restart']
1271      logging.warning('`load_weights_on_restart` argument is deprecated. '
1272                      'Please use `model.load_weights()` for loading weights '
1273                      'before the start of `model.fit()`.')
1274    else:
1275      self.load_weights_on_restart = False
1276
1277    # Deprecated field `period` is for the number of epochs between which
1278    # the model is saved.
1279    if 'period' in kwargs:
1280      self.period = kwargs['period']
1281      logging.warning('`period` argument is deprecated. Please use `save_freq` '
1282                      'to specify the frequency in number of batches seen.')
1283    else:
1284      self.period = 1
1285
1286    if mode not in ['auto', 'min', 'max']:
1287      logging.warning('ModelCheckpoint mode %s is unknown, '
1288                      'fallback to auto mode.', mode)
1289      mode = 'auto'
1290
1291    if mode == 'min':
1292      self.monitor_op = np.less
1293      self.best = np.Inf
1294    elif mode == 'max':
1295      self.monitor_op = np.greater
1296      self.best = -np.Inf
1297    else:
1298      if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
1299        self.monitor_op = np.greater
1300        self.best = -np.Inf
1301      else:
1302        self.monitor_op = np.less
1303        self.best = np.Inf
1304
1305    if self.save_freq != 'epoch' and not isinstance(self.save_freq, int):
1306      raise ValueError('Unrecognized save_freq: {}'.format(self.save_freq))
1307
1308    # Only the chief worker writes model checkpoints, but all workers
1309    # restore checkpoint at on_train_begin().
1310    self._chief_worker_only = False
1311
1312  def set_model(self, model):
1313    self.model = model
1314    # Use name matching rather than `isinstance` to avoid circular dependencies.
1315    if (not self.save_weights_only and
1316        not model._is_graph_network and  # pylint: disable=protected-access
1317        model.__class__.__name__ != 'Sequential'):
1318      self.save_weights_only = True
1319
1320  def on_train_begin(self, logs=None):
1321    if self.load_weights_on_restart:
1322      filepath_to_load = (
1323          self._get_most_recently_modified_file_matching_pattern(self.filepath))
1324      if (filepath_to_load is not None and
1325          self._checkpoint_exists(filepath_to_load)):
1326        try:
1327          # `filepath` may contain placeholders such as `{epoch:02d}`, and
1328          # thus it attempts to load the most recently modified file with file
1329          # name matching the pattern.
1330          self.model.load_weights(filepath_to_load)
1331        except (IOError, ValueError) as e:
1332          raise ValueError('Error loading file from {}. Reason: {}'.format(
1333              filepath_to_load, e))
1334
1335  def _implements_train_batch_hooks(self):
1336    # Only call batch hooks when saving on batch
1337    return self.save_freq != 'epoch'
1338
1339  def on_train_batch_end(self, batch, logs=None):
1340    if self._should_save_on_batch(batch):
1341      self._save_model(epoch=self._current_epoch, logs=logs)
1342
1343  def on_epoch_begin(self, epoch, logs=None):
1344    self._current_epoch = epoch
1345
1346  def on_epoch_end(self, epoch, logs=None):
1347    self.epochs_since_last_save += 1
1348    # pylint: disable=protected-access
1349    if self.save_freq == 'epoch':
1350      self._save_model(epoch=epoch, logs=logs)
1351
1352  def _should_save_on_batch(self, batch):
1353    """Handles batch-level saving logic, supports steps_per_execution."""
1354    if self.save_freq == 'epoch':
1355      return False
1356
1357    if batch <= self._last_batch_seen:  # New epoch.
1358      add_batches = batch + 1  # batches are zero-indexed.
1359    else:
1360      add_batches = batch - self._last_batch_seen
1361    self._batches_seen_since_last_saving += add_batches
1362    self._last_batch_seen = batch
1363
1364    if self._batches_seen_since_last_saving >= self.save_freq:
1365      self._batches_seen_since_last_saving = 0
1366      return True
1367    return False
1368
1369  def _save_model(self, epoch, logs):
1370    """Saves the model.
1371
1372    Args:
1373        epoch: the epoch this iteration is in.
1374        logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`.
1375    """
1376    logs = logs or {}
1377
1378    if isinstance(self.save_freq,
1379                  int) or self.epochs_since_last_save >= self.period:
1380      # Block only when saving interval is reached.
1381      logs = tf_utils.to_numpy_or_python_type(logs)
1382      self.epochs_since_last_save = 0
1383      filepath = self._get_file_path(epoch, logs)
1384
1385      try:
1386        if self.save_best_only:
1387          current = logs.get(self.monitor)
1388          if current is None:
1389            logging.warning('Can save best model only with %s available, '
1390                            'skipping.', self.monitor)
1391          else:
1392            if self.monitor_op(current, self.best):
1393              if self.verbose > 0:
1394                print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
1395                      ' saving model to %s' % (epoch + 1, self.monitor,
1396                                               self.best, current, filepath))
1397              self.best = current
1398              if self.save_weights_only:
1399                self.model.save_weights(
1400                    filepath, overwrite=True, options=self._options)
1401              else:
1402                self.model.save(filepath, overwrite=True, options=self._options)
1403            else:
1404              if self.verbose > 0:
1405                print('\nEpoch %05d: %s did not improve from %0.5f' %
1406                      (epoch + 1, self.monitor, self.best))
1407        else:
1408          if self.verbose > 0:
1409            print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
1410          if self.save_weights_only:
1411            self.model.save_weights(
1412                filepath, overwrite=True, options=self._options)
1413          else:
1414            self.model.save(filepath, overwrite=True, options=self._options)
1415
1416        self._maybe_remove_file()
1417      except IOError as e:
1418        # `e.errno` appears to be `None` so checking the content of `e.args[0]`.
1419        if 'is a directory' in six.ensure_str(e.args[0]).lower():
1420          raise IOError('Please specify a non-directory filepath for '
1421                        'ModelCheckpoint. Filepath used is an existing '
1422                        'directory: {}'.format(filepath))
1423        # Re-throw the error for any other causes.
1424        raise e
1425
1426  def _get_file_path(self, epoch, logs):
1427    """Returns the file path for checkpoint."""
1428    # pylint: disable=protected-access
1429    try:
1430      # `filepath` may contain placeholders such as `{epoch:02d}` and
1431      # `{mape:.2f}`. A mismatch between logged metrics and the path's
1432      # placeholders can cause formatting to fail.
1433      file_path = self.filepath.format(epoch=epoch + 1, **logs)
1434    except KeyError as e:
1435      raise KeyError('Failed to format this callback filepath: "{}". '
1436                     'Reason: {}'.format(self.filepath, e))
1437    self._write_filepath = distributed_file_utils.write_filepath(
1438        file_path, self.model.distribute_strategy)
1439    return self._write_filepath
1440
1441  def _maybe_remove_file(self):
1442    # Remove the checkpoint directory in multi-worker training where this worker
1443    # should not checkpoint. It is a dummy directory previously saved for sync
1444    # distributed training.
1445    distributed_file_utils.remove_temp_dir_with_filepath(
1446        self._write_filepath, self.model.distribute_strategy)
1447
1448  def _checkpoint_exists(self, filepath):
1449    """Returns whether the checkpoint `filepath` refers to exists."""
1450    if filepath.endswith('.h5'):
1451      return file_io.file_exists_v2(filepath)
1452    tf_saved_model_exists = file_io.file_exists_v2(filepath)
1453    tf_weights_only_checkpoint_exists = file_io.file_exists_v2(
1454        filepath + '.index')
1455    return tf_saved_model_exists or tf_weights_only_checkpoint_exists
1456
1457  def _get_most_recently_modified_file_matching_pattern(self, pattern):
1458    """Returns the most recently modified filepath matching pattern.
1459
1460    Pattern may contain python formatting placeholder. If
1461    `tf.train.latest_checkpoint()` does not return None, use that; otherwise,
1462    check for most recently modified one that matches the pattern.
1463
1464    In the rare case where there are more than one pattern-matching file having
1465    the same modified time that is most recent among all, return the filepath
1466    that is largest (by `>` operator, lexicographically using the numeric
1467    equivalents). This provides a tie-breaker when multiple files are most
1468    recent. Note that a larger `filepath` can sometimes indicate a later time of
1469    modification (for instance, when epoch/batch is used as formatting option),
1470    but not necessarily (when accuracy or loss is used). The tie-breaker is
1471    put in the logic as best effort to return the most recent, and to avoid
1472    undeterministic result.
1473
1474    Modified time of a file is obtained with `os.path.getmtime()`.
1475
1476    This utility function is best demonstrated via an example:
1477
1478    ```python
1479    file_pattern = 'f.batch{batch:02d}epoch{epoch:02d}.h5'
1480    test_dir = self.get_temp_dir()
1481    path_pattern = os.path.join(test_dir, file_pattern)
1482    file_paths = [
1483        os.path.join(test_dir, file_name) for file_name in
1484        ['f.batch03epoch02.h5', 'f.batch02epoch02.h5', 'f.batch01epoch01.h5']
1485    ]
1486    for file_path in file_paths:
1487      # Write something to each of the files
1488    self.assertEqual(
1489        _get_most_recently_modified_file_matching_pattern(path_pattern),
1490        file_paths[-1])
1491    ```
1492
1493    Args:
1494        pattern: The file pattern that may optionally contain python placeholder
1495            such as `{epoch:02d}`.
1496
1497    Returns:
1498        The most recently modified file's full filepath matching `pattern`. If
1499        `pattern` does not contain any placeholder, this returns the filepath
1500        that
1501        exactly matches `pattern`. Returns `None` if no match is found.
1502    """
1503    dir_name = os.path.dirname(pattern)
1504    base_name = os.path.basename(pattern)
1505    base_name_regex = '^' + re.sub(r'{.*}', r'.*', base_name) + '$'
1506
1507    # If tf.train.latest_checkpoint tells us there exists a latest checkpoint,
1508    # use that as it is more robust than `os.path.getmtime()`.
1509    latest_tf_checkpoint = checkpoint_management.latest_checkpoint(dir_name)
1510    if latest_tf_checkpoint is not None and re.match(
1511        base_name_regex, os.path.basename(latest_tf_checkpoint)):
1512      return latest_tf_checkpoint
1513
1514    latest_mod_time = 0
1515    file_path_with_latest_mod_time = None
1516    n_file_with_latest_mod_time = 0
1517    file_path_with_largest_file_name = None
1518
1519    if file_io.file_exists_v2(dir_name):
1520      for file_name in os.listdir(dir_name):
1521        # Only consider if `file_name` matches the pattern.
1522        if re.match(base_name_regex, file_name):
1523          file_path = os.path.join(dir_name, file_name)
1524          mod_time = os.path.getmtime(file_path)
1525          if (file_path_with_largest_file_name is None or
1526              file_path > file_path_with_largest_file_name):
1527            file_path_with_largest_file_name = file_path
1528          if mod_time > latest_mod_time:
1529            latest_mod_time = mod_time
1530            file_path_with_latest_mod_time = file_path
1531            # In the case a file with later modified time is found, reset
1532            # the counter for the number of files with latest modified time.
1533            n_file_with_latest_mod_time = 1
1534          elif mod_time == latest_mod_time:
1535            # In the case a file has modified time tied with the most recent,
1536            # increment the counter for the number of files with latest modified
1537            # time by 1.
1538            n_file_with_latest_mod_time += 1
1539
1540    if n_file_with_latest_mod_time == 1:
1541      # Return the sole file that has most recent modified time.
1542      return file_path_with_latest_mod_time
1543    else:
1544      # If there are more than one file having latest modified time, return
1545      # the file path with the largest file name.
1546      return file_path_with_largest_file_name
1547
1548
1549@keras_export('keras.callbacks.experimental.BackupAndRestore', v1=[])
1550class BackupAndRestore(Callback):
1551  """Callback to back up and restore the training state.
1552
1553  `BackupAndRestore` callback is intended to recover from interruptions that
1554  happened in the middle of a model.fit execution by backing up the
1555  training states in a temporary checkpoint file (based on TF CheckpointManager)
1556  at the end of each epoch. If training restarted before completion, the
1557  training state and model are restored to the most recently saved state at the
1558  beginning of a new model.fit() run.
1559  Note that user is responsible to bring jobs back up.
1560  This callback is important for the backup and restore mechanism for fault
1561  tolerance purpose. And the model to be restored from an previous checkpoint is
1562  expected to be the same as the one used to back up. If user changes arguments
1563  passed to compile or fit, the checkpoint saved for fault tolerance can become
1564  invalid.
1565
1566  Note:
1567  1. This callback is not compatible with disabling eager execution.
1568  2. A checkpoint is saved at the end of each epoch, when restoring we'll redo
1569  any partial work from an unfinished epoch in which the training got restarted
1570  (so the work done before a interruption doesn't affect the final model state).
1571  3. This works for both single worker and multi-worker mode, only
1572  MirroredStrategy and MultiWorkerMirroredStrategy are supported for now.
1573
1574  Example:
1575
1576  >>> class InterruptingCallback(tf.keras.callbacks.Callback):
1577  ...   def on_epoch_begin(self, epoch, logs=None):
1578  ...     if epoch == 4:
1579  ...       raise RuntimeError('Interrupting!')
1580  >>> callback = tf.keras.callbacks.experimental.BackupAndRestore(
1581  ... backup_dir="/tmp/backup")
1582  >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1583  >>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
1584  >>> try:
1585  ...   model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,
1586  ...             batch_size=1, callbacks=[callback, InterruptingCallback()],
1587  ...             verbose=0)
1588  ... except:
1589  ...   pass
1590  >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,
1591  ...             batch_size=1, callbacks=[callback], verbose=0)
1592  >>> # Only 6 more epochs are run, since first trainning got interrupted at
1593  >>> # zero-indexed epoch 4, second training will continue from 4 to 9.
1594  >>> len(history.history['loss'])
1595  6
1596
1597  Args:
1598      backup_dir: String, path to store the checkpoint.
1599        e.g. backup_dir = os.path.join(working_dir, 'backup')
1600        This is the directory in which the system stores temporary files to
1601        recover the model from jobs terminated unexpectedly. The directory
1602        cannot be reused elsewhere to store other files, e.g. by
1603        BackupAndRestore callback of another training, or by another callback
1604        (ModelCheckpoint) of the same training.
1605  """
1606
1607  def __init__(self, backup_dir):
1608    super(BackupAndRestore, self).__init__()
1609    self.backup_dir = backup_dir
1610    self._supports_tf_logs = True
1611    self._supported_strategies = (
1612        mirrored_strategy.MirroredStrategy,
1613        collective_all_reduce_strategy.CollectiveAllReduceStrategy,
1614        tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2)
1615
1616    if not context.executing_eagerly():
1617      if ops.inside_function():
1618        raise ValueError('This Callback\'s method contains Python state and '
1619                         'should be called outside of `tf.function`s.')
1620      else:  # Legacy graph mode:
1621        raise ValueError(
1622            'BackupAndRestore only supports eager mode. In graph '
1623            'mode, consider using ModelCheckpoint to manually save '
1624            'and restore weights with `model.load_weights()` and by '
1625            'providing `initial_epoch` in `model.fit()` for fault tolerance.')
1626
1627    # Only the chief worker writes model checkpoints, but all workers
1628    # restore checkpoint at on_train_begin().
1629    self._chief_worker_only = False
1630
1631  def on_train_begin(self, logs=None):
1632    # TrainingState is used to manage the training state needed for
1633    # failure-recovery of a worker in training.
1634    # pylint: disable=protected-access
1635
1636    if self.model._distribution_strategy and not isinstance(
1637        self.model.distribute_strategy, self._supported_strategies):
1638      raise NotImplementedError(
1639          '%s is not supported yet. '
1640          'Currently BackupAndRestore callback only supports empty strategy, '
1641          'MirroredStrategy, MultiWorkerMirroredStrategy and TPUStrategy.' %
1642          type(self.model.distribute_strategy).__name__)
1643    self.model._training_state = (
1644        worker_training_state.WorkerTrainingState(self.model, self.backup_dir))
1645    self._training_state = self.model._training_state
1646    self._training_state.restore()
1647
1648  def on_train_end(self, logs=None):
1649    # pylint: disable=protected-access
1650    # On exit of training, delete the training state backup file that was saved
1651    # for the purpose of worker recovery.
1652    self._training_state.delete_backup()
1653
1654    # Clean up the training state.
1655    del self._training_state
1656    del self.model._training_state
1657
1658  def on_epoch_end(self, epoch, logs=None):
1659    # Back up the model and current epoch for possible future recovery.
1660    self._training_state.back_up(epoch)
1661
1662
1663@keras_export('keras.callbacks.EarlyStopping')
1664class EarlyStopping(Callback):
1665  """Stop training when a monitored metric has stopped improving.
1666
1667  Assuming the goal of a training is to minimize the loss. With this, the
1668  metric to be monitored would be `'loss'`, and mode would be `'min'`. A
1669  `model.fit()` training loop will check at end of every epoch whether
1670  the loss is no longer decreasing, considering the `min_delta` and
1671  `patience` if applicable. Once it's found no longer decreasing,
1672  `model.stop_training` is marked True and the training terminates.
1673
1674  The quantity to be monitored needs to be available in `logs` dict.
1675  To make it so, pass the loss or metrics at `model.compile()`.
1676
1677  Args:
1678    monitor: Quantity to be monitored.
1679    min_delta: Minimum change in the monitored quantity
1680        to qualify as an improvement, i.e. an absolute
1681        change of less than min_delta, will count as no
1682        improvement.
1683    patience: Number of epochs with no improvement
1684        after which training will be stopped.
1685    verbose: verbosity mode.
1686    mode: One of `{"auto", "min", "max"}`. In `min` mode,
1687        training will stop when the quantity
1688        monitored has stopped decreasing; in `"max"`
1689        mode it will stop when the quantity
1690        monitored has stopped increasing; in `"auto"`
1691        mode, the direction is automatically inferred
1692        from the name of the monitored quantity.
1693    baseline: Baseline value for the monitored quantity.
1694        Training will stop if the model doesn't show improvement over the
1695        baseline.
1696    restore_best_weights: Whether to restore model weights from
1697        the epoch with the best value of the monitored quantity.
1698        If False, the model weights obtained at the last step of
1699        training are used. An epoch will be restored regardless
1700        of the performance relative to the `baseline`. If no epoch
1701        improves on `baseline`, training will run for `patience`
1702        epochs and restore weights from the best epoch in that set.
1703
1704  Example:
1705
1706  >>> callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
1707  >>> # This callback will stop the training when there is no improvement in
1708  >>> # the loss for three consecutive epochs.
1709  >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1710  >>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
1711  >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
1712  ...                     epochs=10, batch_size=1, callbacks=[callback],
1713  ...                     verbose=0)
1714  >>> len(history.history['loss'])  # Only 4 epochs are run.
1715  4
1716  """
1717
1718  def __init__(self,
1719               monitor='val_loss',
1720               min_delta=0,
1721               patience=0,
1722               verbose=0,
1723               mode='auto',
1724               baseline=None,
1725               restore_best_weights=False):
1726    super(EarlyStopping, self).__init__()
1727
1728    self.monitor = monitor
1729    self.patience = patience
1730    self.verbose = verbose
1731    self.baseline = baseline
1732    self.min_delta = abs(min_delta)
1733    self.wait = 0
1734    self.stopped_epoch = 0
1735    self.restore_best_weights = restore_best_weights
1736    self.best_weights = None
1737
1738    if mode not in ['auto', 'min', 'max']:
1739      logging.warning('EarlyStopping mode %s is unknown, '
1740                      'fallback to auto mode.', mode)
1741      mode = 'auto'
1742
1743    if mode == 'min':
1744      self.monitor_op = np.less
1745    elif mode == 'max':
1746      self.monitor_op = np.greater
1747    else:
1748      if 'acc' in self.monitor:
1749        self.monitor_op = np.greater
1750      else:
1751        self.monitor_op = np.less
1752
1753    if self.monitor_op == np.greater:
1754      self.min_delta *= 1
1755    else:
1756      self.min_delta *= -1
1757
1758  def on_train_begin(self, logs=None):
1759    # Allow instances to be re-used
1760    self.wait = 0
1761    self.stopped_epoch = 0
1762    self.best = np.Inf if self.monitor_op == np.less else -np.Inf
1763    self.best_weights = None
1764
1765  def on_epoch_end(self, epoch, logs=None):
1766    current = self.get_monitor_value(logs)
1767    if current is None:
1768      return
1769    if self.restore_best_weights and self.best_weights is None:
1770      # Restore the weights after first epoch if no progress is ever made.
1771      self.best_weights = self.model.get_weights()
1772
1773    self.wait += 1
1774    if self._is_improvement(current, self.best):
1775      self.best = current
1776      if self.restore_best_weights:
1777        self.best_weights = self.model.get_weights()
1778      # Only restart wait if we beat both the baseline and our previous best.
1779      if self.baseline is None or self._is_improvement(current, self.baseline):
1780        self.wait = 0
1781
1782    if self.wait >= self.patience:
1783      self.stopped_epoch = epoch
1784      self.model.stop_training = True
1785      if self.restore_best_weights and self.best_weights is not None:
1786        if self.verbose > 0:
1787          print('Restoring model weights from the end of the best epoch.')
1788        self.model.set_weights(self.best_weights)
1789
1790  def on_train_end(self, logs=None):
1791    if self.stopped_epoch > 0 and self.verbose > 0:
1792      print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
1793
1794  def get_monitor_value(self, logs):
1795    logs = logs or {}
1796    monitor_value = logs.get(self.monitor)
1797    if monitor_value is None:
1798      logging.warning('Early stopping conditioned on metric `%s` '
1799                      'which is not available. Available metrics are: %s',
1800                      self.monitor, ','.join(list(logs.keys())))
1801    return monitor_value
1802
1803  def _is_improvement(self, monitor_value, reference_value):
1804    return self.monitor_op(monitor_value - self.min_delta, reference_value)
1805
1806
1807@keras_export('keras.callbacks.RemoteMonitor')
1808class RemoteMonitor(Callback):
1809  """Callback used to stream events to a server.
1810
1811  Requires the `requests` library.
1812  Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
1813  HTTP POST, with a `data` argument which is a
1814  JSON-encoded dictionary of event data.
1815  If `send_as_json=True`, the content type of the request will be
1816  `"application/json"`.
1817  Otherwise the serialized JSON will be sent within a form.
1818
1819  Args:
1820    root: String; root url of the target server.
1821    path: String; path relative to `root` to which the events will be sent.
1822    field: String; JSON field under which the data will be stored.
1823        The field is used only if the payload is sent within a form
1824        (i.e. send_as_json is set to False).
1825    headers: Dictionary; optional custom HTTP headers.
1826    send_as_json: Boolean; whether the request should be
1827        sent as `"application/json"`.
1828  """
1829
1830  def __init__(self,
1831               root='http://localhost:9000',
1832               path='/publish/epoch/end/',
1833               field='data',
1834               headers=None,
1835               send_as_json=False):
1836    super(RemoteMonitor, self).__init__()
1837
1838    self.root = root
1839    self.path = path
1840    self.field = field
1841    self.headers = headers
1842    self.send_as_json = send_as_json
1843
1844  def on_epoch_end(self, epoch, logs=None):
1845    if requests is None:
1846      raise ImportError('RemoteMonitor requires the `requests` library.')
1847    logs = logs or {}
1848    send = {}
1849    send['epoch'] = epoch
1850    for k, v in logs.items():
1851      # np.ndarray and np.generic are not scalar types
1852      # therefore we must unwrap their scalar values and
1853      # pass to the json-serializable dict 'send'
1854      if isinstance(v, (np.ndarray, np.generic)):
1855        send[k] = v.item()
1856      else:
1857        send[k] = v
1858    try:
1859      if self.send_as_json:
1860        requests.post(self.root + self.path, json=send, headers=self.headers)
1861      else:
1862        requests.post(
1863            self.root + self.path, {self.field: json.dumps(send)},
1864            headers=self.headers)
1865    except requests.exceptions.RequestException:
1866      logging.warning('Warning: could not reach RemoteMonitor '
1867                      'root server at ' + str(self.root))
1868
1869
1870@keras_export('keras.callbacks.LearningRateScheduler')
1871class LearningRateScheduler(Callback):
1872  """Learning rate scheduler.
1873
1874  At the beginning of every epoch, this callback gets the updated learning rate
1875  value from `schedule` function provided at `__init__`, with the current epoch
1876  and current learning rate, and applies the updated learning rate
1877  on the optimizer.
1878
1879  Args:
1880    schedule: a function that takes an epoch index (integer, indexed from 0)
1881        and current learning rate (float) as inputs and returns a new
1882        learning rate as output (float).
1883    verbose: int. 0: quiet, 1: update messages.
1884
1885  Example:
1886
1887  >>> # This function keeps the initial learning rate for the first ten epochs
1888  >>> # and decreases it exponentially after that.
1889  >>> def scheduler(epoch, lr):
1890  ...   if epoch < 10:
1891  ...     return lr
1892  ...   else:
1893  ...     return lr * tf.math.exp(-0.1)
1894  >>>
1895  >>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1896  >>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
1897  >>> round(model.optimizer.lr.numpy(), 5)
1898  0.01
1899
1900  >>> callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
1901  >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
1902  ...                     epochs=15, callbacks=[callback], verbose=0)
1903  >>> round(model.optimizer.lr.numpy(), 5)
1904  0.00607
1905
1906  """
1907
1908  def __init__(self, schedule, verbose=0):
1909    super(LearningRateScheduler, self).__init__()
1910    self.schedule = schedule
1911    self.verbose = verbose
1912
1913  def on_epoch_begin(self, epoch, logs=None):
1914    if not hasattr(self.model.optimizer, 'lr'):
1915      raise ValueError('Optimizer must have a "lr" attribute.')
1916    try:  # new API
1917      lr = float(K.get_value(self.model.optimizer.lr))
1918      lr = self.schedule(epoch, lr)
1919    except TypeError:  # Support for old API for backward compatibility
1920      lr = self.schedule(epoch)
1921    if not isinstance(lr, (ops.Tensor, float, np.float32, np.float64)):
1922      raise ValueError('The output of the "schedule" function '
1923                       'should be float.')
1924    if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating:
1925      raise ValueError('The dtype of Tensor should be float')
1926    K.set_value(self.model.optimizer.lr, K.get_value(lr))
1927    if self.verbose > 0:
1928      print('\nEpoch %05d: LearningRateScheduler reducing learning '
1929            'rate to %s.' % (epoch + 1, lr))
1930
1931  def on_epoch_end(self, epoch, logs=None):
1932    logs = logs or {}
1933    logs['lr'] = K.get_value(self.model.optimizer.lr)
1934
1935
1936def keras_model_summary(name, data, step=None):
1937  """Writes a Keras model as JSON to as a Summary.
1938
1939  Writing the Keras model configuration allows the TensorBoard graph plugin to
1940  render a conceptual graph, as opposed to graph of ops. In case the model fails
1941  to serialize as JSON, it ignores and returns False.
1942
1943  Args:
1944    name: A name for this summary. The summary tag used for TensorBoard will be
1945      this name prefixed by any active name scopes.
1946    data: A Keras Model to write.
1947    step: Explicit `int64`-castable monotonic step value for this summary. If
1948      omitted, this defaults to `tf.summary.experimental.get_step()`, which must
1949      not be None.
1950
1951  Returns:
1952    True on success, or False if no summary was written because no default
1953    summary writer was available.
1954
1955  Raises:
1956    ValueError: if a default writer exists, but no step was provided and
1957      `tf.summary.experimental.get_step()` is None.
1958  """
1959  summary_metadata = summary_pb2.SummaryMetadata()
1960  # Hard coding a plugin name. Please refer to go/tb-plugin-name-hardcode for
1961  # the rationale.
1962  summary_metadata.plugin_data.plugin_name = 'graph_keras_model'
1963  # version number = 1
1964  summary_metadata.plugin_data.content = b'1'
1965
1966  try:
1967    json_string = data.to_json()
1968  except Exception as exc:  # pylint: disable=broad-except
1969    # An exception should not break a model code.
1970    logging.warn('Model failed to serialize as JSON. Ignoring... %s', exc)
1971    return False
1972
1973  with summary_ops_v2.summary_scope(name, 'graph_keras_model',
1974                                    [data, step]) as (tag, _):
1975    with ops.device('cpu:0'):
1976      tensor = constant_op.constant(json_string, dtype=dtypes.string)
1977    return summary_ops_v2.write(
1978        tag=tag, tensor=tensor, step=step, metadata=summary_metadata)
1979
1980
1981@keras_export('keras.callbacks.TensorBoard', v1=[])
1982class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
1983  # pylint: disable=line-too-long
1984  """Enable visualizations for TensorBoard.
1985
1986  TensorBoard is a visualization tool provided with TensorFlow.
1987
1988  This callback logs events for TensorBoard, including:
1989
1990  * Metrics summary plots
1991  * Training graph visualization
1992  * Activation histograms
1993  * Sampled profiling
1994
1995  If you have installed TensorFlow with pip, you should be able
1996  to launch TensorBoard from the command line:
1997
1998  ```
1999  tensorboard --logdir=path_to_your_logs
2000  ```
2001
2002  You can find more information about TensorBoard
2003  [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
2004
2005  Args:
2006      log_dir: the path of the directory where to save the log files to be
2007        parsed by TensorBoard. e.g. log_dir = os.path.join(working_dir, 'logs')
2008        This directory should not be reused by any other callbacks.
2009      histogram_freq: frequency (in epochs) at which to compute activation and
2010        weight histograms for the layers of the model. If set to 0, histograms
2011        won't be computed. Validation data (or split) must be specified for
2012        histogram visualizations.
2013      write_graph: whether to visualize the graph in TensorBoard. The log file
2014        can become quite large when write_graph is set to True.
2015      write_images: whether to write model weights to visualize as image in
2016        TensorBoard.
2017      write_steps_per_second: whether to log the training steps per second into
2018        Tensorboard. This supports both epoch and batch frequency logging.
2019      update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`,
2020        writes the losses and metrics to TensorBoard after each batch. The same
2021        applies for `'epoch'`. If using an integer, let's say `1000`, the
2022        callback will write the metrics and losses to TensorBoard every 1000
2023        batches. Note that writing too frequently to TensorBoard can slow down
2024        your training.
2025      profile_batch: Profile the batch(es) to sample compute characteristics.
2026        profile_batch must be a non-negative integer or a tuple of integers.
2027        A pair of positive integers signify a range of batches to profile.
2028        By default, it will profile the second batch. Set profile_batch=0
2029        to disable profiling.
2030      embeddings_freq: frequency (in epochs) at which embedding layers will be
2031        visualized. If set to 0, embeddings won't be visualized.
2032      embeddings_metadata: a dictionary which maps layer name to a file name in
2033        which metadata for this embedding layer is saved. See the
2034        [details](
2035          https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
2036        about metadata files format. In case if the same metadata file is
2037        used for all embedding layers, string can be passed.
2038
2039  Examples:
2040
2041  Basic usage:
2042
2043  ```python
2044  tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
2045  model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
2046  # Then run the tensorboard command to view the visualizations.
2047  ```
2048
2049  Custom batch-level summaries in a subclassed Model:
2050
2051  ```python
2052  class MyModel(tf.keras.Model):
2053
2054    def build(self, _):
2055      self.dense = tf.keras.layers.Dense(10)
2056
2057    def call(self, x):
2058      outputs = self.dense(x)
2059      tf.summary.histogram('outputs', outputs)
2060      return outputs
2061
2062  model = MyModel()
2063  model.compile('sgd', 'mse')
2064
2065  # Make sure to set `update_freq=N` to log a batch-level summary every N batches.
2066  # In addition to any `tf.summary` contained in `Model.call`, metrics added in
2067  # `Model.compile` will be logged every N batches.
2068  tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1)
2069  model.fit(x_train, y_train, callbacks=[tb_callback])
2070  ```
2071
2072  Custom batch-level summaries in a Functional API Model:
2073
2074  ```python
2075  def my_summary(x):
2076    tf.summary.histogram('x', x)
2077    return x
2078
2079  inputs = tf.keras.Input(10)
2080  x = tf.keras.layers.Dense(10)(inputs)
2081  outputs = tf.keras.layers.Lambda(my_summary)(x)
2082  model = tf.keras.Model(inputs, outputs)
2083  model.compile('sgd', 'mse')
2084
2085  # Make sure to set `update_freq=N` to log a batch-level summary every N batches.
2086  # In addition to any `tf.summary` contained in `Model.call`, metrics added in
2087  # `Model.compile` will be logged every N batches.
2088  tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1)
2089  model.fit(x_train, y_train, callbacks=[tb_callback])
2090  ```
2091
2092  Profiling:
2093
2094  ```python
2095  # Profile a single batch, e.g. the 5th batch.
2096  tensorboard_callback = tf.keras.callbacks.TensorBoard(
2097      log_dir='./logs', profile_batch=5)
2098  model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
2099
2100  # Profile a range of batches, e.g. from 10 to 20.
2101  tensorboard_callback = tf.keras.callbacks.TensorBoard(
2102      log_dir='./logs', profile_batch=(10,20))
2103  model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
2104  ```
2105  """
2106
2107  # pylint: enable=line-too-long
2108
2109  def __init__(self,
2110               log_dir='logs',
2111               histogram_freq=0,
2112               write_graph=True,
2113               write_images=False,
2114               write_steps_per_second=False,
2115               update_freq='epoch',
2116               profile_batch=2,
2117               embeddings_freq=0,
2118               embeddings_metadata=None,
2119               **kwargs):
2120    super(TensorBoard, self).__init__()
2121    self._supports_tf_logs = True
2122    self._validate_kwargs(kwargs)
2123
2124    self.log_dir = path_to_string(log_dir)
2125    self.histogram_freq = histogram_freq
2126    self.write_graph = write_graph
2127    self.write_images = write_images
2128    self.write_steps_per_second = write_steps_per_second
2129    self.update_freq = 1 if update_freq == 'batch' else update_freq
2130    self.embeddings_freq = embeddings_freq
2131    self.embeddings_metadata = embeddings_metadata
2132    self._init_profile_batch(profile_batch)
2133    self._epoch = 0
2134    self._global_train_batch = 0
2135    self._previous_epoch_iterations = 0
2136    self._train_accumulated_time = 0
2137    self._batch_start_time = 0
2138
2139    # Lazily initialized in order to avoid creating event files when
2140    # not needed.
2141    self._writers = {}
2142
2143    # Used to restore any existing `SummaryWriter` after training ends.
2144    self._prev_summary_state = []
2145
2146  def _validate_kwargs(self, kwargs):
2147    """Handle arguments were supported in V1."""
2148    if kwargs.get('write_grads', False):
2149      logging.warning('`write_grads` will be ignored in TensorFlow 2.0 '
2150                      'for the `TensorBoard` Callback.')
2151    if kwargs.get('batch_size', False):
2152      logging.warning('`batch_size` is no longer needed in the '
2153                      '`TensorBoard` Callback and will be ignored '
2154                      'in TensorFlow 2.0.')
2155    if kwargs.get('embeddings_layer_names', False):
2156      logging.warning('`embeddings_layer_names` is not supported in '
2157                      'TensorFlow 2.0. Instead, all `Embedding` layers '
2158                      'will be visualized.')
2159    if kwargs.get('embeddings_data', False):
2160      logging.warning('`embeddings_data` is not supported in TensorFlow '
2161                      '2.0. Instead, all `Embedding` variables will be '
2162                      'visualized.')
2163
2164    unrecognized_kwargs = set(kwargs.keys()) - {
2165        'write_grads', 'embeddings_layer_names', 'embeddings_data', 'batch_size'
2166    }
2167
2168    # Only allow kwargs that were supported in V1.
2169    if unrecognized_kwargs:
2170      raise ValueError('Unrecognized arguments in `TensorBoard` '
2171                       'Callback: ' + str(unrecognized_kwargs))
2172
2173  def set_model(self, model):
2174    """Sets Keras model and writes graph if specified."""
2175    self.model = model
2176    self._log_write_dir = self._get_log_write_dir()
2177
2178    self._train_dir = os.path.join(self._log_write_dir, 'train')
2179    self._train_step = self.model._train_counter  # pylint: disable=protected-access
2180
2181    self._val_dir = os.path.join(self._log_write_dir, 'validation')
2182    self._val_step = self.model._test_counter  # pylint: disable=protected-access
2183
2184    self._writers = {}  # Resets writers.
2185
2186    self._should_write_train_graph = False
2187    if self.write_graph:
2188      self._write_keras_model_summary()
2189      self._should_write_train_graph = True
2190    if self.embeddings_freq:
2191      self._configure_embeddings()
2192
2193  @property
2194  def _train_writer(self):
2195    if 'train' not in self._writers:
2196      self._writers['train'] = summary_ops_v2.create_file_writer_v2(
2197          self._train_dir)
2198    return self._writers['train']
2199
2200  @property
2201  def _val_writer(self):
2202    if 'val' not in self._writers:
2203      self._writers['val'] = summary_ops_v2.create_file_writer_v2(self._val_dir)
2204    return self._writers['val']
2205
2206  def _get_log_write_dir(self):
2207    """For multi-worker, only chief should write, others write to '/tmp'."""
2208    return distributed_file_utils.write_dirpath(self.log_dir,
2209                                                self.model.distribute_strategy)
2210
2211  def _delete_tmp_write_dir(self):
2212    """Deletes tmp write directories for multi-worker."""
2213    distributed_file_utils.remove_temp_dirpath(self.log_dir,
2214                                               self.model.distribute_strategy)
2215
2216  def _write_keras_model_train_graph(self):
2217    """Writes Keras model train_function graph to TensorBoard."""
2218    with self._train_writer.as_default():
2219      with summary_ops_v2.record_if(True):
2220        train_fn = self.model.train_function
2221        # If the train_function is a `tf.function`, we can write out a graph
2222        if hasattr(train_fn, 'function_spec'):
2223          summary_ops_v2.graph(train_fn._concrete_stateful_fn.graph)  # pylint: disable=protected-access
2224
2225  def _write_keras_model_summary(self):
2226    """Writes Keras graph network summary to TensorBoard."""
2227    with self._train_writer.as_default():
2228      with summary_ops_v2.record_if(True):
2229        summary_writable = (
2230            self.model._is_graph_network or  # pylint: disable=protected-access
2231            self.model.__class__.__name__ == 'Sequential')  # pylint: disable=protected-access
2232        if summary_writable:
2233          keras_model_summary('keras', self.model, step=0)
2234
2235  def _configure_embeddings(self):
2236    """Configure the Projector for embeddings."""
2237    # TODO(omalleyt): Add integration tests.
2238    from google.protobuf import text_format
2239    from tensorflow.python.keras.layers import embeddings
2240    from tensorflow.python.keras.protobuf import projector_config_pb2
2241
2242    config = projector_config_pb2.ProjectorConfig()
2243    for layer in self.model.layers:
2244      if isinstance(layer, embeddings.Embedding):
2245        embedding = config.embeddings.add()
2246        # Embeddings are always the first layer, so this naming should be
2247        # consistent in any keras models checkpoints.
2248        name = 'layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE'
2249        embedding.tensor_name = name
2250
2251        if self.embeddings_metadata is not None:
2252          if isinstance(self.embeddings_metadata, str):
2253            embedding.metadata_path = self.embeddings_metadata
2254          else:
2255            if layer.name in self.embeddings_metadata.keys():
2256              embedding.metadata_path = self.embeddings_metadata.pop(layer.name)
2257
2258    if self.embeddings_metadata and not isinstance(self.embeddings_metadata,
2259                                                   str):
2260      raise ValueError('Unrecognized `Embedding` layer names passed to '
2261                       '`keras.callbacks.TensorBoard` `embeddings_metadata` '
2262                       'argument: ' + str(self.embeddings_metadata.keys()))
2263
2264    config_pbtxt = text_format.MessageToString(config)
2265    path = os.path.join(self._log_write_dir, 'projector_config.pbtxt')
2266    with gfile.Open(path, 'w') as f:
2267      f.write(config_pbtxt)
2268
2269  def _push_writer(self, writer, step):
2270    """Sets the default writer for custom batch-level summaries."""
2271    if self.update_freq == 'epoch':
2272      return
2273
2274    should_record = lambda: math_ops.equal(step % self.update_freq, 0)
2275    # TODO(b/151339474): Fix deadlock when not using .value() here.
2276    summary_context = (writer.as_default(step.value()),
2277                       summary_ops_v2.record_if(should_record))
2278    self._prev_summary_state.append(summary_context)
2279    summary_context[0].__enter__()
2280    summary_context[1].__enter__()
2281
2282  def _pop_writer(self):
2283    """Pops the current writer."""
2284    if self.update_freq == 'epoch':
2285      return
2286
2287    # See _push_writer for the content of the previous_context, which is pair
2288    # of context.
2289    previous_context = self._prev_summary_state.pop()
2290    previous_context[1].__exit__(*sys.exc_info())
2291    previous_context[0].__exit__(*sys.exc_info())
2292
2293  def _close_writers(self):
2294    for writer in self._writers.values():
2295      writer.close()
2296
2297  def _init_profile_batch(self, profile_batch):
2298    """Validate profile_batch value and set the range of batches to profile.
2299
2300    Args:
2301      profile_batch: The range of batches to profile. Should be a non-negative
2302        integer or a comma separated string of pair of positive integers. A pair
2303        of positive integers signify a range of batches to profile.
2304
2305    Returns:
2306      A pair of non-negative integers specifying the start and stop batch to
2307      profile.
2308
2309    Raises:
2310      ValueError: If profile_batch is not an integer or a comma seperated pair
2311                  of positive integers.
2312
2313    """
2314    profile_batch_error_message = (
2315        'profile_batch must be a non-negative integer or 2-tuple of positive '
2316        'integers. A pair of positive integers signifies a range of batches '
2317        'to profile. Found: {}'.format(profile_batch))
2318
2319    # Support legacy way of specifying "start,stop" or "start" as str.
2320    if isinstance(profile_batch, six.string_types):
2321      profile_batch = str(profile_batch).split(',')
2322      profile_batch = nest.map_structure(int, profile_batch)
2323
2324    if isinstance(profile_batch, int):
2325      self._start_batch = profile_batch
2326      self._stop_batch = profile_batch
2327    elif isinstance(profile_batch, (tuple, list)) and len(profile_batch) == 2:
2328      self._start_batch, self._stop_batch = profile_batch
2329    else:
2330      raise ValueError(profile_batch_error_message)
2331
2332    if self._start_batch < 0 or self._stop_batch < self._start_batch:
2333      raise ValueError(profile_batch_error_message)
2334
2335    if self._start_batch > 0:
2336      # Warm up and improve the profiling accuracy.
2337      profiler.start('')
2338      profiler.stop(save=False)
2339    # True when a trace is running.
2340    self._is_tracing = False
2341
2342    # Setting `profile_batch=0` disables profiling.
2343    self._should_trace = not (self._start_batch == 0 and self._stop_batch == 0)
2344
2345  def on_train_begin(self, logs=None):
2346    self._global_train_batch = 0
2347    self._previous_epoch_iterations = 0
2348    self._train_accumulated_time = 0
2349    self._push_writer(self._train_writer, self._train_step)
2350
2351  def on_train_end(self, logs=None):
2352    self._pop_writer()
2353
2354    if self._is_tracing:
2355      self._stop_trace()
2356
2357    self._close_writers()
2358    self._delete_tmp_write_dir()
2359
2360  def on_test_begin(self, logs=None):
2361    self._push_writer(self._val_writer, self._val_step)
2362
2363  def on_test_end(self, logs=None):
2364    self._pop_writer()
2365
2366  def _implements_train_batch_hooks(self):
2367    # Only call batch hooks when tracing or write_steps_per_second are enabled
2368    return self._should_trace or self.write_steps_per_second
2369
2370  def on_train_batch_begin(self, batch, logs=None):
2371    self._global_train_batch += 1
2372    if self.write_steps_per_second:
2373      self._batch_start_time = time.time()
2374    if not self._should_trace:
2375      return
2376
2377    if self._global_train_batch == self._start_batch:
2378      self._start_trace()
2379
2380  def on_train_batch_end(self, batch, logs=None):
2381    if self._should_write_train_graph:
2382      self._write_keras_model_train_graph()
2383      self._should_write_train_graph = False
2384    if self.write_steps_per_second:
2385      batch_run_time = time.time() - self._batch_start_time
2386      self._train_accumulated_time += batch_run_time
2387      summary_ops_v2.scalar('batch_steps_per_second', 1. / batch_run_time)
2388    if not self._should_trace:
2389      return
2390
2391    if self._is_tracing and self._global_train_batch >= self._stop_batch:
2392      self._stop_trace()
2393
2394  def on_epoch_begin(self, epoch, logs=None):
2395    # Keeps track of epoch for profiling.
2396    self._epoch = epoch
2397    if self.write_steps_per_second:
2398      self._previous_epoch_iterations = self.model.optimizer.iterations.numpy()
2399      self._train_accumulated_time = 0
2400
2401  def on_epoch_end(self, epoch, logs=None):
2402    """Runs metrics and histogram summaries at epoch end."""
2403    self._log_epoch_metrics(epoch, logs)
2404
2405    if self.histogram_freq and epoch % self.histogram_freq == 0:
2406      self._log_weights(epoch)
2407
2408    if self.embeddings_freq and epoch % self.embeddings_freq == 0:
2409      self._log_embeddings(epoch)
2410
2411  def _start_trace(self):
2412    summary_ops_v2.trace_on(graph=True, profiler=False)
2413    profiler.start(logdir=self._train_dir)
2414    self._is_tracing = True
2415
2416  def _stop_trace(self, batch=None):
2417    """Logs the trace graph to TensorBoard."""
2418    if batch is None:
2419      batch = self._stop_batch
2420    with self._train_writer.as_default():
2421      with summary_ops_v2.record_if(True):
2422        # TODO(b/126388999): Remove step info in the summary name.
2423        summary_ops_v2.trace_export(name='batch_%d' % batch, step=batch)
2424    profiler.stop()
2425    self._is_tracing = False
2426
2427  def _collect_learning_rate(self, logs):
2428    lr_schedule = getattr(self.model.optimizer, 'lr', None)
2429    if isinstance(lr_schedule, learning_rate_schedule.LearningRateSchedule):
2430      logs['learning_rate'] = lr_schedule(self.model.optimizer.iterations)
2431    return logs
2432
2433  def _compute_steps_per_second(self):
2434    current_iteration = self.model.optimizer.iterations.numpy()
2435    steps_per_second = ((current_iteration - self._previous_epoch_iterations) /
2436                        (self._train_accumulated_time))
2437    return steps_per_second
2438
2439  def _log_epoch_metrics(self, epoch, logs):
2440    """Writes epoch metrics out as scalar summaries.
2441
2442    Args:
2443        epoch: Int. The global step to use for TensorBoard.
2444        logs: Dict. Keys are scalar summary names, values are scalars.
2445    """
2446    if not logs:
2447      return
2448
2449    train_logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
2450    val_logs = {k: v for k, v in logs.items() if k.startswith('val_')}
2451    train_logs = self._collect_learning_rate(train_logs)
2452    if self.write_steps_per_second:
2453      train_logs['steps_per_second'] = self._compute_steps_per_second()
2454
2455    with summary_ops_v2.record_if(True):
2456      if train_logs:
2457        with self._train_writer.as_default():
2458          for name, value in train_logs.items():
2459            summary_ops_v2.scalar('epoch_' + name, value, step=epoch)
2460      if val_logs:
2461        with self._val_writer.as_default():
2462          for name, value in val_logs.items():
2463            name = name[4:]  # Remove 'val_' prefix.
2464            summary_ops_v2.scalar('epoch_' + name, value, step=epoch)
2465
2466  def _log_weights(self, epoch):
2467    """Logs the weights of the Model to TensorBoard."""
2468    with self._train_writer.as_default():
2469      with summary_ops_v2.record_if(True):
2470        for layer in self.model.layers:
2471          for weight in layer.weights:
2472            weight_name = weight.name.replace(':', '_')
2473            summary_ops_v2.histogram(weight_name, weight, step=epoch)
2474            if self.write_images:
2475              self._log_weight_as_image(weight, weight_name, epoch)
2476        self._train_writer.flush()
2477
2478  def _log_weight_as_image(self, weight, weight_name, epoch):
2479    """Logs a weight as a TensorBoard image."""
2480    w_img = array_ops.squeeze(weight)
2481    shape = K.int_shape(w_img)
2482    if len(shape) == 1:  # Bias case
2483      w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1])
2484    elif len(shape) == 2:  # Dense layer kernel case
2485      if shape[0] > shape[1]:
2486        w_img = array_ops.transpose(w_img)
2487        shape = K.int_shape(w_img)
2488      w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1])
2489    elif len(shape) == 3:  # ConvNet case
2490      if K.image_data_format() == 'channels_last':
2491        # Switch to channels_first to display every kernel as a separate
2492        # image.
2493        w_img = array_ops.transpose(w_img, perm=[2, 0, 1])
2494        shape = K.int_shape(w_img)
2495      w_img = array_ops.reshape(w_img, [shape[0], shape[1], shape[2], 1])
2496
2497    shape = K.int_shape(w_img)
2498    # Not possible to handle 3D convnets etc.
2499    if len(shape) == 4 and shape[-1] in [1, 3, 4]:
2500      summary_ops_v2.image(weight_name, w_img, step=epoch)
2501
2502  def _log_embeddings(self, epoch):
2503    embeddings_ckpt = os.path.join(self._log_write_dir, 'train',
2504                                   'keras_embedding.ckpt-{}'.format(epoch))
2505    self.model.save_weights(embeddings_ckpt)
2506
2507
2508@keras_export('keras.callbacks.ReduceLROnPlateau')
2509class ReduceLROnPlateau(Callback):
2510  """Reduce learning rate when a metric has stopped improving.
2511
2512  Models often benefit from reducing the learning rate by a factor
2513  of 2-10 once learning stagnates. This callback monitors a
2514  quantity and if no improvement is seen for a 'patience' number
2515  of epochs, the learning rate is reduced.
2516
2517  Example:
2518
2519  ```python
2520  reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
2521                                patience=5, min_lr=0.001)
2522  model.fit(X_train, Y_train, callbacks=[reduce_lr])
2523  ```
2524
2525  Args:
2526      monitor: quantity to be monitored.
2527      factor: factor by which the learning rate will be reduced.
2528        `new_lr = lr * factor`.
2529      patience: number of epochs with no improvement after which learning rate
2530        will be reduced.
2531      verbose: int. 0: quiet, 1: update messages.
2532      mode: one of `{'auto', 'min', 'max'}`. In `'min'` mode,
2533        the learning rate will be reduced when the
2534        quantity monitored has stopped decreasing; in `'max'` mode it will be
2535        reduced when the quantity monitored has stopped increasing; in `'auto'`
2536        mode, the direction is automatically inferred from the name of the
2537        monitored quantity.
2538      min_delta: threshold for measuring the new optimum, to only focus on
2539        significant changes.
2540      cooldown: number of epochs to wait before resuming normal operation after
2541        lr has been reduced.
2542      min_lr: lower bound on the learning rate.
2543  """
2544
2545  def __init__(self,
2546               monitor='val_loss',
2547               factor=0.1,
2548               patience=10,
2549               verbose=0,
2550               mode='auto',
2551               min_delta=1e-4,
2552               cooldown=0,
2553               min_lr=0,
2554               **kwargs):
2555    super(ReduceLROnPlateau, self).__init__()
2556
2557    self.monitor = monitor
2558    if factor >= 1.0:
2559      raise ValueError('ReduceLROnPlateau ' 'does not support a factor >= 1.0.')
2560    if 'epsilon' in kwargs:
2561      min_delta = kwargs.pop('epsilon')
2562      logging.warning('`epsilon` argument is deprecated and '
2563                      'will be removed, use `min_delta` instead.')
2564    self.factor = factor
2565    self.min_lr = min_lr
2566    self.min_delta = min_delta
2567    self.patience = patience
2568    self.verbose = verbose
2569    self.cooldown = cooldown
2570    self.cooldown_counter = 0  # Cooldown counter.
2571    self.wait = 0
2572    self.best = 0
2573    self.mode = mode
2574    self.monitor_op = None
2575    self._reset()
2576
2577  def _reset(self):
2578    """Resets wait counter and cooldown counter.
2579    """
2580    if self.mode not in ['auto', 'min', 'max']:
2581      logging.warning('Learning rate reduction mode %s is unknown, '
2582                      'fallback to auto mode.', self.mode)
2583      self.mode = 'auto'
2584    if (self.mode == 'min' or
2585        (self.mode == 'auto' and 'acc' not in self.monitor)):
2586      self.monitor_op = lambda a, b: np.less(a, b - self.min_delta)
2587      self.best = np.Inf
2588    else:
2589      self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta)
2590      self.best = -np.Inf
2591    self.cooldown_counter = 0
2592    self.wait = 0
2593
2594  def on_train_begin(self, logs=None):
2595    self._reset()
2596
2597  def on_epoch_end(self, epoch, logs=None):
2598    logs = logs or {}
2599    logs['lr'] = K.get_value(self.model.optimizer.lr)
2600    current = logs.get(self.monitor)
2601    if current is None:
2602      logging.warning('Learning rate reduction is conditioned on metric `%s` '
2603                      'which is not available. Available metrics are: %s',
2604                      self.monitor, ','.join(list(logs.keys())))
2605
2606    else:
2607      if self.in_cooldown():
2608        self.cooldown_counter -= 1
2609        self.wait = 0
2610
2611      if self.monitor_op(current, self.best):
2612        self.best = current
2613        self.wait = 0
2614      elif not self.in_cooldown():
2615        self.wait += 1
2616        if self.wait >= self.patience:
2617          old_lr = K.get_value(self.model.optimizer.lr)
2618          if old_lr > np.float32(self.min_lr):
2619            new_lr = old_lr * self.factor
2620            new_lr = max(new_lr, self.min_lr)
2621            K.set_value(self.model.optimizer.lr, new_lr)
2622            if self.verbose > 0:
2623              print('\nEpoch %05d: ReduceLROnPlateau reducing learning '
2624                    'rate to %s.' % (epoch + 1, new_lr))
2625            self.cooldown_counter = self.cooldown
2626            self.wait = 0
2627
2628  def in_cooldown(self):
2629    return self.cooldown_counter > 0
2630
2631
2632@keras_export('keras.callbacks.CSVLogger')
2633class CSVLogger(Callback):
2634  """Callback that streams epoch results to a CSV file.
2635
2636  Supports all values that can be represented as a string,
2637  including 1D iterables such as `np.ndarray`.
2638
2639  Example:
2640
2641  ```python
2642  csv_logger = CSVLogger('training.log')
2643  model.fit(X_train, Y_train, callbacks=[csv_logger])
2644  ```
2645
2646  Args:
2647      filename: Filename of the CSV file, e.g. `'run/log.csv'`.
2648      separator: String used to separate elements in the CSV file.
2649      append: Boolean. True: append if file exists (useful for continuing
2650          training). False: overwrite existing file.
2651  """
2652
2653  def __init__(self, filename, separator=',', append=False):
2654    self.sep = separator
2655    self.filename = path_to_string(filename)
2656    self.append = append
2657    self.writer = None
2658    self.keys = None
2659    self.append_header = True
2660    if six.PY2:
2661      self.file_flags = 'b'
2662      self._open_args = {}
2663    else:
2664      self.file_flags = ''
2665      self._open_args = {'newline': '\n'}
2666    super(CSVLogger, self).__init__()
2667
2668  def on_train_begin(self, logs=None):
2669    if self.append:
2670      if file_io.file_exists_v2(self.filename):
2671        with open(self.filename, 'r' + self.file_flags) as f:
2672          self.append_header = not bool(len(f.readline()))
2673      mode = 'a'
2674    else:
2675      mode = 'w'
2676    self.csv_file = io.open(self.filename,
2677                            mode + self.file_flags,
2678                            **self._open_args)
2679
2680  def on_epoch_end(self, epoch, logs=None):
2681    logs = logs or {}
2682
2683    def handle_value(k):
2684      is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
2685      if isinstance(k, six.string_types):
2686        return k
2687      elif isinstance(k, collections.abc.Iterable) and not is_zero_dim_ndarray:
2688        return '"[%s]"' % (', '.join(map(str, k)))
2689      else:
2690        return k
2691
2692    if self.keys is None:
2693      self.keys = sorted(logs.keys())
2694
2695    if self.model.stop_training:
2696      # We set NA so that csv parsers do not fail for this last epoch.
2697      logs = dict((k, logs[k]) if k in logs else (k, 'NA') for k in self.keys)
2698
2699    if not self.writer:
2700
2701      class CustomDialect(csv.excel):
2702        delimiter = self.sep
2703
2704      fieldnames = ['epoch'] + self.keys
2705
2706      self.writer = csv.DictWriter(
2707          self.csv_file,
2708          fieldnames=fieldnames,
2709          dialect=CustomDialect)
2710      if self.append_header:
2711        self.writer.writeheader()
2712
2713    row_dict = collections.OrderedDict({'epoch': epoch})
2714    row_dict.update((key, handle_value(logs[key])) for key in self.keys)
2715    self.writer.writerow(row_dict)
2716    self.csv_file.flush()
2717
2718  def on_train_end(self, logs=None):
2719    self.csv_file.close()
2720    self.writer = None
2721
2722
2723@keras_export('keras.callbacks.LambdaCallback')
2724class LambdaCallback(Callback):
2725  r"""Callback for creating simple, custom callbacks on-the-fly.
2726
2727  This callback is constructed with anonymous functions that will be called
2728  at the appropriate time (during `Model.{fit | evaluate | predict}`).
2729  Note that the callbacks expects positional arguments, as:
2730
2731  - `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
2732    `epoch`, `logs`
2733  - `on_batch_begin` and `on_batch_end` expect two positional arguments:
2734    `batch`, `logs`
2735  - `on_train_begin` and `on_train_end` expect one positional argument:
2736    `logs`
2737
2738  Args:
2739      on_epoch_begin: called at the beginning of every epoch.
2740      on_epoch_end: called at the end of every epoch.
2741      on_batch_begin: called at the beginning of every batch.
2742      on_batch_end: called at the end of every batch.
2743      on_train_begin: called at the beginning of model training.
2744      on_train_end: called at the end of model training.
2745
2746  Example:
2747
2748  ```python
2749  # Print the batch number at the beginning of every batch.
2750  batch_print_callback = LambdaCallback(
2751      on_batch_begin=lambda batch,logs: print(batch))
2752
2753  # Stream the epoch loss to a file in JSON format. The file content
2754  # is not well-formed JSON but rather has a JSON object per line.
2755  import json
2756  json_log = open('loss_log.json', mode='wt', buffering=1)
2757  json_logging_callback = LambdaCallback(
2758      on_epoch_end=lambda epoch, logs: json_log.write(
2759          json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
2760      on_train_end=lambda logs: json_log.close()
2761  )
2762
2763  # Terminate some processes after having finished model training.
2764  processes = ...
2765  cleanup_callback = LambdaCallback(
2766      on_train_end=lambda logs: [
2767          p.terminate() for p in processes if p.is_alive()])
2768
2769  model.fit(...,
2770            callbacks=[batch_print_callback,
2771                       json_logging_callback,
2772                       cleanup_callback])
2773  ```
2774  """
2775
2776  def __init__(self,
2777               on_epoch_begin=None,
2778               on_epoch_end=None,
2779               on_batch_begin=None,
2780               on_batch_end=None,
2781               on_train_begin=None,
2782               on_train_end=None,
2783               **kwargs):
2784    super(LambdaCallback, self).__init__()
2785    self.__dict__.update(kwargs)
2786    if on_epoch_begin is not None:
2787      self.on_epoch_begin = on_epoch_begin
2788    else:
2789      self.on_epoch_begin = lambda epoch, logs: None
2790    if on_epoch_end is not None:
2791      self.on_epoch_end = on_epoch_end
2792    else:
2793      self.on_epoch_end = lambda epoch, logs: None
2794    if on_batch_begin is not None:
2795      self.on_batch_begin = on_batch_begin
2796    else:
2797      self.on_batch_begin = lambda batch, logs: None
2798    if on_batch_end is not None:
2799      self.on_batch_end = on_batch_end
2800    else:
2801      self.on_batch_end = lambda batch, logs: None
2802    if on_train_begin is not None:
2803      self.on_train_begin = on_train_begin
2804    else:
2805      self.on_train_begin = lambda logs: None
2806    if on_train_end is not None:
2807      self.on_train_end = on_train_end
2808    else:
2809      self.on_train_end = lambda logs: None
2810